1- @is_primitive MinimalCtx Tuple{DI. DifferentiateWith,<: Union{Number,AbstractArray,Tuple} }
1+ @is_primitive MinimalCtx Tuple{DI. DifferentiateWith,<: Any }
2+
3+ struct MooncakeDifferentiateWithError <: Exception
4+ F:: Type
5+ X:: Type
6+ Y:: Type
7+ function MooncakeDifferentiateWithError (:: F , :: X , :: Y ) where {F,X,Y}
8+ return new (F, X, Y)
9+ end
10+ end
11+
12+ function Base. showerror (io:: IO , e:: MooncakeDifferentiateWithError )
13+ return print (
14+ io,
15+ " MooncakeDifferentiateWithError: For the function type $(e. F) and input type $(e. X) , the output type $(e. Y) is currently not supported." ,
16+ )
17+ end
218
3- # nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet !
4- # This is because basis construction (DI.basis) does not have overloads for these types.
519# For details, refer commented out test cases to see where the pullback creation fails.
6- function Mooncake. rrule!! (
7- dw:: CoDual{<:DI.DifferentiateWith} , x:: Union{CoDual{<:Number},CoDual{<:Tuple}}
8- )
20+ function Mooncake. rrule!! (dw:: CoDual{<:DI.DifferentiateWith} , x:: CoDual{<:Number} )
921 primal_func = primal (dw)
1022 primal_x = primal (x)
1123 (; f, backend) = primal_func
@@ -25,31 +37,12 @@ function Mooncake.rrule!!(
2537 return NoRData (), rdata (only (tx))
2638 end
2739
28- # output is a Tuple, NTuple
29- function pullback_tuple!! (dy:: Tuple )
30- tx = DI. pullback (f, backend, primal_x, (dy,))
31- @assert rdata (only (tx)) isa rdata_type (tangent_type (typeof (primal_x)))
32- return NoRData (), rdata (only (tx))
33- end
34-
35- # inputs are non Differentiable
36- function pullback_nodiff!! (dy:: NoRData )
37- @assert tangent_type (typeof (primal (x))) <: NoTangent
38- return NoRData (), dy
39- end
40-
41- pullback = if tangent_type (typeof (primal (x))) <: NoTangent
42- pullback_nodiff!!
43- elseif primal (y) isa Number
40+ pullback = if primal (y) isa Number
4441 pullback_scalar!!
45- elseif primal (y) <: AbstractArray
42+ elseif primal (y) isa AbstractArray
4643 pullback_array!!
47- elseif primal (y) <: Tuple
48- pullback_tuple!!
4944 else
50- error (
51- " For the function type $(typeof (primal_func)) and input type $(typeof (primal_x)) , the primal type $(typeof (primal (y))) is currently not supported." ,
52- )
45+ throw (MooncakeDifferentiateWithError (primal_func, primal_x, primal (y)))
5346 end
5447
5548 return y, pullback
@@ -78,191 +71,13 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
7871 return NoRData (), NoRData ()
7972 end
8073
81- # output is a Tuple, NTuple
82- function pullback_tuple!! (dy:: Tuple )
83- tx = DI. pullback (f, backend, primal_x, (dy,))
84- @assert rdata (first (only (tx))) isa rdata_type (tangent_type (typeof (first (primal_x))))
85- fdata_arg .+ = only (tx)
86- return NoRData (), NoRData ()
87- end
88-
89- # inputs are non Differentiable
90- function pullback_nodiff!! (dy:: NoRData )
91- @assert tangent_type (typeof (primal (x))) <: Vector{NoTangent}
92- return NoRData (), dy
93- end
94-
95- pullback = if tangent_type (typeof (primal (x))) <: Vector{NoTangent}
96- pullback_nodiff!!
97- elseif primal (y) isa Number
74+ pullback = if primal (y) isa Number
9875 pullback_scalar!!
99- elseif primal (y) <: AbstractArray
76+ elseif primal (y) isa AbstractArray
10077 pullback_array!!
101- elseif primal (y) <: Tuple
102- pullback_tuple!!
10378 else
104- error (
105- " For the function type $(typeof (primal_func)) and input type $(typeof (primal_x)) , the primal type $(typeof (primal (y))) is currently not supported." ,
106- )
79+ throw (MooncakeDifferentiateWithError (primal_func, primal_x, primal (y)))
10780 end
10881
10982 return y, pullback
11083end
111-
112- function Mooncake. generate_derived_rrule!!_test_cases (rng_ctor, :: Val{:diffwith} )
113- return Any[], Any[]
114- end
115-
116- function Mooncake. generate_hand_written_rrule!!_test_cases (rng_ctor, :: Val{:diffwith} )
117- test_cases = reduce (
118- vcat,
119- map ([(x) -> DI. DifferentiateWith (x, DI. AutoFiniteDiff ())]) do F
120- map ([Float64, Float32]) do P
121- return Any[
122- # (false, :none, nothing, F(identity), ((1.0,),)), # (DI.basis fails for this, correct it!)
123- # (false, :none, nothing, F(identity), [[1.0]]), # (DI.basis fails for this, correct it!)
124- (false , :stability_and_allocs , nothing , F (cosh), P (0.3 )),
125- (false , :stability_and_allocs , nothing , F (sinh), P (0.3 )),
126- (
127- false ,
128- :stability_and_allocs ,
129- nothing ,
130- F (Base. FastMath. exp10_fast),
131- P (0.5 ),
132- ),
133- (
134- false ,
135- :stability_and_allocs ,
136- nothing ,
137- F (Base. FastMath. exp2_fast),
138- P (0.5 ),
139- ),
140- (
141- false ,
142- :stability_and_allocs ,
143- nothing ,
144- F (Base. FastMath. exp_fast),
145- P (5.0 ),
146- ),
147- (false , :stability , nothing , F (copy), rand (Int32, 5 )),
148- ]
149- end
150- end ... ,
151- )
152-
153- map ([(x) -> DI. DifferentiateWith (x, DI. AutoFiniteDiff ())]) do F
154- push! (
155- test_cases,
156- Any[
157- (false , :stability , nothing , copy, randn (5 , 4 )),
158- (
159- # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
160- false ,
161- :stability ,
162- nothing ,
163- F (x -> + (x... )),
164- randn (33 ),
165- ),
166- (
167- false ,
168- :stability ,
169- nothing ,
170- (F (
171- function (x)
172- rx = Ref (x)
173- return Base. pointerref (
174- Base. bitcast (Ptr{Float64}, pointer_from_objref (rx)), 1 , 1
175- )
176- end ,
177- )),
178- 5.0 ,
179- ),
180- # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), # (DI.basis fails for this, correct it!)
181- (
182- false ,
183- :stability_and_allocs ,
184- nothing ,
185- F (Mooncake. IntrinsicsWrappers. ctlz_int),
186- 5 ,
187- ),
188- (
189- false ,
190- :stability_and_allocs ,
191- nothing ,
192- F (Mooncake. IntrinsicsWrappers. ctpop_int),
193- 5 ,
194- ),
195- (
196- false ,
197- :stability_and_allocs ,
198- nothing ,
199- F (Mooncake. IntrinsicsWrappers. cttz_int),
200- 5 ,
201- ),
202- (
203- false ,
204- :stability_and_allocs ,
205- nothing ,
206- F (Mooncake. IntrinsicsWrappers. abs_float),
207- 5.0f0 ,
208- ),
209- (false , :stability_and_allocs , nothing , F (deepcopy), 5.0 ),
210- (false , :stability , nothing , F (deepcopy), randn (5 )),
211- (false , :stability_and_allocs , nothing , F (sin), 1.1 ),
212- (false , :stability_and_allocs , nothing , F (sin), 1.0f1 ),
213- (false , :stability_and_allocs , nothing , F (cos), 1.1 ),
214- (false , :stability_and_allocs , nothing , F (cos), 1.0f1 ),
215- (false , :stability_and_allocs , nothing , F (exp), 1.1 ),
216- (false , :stability_and_allocs , nothing , F (exp), 1.0f1 ),
217- ]. .. ,
218- )
219- end
220-
221- map ([(x) -> DI. DifferentiateWith (x, DI. AutoForwardDiff ())]) do F
222- map ([Float64, Float32]) do P
223- push! (
224- test_cases,
225- Any[
226- (
227- false ,
228- :stability_and_allocs ,
229- nothing ,
230- F (Base. FastMath. sincos),
231- P (3.0 ),
232- ),
233- (false , :none , nothing , F (Mooncake. __vec_to_tuple), [P (1.0 )]),
234- ]. .. ,
235- )
236- end
237-
238- push! (
239- test_cases,
240- Any[
241- (
242- false ,
243- :stability_and_allocs ,
244- nothing ,
245- F (Mooncake. IntrinsicsWrappers. ctlz_int),
246- 5 ,
247- ),
248- (
249- false ,
250- :stability_and_allocs ,
251- nothing ,
252- F (Mooncake. IntrinsicsWrappers. ctpop_int),
253- 5 ,
254- ),
255- (
256- false ,
257- :stability_and_allocs ,
258- nothing ,
259- F (Mooncake. IntrinsicsWrappers. cttz_int),
260- 5 ,
261- ),
262- ]. .. ,
263- )
264- end
265-
266- memory = Any[]
267- return test_cases, memory
268- end
0 commit comments