11@is_primitive MinimalCtx Tuple{DI. DifferentiateWith,<: Union{Number,AbstractArray,Tuple} }
22
3+ # nested vectors, similar are not supported
34function Mooncake. rrule!! (
45 dw:: CoDual{<:DI.DifferentiateWith} , x:: Union{CoDual{<:Number},CoDual{<:Tuple}}
56)
@@ -10,31 +11,41 @@ function Mooncake.rrule!!(
1011
1112 # output is a vector, so we need to use the vector pullback
1213 function pullback_array!! (dy:: NoRData )
13- tx = DI. pullback (f, backend, primal_x, (fdata ( y. dx) ,))
14- @assert only (tx) isa rdata_type (typeof (primal_x))
15- return NoRData (), only (tx)
14+ tx = DI. pullback (f, backend, primal_x, (y. dx,))
15+ @assert rdata ( only (tx)) isa rdata_type (tangent_type ( typeof (primal_x) ))
16+ return NoRData (), rdata ( only (tx) )
1617 end
1718
1819 # output is a scalar, so we can use the scalar pullback
1920 function pullback_scalar!! (dy:: Number )
2021 tx = DI. pullback (f, backend, primal_x, (dy,))
21- @assert only (tx) isa rdata_type (typeof (primal_x))
22- return NoRData (), only (tx)
22+ @assert rdata ( only (tx)) isa rdata_type (tangent_type ( typeof (primal_x) ))
23+ return NoRData (), rdata ( only (tx) )
2324 end
2425
2526 # output is a Tuple, NTuple
2627 function pullback_tuple!! (dy:: Tuple )
2728 tx = DI. pullback (f, backend, primal_x, (dy,))
28- @assert only (tx) isa rdata_type (typeof (primal_x))
29- return NoRData (), only (tx)
29+ @assert rdata ( only (tx)) isa rdata_type (tangent_type ( typeof (primal_x) ))
30+ return NoRData (), rdata ( only (tx) )
3031 end
3132
32- pullback = if typeof (primal (y)) <: Number
33+ # inputs are non Differentiable
34+ function pullback_nodiff!! (dy:: NoRData )
35+ @assert tangent_type (typeof (primal (x))) <: NoTangent
36+ return NoRData (), dy
37+ end
38+
39+ pullback = if tangent_type (typeof (primal (x))) <: NoTangent
40+ pullback_nodiff!!
41+ elseif typeof (primal (y)) <: Number
3342 pullback_scalar!!
3443 elseif typeof (primal (y)) <: Array
3544 pullback_array!!
36- else
45+ elseif typeof ( primal (y)) <: Tuple
3746 pullback_tuple!!
47+ else
48+ error (" $(typeof (primal (y))) primal type currently not supported." )
3849 end
3950
4051 return y, pullback
4354function Mooncake. rrule!! (dw:: CoDual{<:DI.DifferentiateWith} , x:: CoDual{<:AbstractArray} )
4455 primal_func = primal (dw)
4556 primal_x = primal (x)
46- fdata_arg = fdata ( x. dx)
57+ fdata_arg = x. dx
4758 (; f, backend) = primal_func
4859 y = zero_fcodual (f (primal_x))
4960
5061 # output is a vector, so we need to use the vector pullback
5162 function pullback_array!! (dy:: NoRData )
52- tx = DI. pullback (f, backend, primal_x, (fdata ( y. dx) ,))
53- @assert first (only (tx)) isa rdata_type (typeof (first (primal_x)))
63+ tx = DI. pullback (f, backend, primal_x, (y. dx,))
64+ @assert rdata ( first (only (tx))) isa rdata_type (tangent_type ( typeof (first (primal_x) )))
5465 fdata_arg .+ = only (tx)
5566 return NoRData (), dy
5667 end
5768
5869 # output is a scalar, so we can use the scalar pullback
5970 function pullback_scalar!! (dy:: Number )
6071 tx = DI. pullback (f, backend, primal_x, (dy,))
61- @assert first (only (tx)) isa rdata_type (typeof (first (primal_x)))
72+ @assert rdata ( first (only (tx))) isa rdata_type (tangent_type ( typeof (first (primal_x) )))
6273 fdata_arg .+ = only (tx)
6374 return NoRData (), NoRData ()
6475 end
6576
6677 # output is a Tuple, NTuple
6778 function pullback_tuple!! (dy:: Tuple )
6879 tx = DI. pullback (f, backend, primal_x, (dy,))
69- @assert first (only (tx)) isa rdata_type (typeof (first (primal_x)))
80+ @assert rdata ( first (only (tx))) isa rdata_type (tangent_type ( typeof (first (primal_x) )))
7081 fdata_arg .+ = only (tx)
7182 return NoRData (), NoRData ()
7283 end
7384
74- pullback = if typeof (primal (y)) <: Number
85+ # inputs are non Differentiable
86+ function pullback_nodiff!! (dy:: NoRData )
87+ @assert tangent_type (typeof (primal (x))) <: Vector{NoTangent}
88+ return NoRData (), dy
89+ end
90+
91+ pullback = if tangent_type (typeof (primal (x))) <: Vector{NoTangent}
92+ pullback_nodiff!!
93+ elseif typeof (primal (y)) <: Number
7594 pullback_scalar!!
76- elseif typeof (primal (y)) <: Array
95+ elseif typeof (primal (y)) <: AbstractArray
7796 pullback_array!!
78- else
97+ elseif typeof ( primal (y)) <: Tuple
7998 pullback_tuple!!
99+ else
100+ error (" $(typeof (primal (y))) primal type currently not supported." )
80101 end
81102
82103 return y, pullback
89110function Mooncake. generate_hand_written_rrule!!_test_cases (rng_ctor, :: Val{:diffwith} )
90111 test_cases = reduce (
91112 vcat,
92- map ([Float64, Float32]) do P
93- return Any[
94- (false , :stability_and_allocs , nothing , cosh, P (0.3 )),
95- (false , :stability_and_allocs , nothing , sinh, P (0.3 )),
96- (false , :stability_and_allocs , nothing , Base. FastMath. exp10_fast, P (0.5 )),
97- (false , :stability_and_allocs , nothing , Base. FastMath. exp2_fast, P (0.5 )),
98- (false , :stability_and_allocs , nothing , Base. FastMath. exp_fast, P (5.0 )),
99- (false , :stability_and_allocs , nothing , Base. FastMath. sincos, P (3.0 )),
100- ]
101- end ,
113+ map ([(x) -> DI. DifferentiateWith (x, DI. AutoFiniteDiff ())]) do F
114+ map ([Float64, Float32]) do P
115+ return Any[
116+ (false , :stability , nothing , F (cosh), P (0.3 )),
117+ (false , :stability , nothing , F (sinh), P (0.3 )),
118+ (false , :stability , nothing , F (Base. FastMath. exp10_fast), P (0.5 )),
119+ (false , :stability , nothing , F (Base. FastMath. exp2_fast), P (0.5 )),
120+ (false , :stability , nothing , F (Base. FastMath. exp_fast), P (5.0 )),
121+ (false , :none , nothing , F (copy), rand (Int32, 5 )),
122+ ]
123+ end
124+ end ... ,
102125 )
103- push! (test_cases, (false , :stability , nothing , copy, randn (5 , 4 )))
104- push! (test_cases, (
105- # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
106- false ,
107- :none ,
108- nothing ,
109- x -> + (x... ),
110- randn (33 ),
111- ))
112- push! (
113- test_cases,
114- (
115- false ,
116- :none ,
117- nothing ,
118- (
119- function (x)
120- rx = Ref (x)
121- return Base. pointerref (
122- Base. bitcast (Ptr{Float64}, pointer_from_objref (rx)), 1 , 1
123- )
124- end
125- ),
126- 5.0 ,
127- ),
128- )
129- push! (
130- test_cases,
131- (
132- false ,
133- :none ,
134- nothing ,
135- x -> (pointerset (pointer (x), UInt8 (3 ), 2 , 1 ); x),
136- rand (UInt8, 5 ),
137- ),
138- )
139- push! (test_cases, (false , :none , nothing , Mooncake. __vec_to_tuple, [1.0 ]))
140- push! (test_cases, (false , :none , nothing , Mooncake. __vec_to_tuple, Any[1.0 ]))
141- push! (test_cases, (false , :none , nothing , Mooncake. __vec_to_tuple, Any[[1.0 ]]))
142- push! (test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. ctlz_int, 5 ))
143- push! (
144- test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. ctpop_int, 5 )
145- )
146- push! (test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. cttz_int, 5 ))
147- push! (
148- test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. abs_float, 5.0 )
149- )
150- push! (
151- test_cases,
152- (false , :stability , nothing , Mooncake. IntrinsicsWrappers. abs_float, 5.0f0 ),
153- )
154- push! (test_cases, (false , :stability , nothing , deepcopy, 5.0 ))
155- push! (test_cases, (false , :stability , nothing , deepcopy, randn (5 )))
156- push! (test_cases, (false , :stability_and_allocs , nothing , sin, 1.1 ))
157- push! (test_cases, (false , :stability_and_allocs , nothing , sin, 1.0f1 ))
158- push! (test_cases, (false , :stability_and_allocs , nothing , cos, 1.1 ))
159- push! (test_cases, (false , :stability_and_allocs , nothing , cos, 1.0f1 ))
160- push! (test_cases, (false , :stability_and_allocs , nothing , exp, 1.1 ))
161- push! (test_cases, (false , :stability_and_allocs , nothing , exp, 1.0f1 ))
162-
163- # additional_test_set = Mooncake.tangent_test_cases()
164- # function is_valid(f)
165- # try
166- # isa(f([1.0, 2.0]), Union{<:Number,<:AbstractArray})
167- # catch
168- # false
169- # end
170- # end
171- # for test in additional_test_set
172- # if is_valid(test[2])
173- # push!(test_cases, test)
174- # end
175- # end
126+
127+ map ([(x) -> DI. DifferentiateWith (x, DI. AutoZygote ())]) do F
128+ map ([Float64, Float32]) do P
129+ push! (
130+ test_cases,
131+ Any[
132+ (false , :stability , nothing , F (Base. FastMath. sincos), P (3.0 )),
133+ (false , :none , nothing , F (Mooncake. __vec_to_tuple), Any[P (1.0 )]),
134+ ]. .. ,
135+ )
136+ end
137+ end
138+
139+ map ([(x) -> DI. DifferentiateWith (x, DI. AutoZygote ())]) do F
140+ push! (
141+ test_cases,
142+ Any[
143+ (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctlz_int), 5 ),
144+ (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctpop_int), 5 ),
145+ (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. cttz_int), 5 ),
146+ ]. .. ,
147+ )
148+ end
149+
150+ map ([(x) -> DI. DifferentiateWith (x, DI. AutoFiniteDiff ())]) do F
151+ push! (
152+ test_cases,
153+ Any[
154+ (false , :stability , nothing , copy, randn (5 , 4 )),
155+ (
156+ # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
157+ false ,
158+ :none ,
159+ nothing ,
160+ F (x -> + (x... )),
161+ randn (33 ),
162+ ),
163+ (
164+ false ,
165+ :none ,
166+ nothing ,
167+ (F (
168+ function (x)
169+ rx = Ref (x)
170+ return Base. pointerref (
171+ Base. bitcast (Ptr{Float64}, pointer_from_objref (rx)), 1 , 1
172+ )
173+ end ,
174+ )),
175+ 5.0 ,
176+ ),
177+ (false , :none , nothing , F (Mooncake. __vec_to_tuple), [1.0 ]),
178+ # (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), DI.basis fails for this, correct it!
179+ (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctlz_int), 5 ),
180+ (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. ctpop_int), 5 ),
181+ (false , :stability , nothing , F (Mooncake. IntrinsicsWrappers. cttz_int), 5 ),
182+ (
183+ false ,
184+ :stability ,
185+ nothing ,
186+ F (Mooncake. IntrinsicsWrappers. abs_float),
187+ 5.0f0 ,
188+ ),
189+ (false , :stability , nothing , F (deepcopy), 5.0 ),
190+ (false , :stability , nothing , F (deepcopy), randn (5 )),
191+ (false , :stability_and_allocs , nothing , F (sin), 1.1 ),
192+ (false , :stability_and_allocs , nothing , F (sin), 1.0f1 ),
193+ (false , :stability_and_allocs , nothing , F (cos), 1.1 ),
194+ (false , :stability_and_allocs , nothing , F (cos), 1.0f1 ),
195+ (false , :stability_and_allocs , nothing , F (exp), 1.1 ),
196+ (false , :stability_and_allocs , nothing , F (exp), 1.0f1 ),
197+ ]. .. ,
198+ )
199+ end
176200
177201 memory = Any[]
178202 return test_cases, memory
0 commit comments