11@is_primitive MinimalCtx Tuple{DI. DifferentiateWith,<: Union{Number,AbstractArray,Tuple} }
22
3- # nested vectors, similar are not supported
3+ # nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet !
4+ # This is because basis construction (DI.basis) does not have overloads for these types.
5+ # For details, refer commented out test cases to see where the pullback creation fails.
46function Mooncake. rrule!! (
57 dw:: CoDual{<:DI.DifferentiateWith} , x:: Union{CoDual{<:Number},CoDual{<:Tuple}}
68)
@@ -45,7 +47,9 @@ function Mooncake.rrule!!(
4547 elseif typeof (primal (y)) <: Tuple
4648 pullback_tuple!!
4749 else
48- error (" $(typeof (primal (y))) primal type currently not supported." )
50+ error (
51+ " For the function type $(typeof (primal_func)) and input type $(typeof (primal_x)) , the primal type $(typeof (primal (y))) is currently not supported." ,
52+ )
4953 end
5054
5155 return y, pullback
@@ -97,7 +101,9 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
97101 elseif typeof (primal (y)) <: Tuple
98102 pullback_tuple!!
99103 else
100- error (" $(typeof (primal (y))) primal type currently not supported." )
104+ error (
105+ " For the function type $(typeof (primal_func)) and input type $(typeof (primal_x)) , the primal type $(typeof (primal (y))) is currently not supported." ,
106+ )
101107 end
102108
103109 return y, pullback
@@ -113,40 +119,37 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
113119 map ([(x) -> DI. DifferentiateWith (x, DI. AutoFiniteDiff ())]) do F
114120 map ([Float64, Float32]) do P
115121 return Any[
116- (false , :stability , nothing , F (cosh), P (0.3 )),
117- (false , :stability , nothing , F (sinh), P (0.3 )),
118- (false , :stability , nothing , F (Base. FastMath. exp10_fast), P (0.5 )),
119- (false , :stability , nothing , F (Base. FastMath. exp2_fast), P (0.5 )),
120- (false , :stability , nothing , F (Base. FastMath. exp_fast), P (5.0 )),
121- (false , :none , nothing , F (copy), rand (Int32, 5 )),
122+ # (false, :none, nothing, F(identity), ((1.0,),)), # (DI.basis fails for this, correct it!)
123+ # (false, :none, nothing, F(identity), [[1.0]]), # (DI.basis fails for this, correct it!)
124+ (false , :stability_and_allocs , nothing , F (cosh), P (0.3 )),
125+ (false , :stability_and_allocs , nothing , F (sinh), P (0.3 )),
126+ (
127+ false ,
128+ :stability_and_allocs ,
129+ nothing ,
130+ F (Base. FastMath. exp10_fast),
131+ P (0.5 ),
132+ ),
133+ (
134+ false ,
135+ :stability_and_allocs ,
136+ nothing ,
137+ F (Base. FastMath. exp2_fast),
138+ P (0.5 ),
139+ ),
140+ (
141+ false ,
142+ :stability_and_allocs ,
143+ nothing ,
144+ F (Base. FastMath. exp_fast),
145+ P (5.0 ),
146+ ),
147+ (false , :stability , nothing , F (copy), rand (Int32, 5 )),
122148 ]
123149 end
124150 end ... ,
125151 )
126152
127- map ([(x) -> DI. DifferentiateWith (x, DI. AutoZygote ())]) do F
128- map ([Float64, Float32]) do P
129- push! (
130- test_cases,
131- Any[
132- (false , :stability , nothing , F (Base. FastMath. sincos), P (3.0 )),
133- (false , :none , nothing , F (Mooncake. __vec_to_tuple), Any[P (1.0 )]),
134- ]. .. ,
135- )
136- end
137- end
138-
139- map ([(x) -> DI. DifferentiateWith (x, DI. AutoZygote ())]) do F
140- push! (
141- test_cases,
142- Any[
143- (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctlz_int), 5 ),
144- (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctpop_int), 5 ),
145- (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. cttz_int), 5 ),
146- ]. .. ,
147- )
148- end
149-
150153 map ([(x) -> DI. DifferentiateWith (x, DI. AutoFiniteDiff ())]) do F
151154 push! (
152155 test_cases,
@@ -155,14 +158,14 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
155158 (
156159 # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
157160 false ,
158- :none ,
161+ :stability ,
159162 nothing ,
160163 F (x -> + (x... )),
161164 randn (33 ),
162165 ),
163166 (
164167 false ,
165- :none ,
168+ :stability ,
166169 nothing ,
167170 (F (
168171 function (x)
@@ -174,19 +177,36 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
174177 )),
175178 5.0 ,
176179 ),
177- (false , :none , nothing , F (Mooncake. __vec_to_tuple), [1.0 ]),
178- # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), DI.basis fails for this, correct it!
179- (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctlz_int), 5 ),
180- (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctpop_int), 5 ),
181- (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. cttz_int), 5 ),
180+ # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), # (DI.basis fails for this, correct it!)
182181 (
183182 false ,
184- :stability ,
183+ :stability_and_allocs ,
184+ nothing ,
185+ F (Mooncake. IntrinsicsWrappers. ctlz_int),
186+ 5 ,
187+ ),
188+ (
189+ false ,
190+ :stability_and_allocs ,
191+ nothing ,
192+ F (Mooncake. IntrinsicsWrappers. ctpop_int),
193+ 5 ,
194+ ),
195+ (
196+ false ,
197+ :stability_and_allocs ,
198+ nothing ,
199+ F (Mooncake. IntrinsicsWrappers. cttz_int),
200+ 5 ,
201+ ),
202+ (
203+ false ,
204+ :stability_and_allocs ,
185205 nothing ,
186206 F (Mooncake. IntrinsicsWrappers. abs_float),
187207 5.0f0 ,
188208 ),
189- (false , :stability , nothing , F (deepcopy), 5.0 ),
209+ (false , :stability_and_allocs , nothing , F (deepcopy), 5.0 ),
190210 (false , :stability , nothing , F (deepcopy), randn (5 )),
191211 (false , :stability_and_allocs , nothing , F (sin), 1.1 ),
192212 (false , :stability_and_allocs , nothing , F (sin), 1.0f1 ),
@@ -198,6 +218,51 @@ function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diff
198218 )
199219 end
200220
221+ map ([(x) -> DI. DifferentiateWith (x, DI. AutoForwardDiff ())]) do F
222+ map ([Float64, Float32]) do P
223+ push! (
224+ test_cases,
225+ Any[
226+ (
227+ false ,
228+ :stability_and_allocs ,
229+ nothing ,
230+ F (Base. FastMath. sincos),
231+ P (3.0 ),
232+ ),
233+ (false , :none , nothing , F (Mooncake. __vec_to_tuple), [P (1.0 )]),
234+ ]. .. ,
235+ )
236+ end
237+
238+ push! (
239+ test_cases,
240+ Any[
241+ (
242+ false ,
243+ :stability_and_allocs ,
244+ nothing ,
245+ F (Mooncake. IntrinsicsWrappers. ctlz_int),
246+ 5 ,
247+ ),
248+ (
249+ false ,
250+ :stability_and_allocs ,
251+ nothing ,
252+ F (Mooncake. IntrinsicsWrappers. ctpop_int),
253+ 5 ,
254+ ),
255+ (
256+ false ,
257+ :stability_and_allocs ,
258+ nothing ,
259+ F (Mooncake. IntrinsicsWrappers. cttz_int),
260+ 5 ,
261+ ),
262+ ]. .. ,
263+ )
264+ end
265+
201266 memory = Any[]
202267 return test_cases, memory
203268end
0 commit comments