Skip to content

Commit b4fe0f8

Browse files
tests.
1 parent c389a80 commit b4fe0f8

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}}
22

33
function Mooncake.rrule!!(
4-
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{Union{<:Number,<:Tuple}}
4+
dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}}
55
)
66
primal_func = primal(dw)
77
primal_x = primal(x)
@@ -82,11 +82,11 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
8282
return y, pullback
8383
end
8484

85-
function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
85+
function Mooncake.generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
8686
return Any[], Any[]
8787
end
8888

89-
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
89+
function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
9090
test_cases = reduce(
9191
vcat,
9292
map([Float64, Float32]) do P

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ test_differentiation(
3333

3434
@testset "new" begin
3535
Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith))
36-
end
36+
end

0 commit comments

Comments
 (0)