Skip to content

Commit 0f0b9fc

Browse files
changes from reviews
1 parent ec4b75d commit 0f0b9fc

1 file changed

Lines changed: 106 additions & 41 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 106 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}}
22

3-
# nested vectors, similar are not supported
3+
# nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet !
4+
# This is because basis construction (DI.basis) does not have overloads for these types.
5+
# For details, refer commented out test cases to see where the pullback creation fails.
46
function Mooncake.rrule!!(
57
dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}}
68
)
@@ -45,7 +47,9 @@ function Mooncake.rrule!!(
4547
elseif typeof(primal(y)) <: Tuple
4648
pullback_tuple!!
4749
else
48-
error("$(typeof(primal(y))) primal type currently not supported.")
50+
error(
51+
"For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.",
52+
)
4953
end
5054

5155
return y, pullback
@@ -97,7 +101,9 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
97101
elseif typeof(primal(y)) <: Tuple
98102
pullback_tuple!!
99103
else
100-
error("$(typeof(primal(y))) primal type currently not supported.")
104+
error(
105+
"For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.",
106+
)
101107
end
102108

103109
return y, pullback
@@ -113,40 +119,37 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
113119
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
114120
map([Float64, Float32]) do P
115121
return Any[
116-
(false, :stability, nothing, F(cosh), P(0.3)),
117-
(false, :stability, nothing, F(sinh), P(0.3)),
118-
(false, :stability, nothing, F(Base.FastMath.exp10_fast), P(0.5)),
119-
(false, :stability, nothing, F(Base.FastMath.exp2_fast), P(0.5)),
120-
(false, :stability, nothing, F(Base.FastMath.exp_fast), P(5.0)),
121-
(false, :none, nothing, F(copy), rand(Int32, 5)),
122+
# (false, :none, nothing, F(identity), ((1.0,),)), # (DI.basis fails for this, correct it!)
123+
# (false, :none, nothing, F(identity), [[1.0]]), # (DI.basis fails for this, correct it!)
124+
(false, :stability_and_allocs, nothing, F(cosh), P(0.3)),
125+
(false, :stability_and_allocs, nothing, F(sinh), P(0.3)),
126+
(
127+
false,
128+
:stability_and_allocs,
129+
nothing,
130+
F(Base.FastMath.exp10_fast),
131+
P(0.5),
132+
),
133+
(
134+
false,
135+
:stability_and_allocs,
136+
nothing,
137+
F(Base.FastMath.exp2_fast),
138+
P(0.5),
139+
),
140+
(
141+
false,
142+
:stability_and_allocs,
143+
nothing,
144+
F(Base.FastMath.exp_fast),
145+
P(5.0),
146+
),
147+
(false, :stability, nothing, F(copy), rand(Int32, 5)),
122148
]
123149
end
124150
end...,
125151
)
126152

127-
map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F
128-
map([Float64, Float32]) do P
129-
push!(
130-
test_cases,
131-
Any[
132-
(false, :stability, nothing, F(Base.FastMath.sincos), P(3.0)),
133-
(false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[P(1.0)]),
134-
]...,
135-
)
136-
end
137-
end
138-
139-
map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F
140-
push!(
141-
test_cases,
142-
Any[
143-
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5),
144-
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5),
145-
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5),
146-
]...,
147-
)
148-
end
149-
150153
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
151154
push!(
152155
test_cases,
@@ -155,14 +158,14 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
155158
(
156159
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
157160
false,
158-
:none,
161+
:stability,
159162
nothing,
160163
F(x -> +(x...)),
161164
randn(33),
162165
),
163166
(
164167
false,
165-
:none,
168+
:stability,
166169
nothing,
167170
(F(
168171
function (x)
@@ -174,19 +177,36 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
174177
)),
175178
5.0,
176179
),
177-
(false, :none, nothing, F(Mooncake.__vec_to_tuple), [1.0]),
178-
# (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), DI.basis fails for this, correct it!
179-
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5),
180-
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5),
181-
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5),
180+
# (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), # (DI.basis fails for this, correct it!)
182181
(
183182
false,
184-
:stability,
183+
:stability_and_allocs,
184+
nothing,
185+
F(Mooncake.IntrinsicsWrappers.ctlz_int),
186+
5,
187+
),
188+
(
189+
false,
190+
:stability_and_allocs,
191+
nothing,
192+
F(Mooncake.IntrinsicsWrappers.ctpop_int),
193+
5,
194+
),
195+
(
196+
false,
197+
:stability_and_allocs,
198+
nothing,
199+
F(Mooncake.IntrinsicsWrappers.cttz_int),
200+
5,
201+
),
202+
(
203+
false,
204+
:stability_and_allocs,
185205
nothing,
186206
F(Mooncake.IntrinsicsWrappers.abs_float),
187207
5.0f0,
188208
),
189-
(false, :stability, nothing, F(deepcopy), 5.0),
209+
(false, :stability_and_allocs, nothing, F(deepcopy), 5.0),
190210
(false, :stability, nothing, F(deepcopy), randn(5)),
191211
(false, :stability_and_allocs, nothing, F(sin), 1.1),
192212
(false, :stability_and_allocs, nothing, F(sin), 1.0f1),
@@ -198,6 +218,51 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
198218
)
199219
end
200220

221+
map([(x) -> DI.DifferentiateWith(x, DI.AutoForwardDiff())]) do F
222+
map([Float64, Float32]) do P
223+
push!(
224+
test_cases,
225+
Any[
226+
(
227+
false,
228+
:stability_and_allocs,
229+
nothing,
230+
F(Base.FastMath.sincos),
231+
P(3.0),
232+
),
233+
(false, :none, nothing, F(Mooncake.__vec_to_tuple), [P(1.0)]),
234+
]...,
235+
)
236+
end
237+
238+
push!(
239+
test_cases,
240+
Any[
241+
(
242+
false,
243+
:stability_and_allocs,
244+
nothing,
245+
F(Mooncake.IntrinsicsWrappers.ctlz_int),
246+
5,
247+
),
248+
(
249+
false,
250+
:stability_and_allocs,
251+
nothing,
252+
F(Mooncake.IntrinsicsWrappers.ctpop_int),
253+
5,
254+
),
255+
(
256+
false,
257+
:stability_and_allocs,
258+
nothing,
259+
F(Mooncake.IntrinsicsWrappers.cttz_int),
260+
5,
261+
),
262+
]...,
263+
)
264+
end
265+
201266
memory = Any[]
202267
return test_cases, memory
203268
end

0 commit comments

Comments
 (0)