1- @is_primitive MinimalCtx Tuple{DI. DifferentiateWith,<: Union{Number,AbstractArray} }
1+ @is_primitive MinimalCtx Tuple{DI. DifferentiateWith,<: Union{Number,AbstractArray,Tuple } }
22
3- function Mooncake. rrule!! (dw:: CoDual{<:DI.DifferentiateWith} , x:: CoDual{<:Number} )
3+ function Mooncake. rrule!! (
4+ dw:: CoDual{<:DI.DifferentiateWith} , x:: CoDual{Union{<:Number,<:Tuple}}
5+ )
46 primal_func = primal (dw)
57 primal_x = primal (x)
68 (; f, backend) = primal_func
@@ -20,7 +22,22 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
2022 return NoRData (), only (tx)
2123 end
2224
23- return y, typeof (primal (y)) <: Number ? pullback_scalar!! : pullback_array!!
25+ # output is a Tuple, NTuple
26+ function pullback_tuple!! (dy:: Tuple )
27+ tx = DI. pullback (f, backend, primal_x, (dy,))
28+ @assert only (tx) isa rdata_type (typeof (primal_x))
29+ return NoRData (), only (tx)
30+ end
31+
32+ pullback = if typeof (primal (y)) <: Number
33+ pullback_scalar!!
34+ elseif typeof (primal (y)) <: Array
35+ pullback_array!!
36+ else
37+ pullback_tuple!!
38+ end
39+
40+ return y, pullback
2441end
2542
2643function Mooncake. rrule!! (dw:: CoDual{<:DI.DifferentiateWith} , x:: CoDual{<:AbstractArray} )
@@ -46,5 +63,117 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
4663 return NoRData (), NoRData ()
4764 end
4865
49- return y, typeof (primal (y)) <: Number ? pullback_scalar!! : pullback_array!!
66+ # output is a Tuple, NTuple
67+ function pullback_tuple!! (dy:: Tuple )
68+ tx = DI. pullback (f, backend, primal_x, (dy,))
69+ @assert first (only (tx)) isa rdata_type (typeof (first (primal_x)))
70+ fdata_arg .+ = only (tx)
71+ return NoRData (), NoRData ()
72+ end
73+
74+ pullback = if typeof (primal (y)) <: Number
75+ pullback_scalar!!
76+ elseif typeof (primal (y)) <: Array
77+ pullback_array!!
78+ else
79+ pullback_tuple!!
80+ end
81+
82+ return y, pullback
83+ end
84+
85+ function generate_derived_rrule!!_test_cases (rng_ctor, :: Val{:diffwith} )
86+ return Any[], Any[]
87+ end
88+
89+ function generate_hand_written_rrule!!_test_cases (rng_ctor, :: Val{:diffwith} )
90+ test_cases = reduce (
91+ vcat,
92+ map ([Float64, Float32]) do P
93+ return Any[
94+ (false , :stability_and_allocs , nothing , cosh, P (0.3 )),
95+ (false , :stability_and_allocs , nothing , sinh, P (0.3 )),
96+ (false , :stability_and_allocs , nothing , Base. FastMath. exp10_fast, P (0.5 )),
97+ (false , :stability_and_allocs , nothing , Base. FastMath. exp2_fast, P (0.5 )),
98+ (false , :stability_and_allocs , nothing , Base. FastMath. exp_fast, P (5.0 )),
99+ (false , :stability_and_allocs , nothing , Base. FastMath. sincos, P (3.0 )),
100+ ]
101+ end ,
102+ )
103+ push! (test_cases, (false , :stability , nothing , copy, randn (5 , 4 )))
104+ push! (test_cases, (
105+ # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
106+ false ,
107+ :none ,
108+ nothing ,
109+ x -> + (x... ),
110+ randn (33 ),
111+ ))
112+ push! (
113+ test_cases,
114+ (
115+ false ,
116+ :none ,
117+ nothing ,
118+ (
119+ function (x)
120+ rx = Ref (x)
121+ return Base. pointerref (
122+ Base. bitcast (Ptr{Float64}, pointer_from_objref (rx)), 1 , 1
123+ )
124+ end
125+ ),
126+ 5.0 ,
127+ ),
128+ )
129+ push! (
130+ test_cases,
131+ (
132+ false ,
133+ :none ,
134+ nothing ,
135+ x -> (pointerset (pointer (x), UInt8 (3 ), 2 , 1 ); x),
136+ rand (UInt8, 5 ),
137+ ),
138+ )
139+ push! (test_cases, (false , :none , nothing , Mooncake. __vec_to_tuple, [1.0 ]))
140+ push! (test_cases, (false , :none , nothing , Mooncake. __vec_to_tuple, Any[1.0 ]))
141+ push! (test_cases, (false , :none , nothing , Mooncake. __vec_to_tuple, Any[[1.0 ]]))
142+ push! (test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. ctlz_int, 5 ))
143+ push! (
144+ test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. ctpop_int, 5 )
145+ )
146+ push! (test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. cttz_int, 5 ))
147+ push! (
148+ test_cases, (false , :stability , nothing , Mooncake. IntrinsicsWrappers. abs_float, 5.0 )
149+ )
150+ push! (
151+ test_cases,
152+ (false , :stability , nothing , Mooncake. IntrinsicsWrappers. abs_float, 5.0f0 ),
153+ )
154+ push! (test_cases, (false , :stability , nothing , deepcopy, 5.0 ))
155+ push! (test_cases, (false , :stability , nothing , deepcopy, randn (5 )))
156+ push! (test_cases, (false , :stability_and_allocs , nothing , sin, 1.1 ))
157+ push! (test_cases, (false , :stability_and_allocs , nothing , sin, 1.0f1 ))
158+ push! (test_cases, (false , :stability_and_allocs , nothing , cos, 1.1 ))
159+ push! (test_cases, (false , :stability_and_allocs , nothing , cos, 1.0f1 ))
160+ push! (test_cases, (false , :stability_and_allocs , nothing , exp, 1.1 ))
161+ push! (test_cases, (false , :stability_and_allocs , nothing , exp, 1.0f1 ))
162+
163+ # additional_test_set = Mooncake.tangent_test_cases()
164+ # function is_valid(f)
165+ # try
166+ # isa(f([1.0, 2.0]), Union{<:Number,<:AbstractArray})
167+ # catch
168+ # false
169+ # end
170+ # end
171+ # for test in additional_test_set
172+ # if is_valid(test[2])
173+ # push!(test_cases, test)
174+ # end
175+ # end
176+
177+ memory = Any[]
178+ return test_cases, memory
50179end
0 commit comments