6161
6262# ## Out-of-place
6363
64- function DI. value_and_pullback (
65- f:: F ,
66- :: NoPullbackPrep ,
67- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
68- x:: Number ,
69- ty:: NTuple{1} ,
70- contexts:: Vararg{Context,C} ,
71- ) where {F,C}
72- f_and_df = force_annotation (get_f_and_df (f, backend))
73- mode = reverse_split_withprimal (backend)
74- RA = eltype (ty) <: Number ? Active : Duplicated
75- dinputs, result = seeded_autodiff_thunk (
76- mode, only (ty), f_and_df, RA, Active (x), map (translate, contexts)...
77- )
78- return result, (first (dinputs),)
79- end
80-
81- function DI. value_and_pullback (
82- f:: F ,
83- :: NoPullbackPrep ,
84- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
85- x:: Number ,
86- ty:: NTuple{B} ,
87- contexts:: Vararg{Context,C} ,
88- ) where {F,B,C}
89- f_and_df = force_annotation (get_f_and_df (f, backend, Val (B)))
90- mode = reverse_split_withprimal (backend)
91- RA = eltype (ty) <: Number ? Active : BatchDuplicated
92- dinputs, result = batch_seeded_autodiff_thunk (
93- mode, ty, f_and_df, RA, Active (x), map (translate, contexts)...
94- )
95- return result, values (first (dinputs))
96- end
97-
9864function DI. value_and_pullback (
9965 f:: F ,
10066 :: NoPullbackPrep ,
@@ -105,12 +71,18 @@ function DI.value_and_pullback(
10571) where {F,C}
10672 f_and_df = force_annotation (get_f_and_df (f, backend))
10773 mode = reverse_split_withprimal (backend)
108- RA = eltype (ty) <: Number ? Active : Duplicated
74+ IA = guess_activity (typeof (x), mode)
75+ RA = guess_activity (eltype (ty), mode)
10976 dx = make_zero (x)
110- _ , result = seeded_autodiff_thunk (
111- mode, only (ty), f_and_df, RA, Duplicated ( x, dx), map (translate, contexts)...
77+ dinputs , result = seeded_autodiff_thunk (
78+ mode, only (ty), f_and_df, RA, annotate (IA, x, dx), map (translate, contexts)...
11279 )
113- return result, (dx,)
80+ new_dx = first (dinputs)
81+ if isnothing (new_dx)
82+ return result, (dx,)
83+ else
84+ return result, (new_dx,)
85+ end
11486end
11587
11688function DI. value_and_pullback (
@@ -123,12 +95,18 @@ function DI.value_and_pullback(
12395) where {F,B,C}
12496 f_and_df = force_annotation (get_f_and_df (f, backend, Val (B)))
12597 mode = reverse_split_withprimal (backend)
126- RA = eltype (ty) <: Number ? Active : BatchDuplicated
98+ IA = batchify_activity (guess_activity (typeof (x), mode), Val (B))
99+ RA = batchify_activity (guess_activity (eltype (ty), mode), Val (B))
127100 tx = ntuple (_ -> make_zero (x), Val (B))
128- _ , result = batch_seeded_autodiff_thunk (
129- mode, ty, f_and_df, RA, BatchDuplicated ( x, tx), map (translate, contexts)...
101+ dinputs , result = batch_seeded_autodiff_thunk (
102+ mode, ty, f_and_df, RA, annotate (IA, x, tx), map (translate, contexts)...
130103 )
131- return result, tx
104+ new_tx = values (first (dinputs))
105+ if isnothing (new_tx)
106+ return result, tx
107+ else
108+ return result, new_tx
109+ end
132110end
133111
134112function DI. pullback (
@@ -155,7 +133,7 @@ function DI.value_and_pullback!(
155133) where {F,C}
156134 f_and_df = force_annotation (get_f_and_df (f, backend))
157135 mode = reverse_split_withprimal (backend)
158- RA = eltype (ty) <: Number ? Active : Duplicated
136+ RA = guess_activity ( eltype (ty), mode)
159137 dx_righttype = convert (typeof (x), only (tx))
160138 make_zero! (dx_righttype)
161139 _, result = seeded_autodiff_thunk (
@@ -181,7 +159,7 @@ function DI.value_and_pullback!(
181159) where {F,B,C}
182160 f_and_df = force_annotation (get_f_and_df (f, backend, Val (B)))
183161 mode = reverse_split_withprimal (backend)
184- RA = eltype (ty) <: Number ? Active : BatchDuplicated
162+ RA = batchify_activity ( guess_activity ( eltype (ty), mode), Val (B))
185163 tx_righttype = map (Fix1 (convert, typeof (x)), tx)
186164 make_zero! (tx_righttype)
187165 _, result = batch_seeded_autodiff_thunk (
@@ -213,29 +191,39 @@ end
213191# ## Without preparation
214192
215193function DI. gradient (
216- f:: F ,
217- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
218- x,
219- contexts:: Vararg{Context,C} ,
194+ f:: F , backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} , x, contexts:: Vararg{Context,C}
220195) where {F,C}
221196 f_and_df = get_f_and_df (f, backend)
222- ders = gradient (reverse_noprimal (backend), f_and_df, x, map (translate, contexts)... )
223- grad = first (ders)
224- return grad
197+ mode = reverse_noprimal (backend)
198+ IA = guess_activity (typeof (x), mode)
199+ grad = make_zero (x)
200+ dinputs = only (
201+ autodiff (mode, f_and_df, Active, annotate (IA, x, grad), map (translate, contexts)... )
202+ )
203+ new_grad = first (dinputs)
204+ if isnothing (new_grad)
205+ return grad
206+ else
207+ return new_grad
208+ end
225209end
226210
227211function DI. value_and_gradient (
228- f:: F ,
229- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
230- x,
231- contexts:: Vararg{Context,C} ,
212+ f:: F , backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} , x, contexts:: Vararg{Context,C}
232213) where {F,C}
233214 f_and_df = get_f_and_df (f, backend)
234- ders, y = gradient (
235- reverse_withprimal (backend), f_and_df, x, map (translate, contexts)...
215+ mode = reverse_withprimal (backend)
216+ IA = guess_activity (typeof (x), mode)
217+ grad = make_zero (x)
218+ dinputs, result = autodiff (
219+ mode, f_and_df, Active, annotate (IA, x, grad), map (translate, contexts)...
236220 )
237- grad = first (ders)
238- return y, grad
221+ new_grad = first (dinputs)
222+ if isnothing (new_grad)
223+ return result, grad
224+ else
225+ return result, new_grad
226+ end
239227end
240228
241229# ## With preparation
@@ -245,10 +233,7 @@ struct EnzymeGradientPrep{G} <: GradientPrep
245233end
246234
247235function DI. prepare_gradient (
248- f:: F ,
249- :: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
250- x,
251- contexts:: Vararg{Context,C} ,
236+ f:: F , :: AutoEnzyme{<:Union{ReverseMode,Nothing}} , x, contexts:: Vararg{Context,C}
252237) where {F,C}
253238 grad_righttype = make_zero (x)
254239 return EnzymeGradientPrep (grad_righttype)
@@ -257,21 +242,18 @@ end
257242function DI. gradient (
258243 f:: F ,
259244 :: EnzymeGradientPrep ,
260- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const} } ,
245+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
261246 x,
262247 contexts:: Vararg{Context,C} ,
263248) where {F,C}
264- f_and_df = get_f_and_df (f, backend)
265- ders = gradient (reverse_noprimal (backend), f_and_df, x, map (translate, contexts)... )
266- grad = first (ders)
267- return grad
249+ return DI. gradient (f, backend, x, contexts... )
268250end
269251
270252function DI. gradient! (
271253 f:: F ,
272254 grad,
273255 prep:: EnzymeGradientPrep ,
274- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const} } ,
256+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
275257 x,
276258 contexts:: Vararg{Context,C} ,
277259) where {F,C}
@@ -292,23 +274,18 @@ end
292274function DI. value_and_gradient (
293275 f:: F ,
294276 :: EnzymeGradientPrep ,
295- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const} } ,
277+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
296278 x,
297279 contexts:: Vararg{Context,C} ,
298280) where {F,C}
299- f_and_df = get_f_and_df (f, backend)
300- ders, y = gradient (
301- reverse_withprimal (backend), f_and_df, x, map (translate, contexts)...
302- )
303- grad = first (ders)
304- return y, grad
281+ return DI. value_and_gradient (f, backend, x, contexts... )
305282end
306283
307284function DI. value_and_gradient! (
308285 f:: F ,
309286 grad,
310287 prep:: EnzymeGradientPrep ,
311- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const} } ,
288+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
312289 x,
313290 contexts:: Vararg{Context,C} ,
314291) where {F,C}
328305
329306# # Jacobian
330307
308+ # TODO : does not support static arrays
309+
310+ #=
331311struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end
332312
333313function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
@@ -385,3 +365,4 @@ function DI.value_and_jacobian!(
385365 y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
386366 return y, copyto!(jac, new_jac)
387367end
368+ =#
0 commit comments