@@ -210,82 +210,138 @@ end
210210
211211# # Gradient
212212
213+ # ## Without preparation
214+
215+ function DI. gradient (
216+ f:: F ,
217+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
218+ x,
219+ contexts:: Vararg{Context,C} ,
220+ ) where {F,C}
221+ f_and_df = get_f_and_df (f, backend)
222+ grad = make_zero (x)
223+ autodiff (
224+ reverse_noprimal (backend),
225+ f_and_df,
226+ Active,
227+ Duplicated (x, grad),
228+ map (translate, contexts)... ,
229+ )
230+ return grad
231+ end
232+
233+ function DI. value_and_gradient (
234+ f:: F ,
235+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
236+ x,
237+ contexts:: Vararg{Context,C} ,
238+ ) where {F,C}
239+ f_and_df = get_f_and_df (f, backend)
240+ grad = make_zero (x)
241+ _, y = autodiff (
242+ reverse_withprimal (backend),
243+ f_and_df,
244+ Active,
245+ Duplicated (x, grad),
246+ map (translate, contexts)... ,
247+ )
248+ return y, grad
249+ end
250+
251+ # ## With preparation
252+
253+ struct EnzymeGradientPrep{G} <: GradientPrep
254+ grad_righttype:: G
255+ end
256+
213257function DI. prepare_gradient (
214258 f:: F ,
215259 :: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
216260 x,
217261 contexts:: Vararg{Context,C} ,
218262) where {F,C}
219- return NoGradientPrep ()
263+ grad_righttype = make_zero (x)
264+ return EnzymeGradientPrep (grad_righttype)
220265end
221266
222267function DI. gradient (
223268 f:: F ,
224- :: NoGradientPrep ,
269+ :: EnzymeGradientPrep ,
225270 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
226271 x,
227272 contexts:: Vararg{Context,C} ,
228273) where {F,C}
229274 f_and_df = get_f_and_df (f, backend)
230- derivs = gradient (reverse_noprimal (backend), f_and_df, x, map (translate, contexts)... )
231- return first (derivs)
275+ grad = make_zero (x)
276+ autodiff (
277+ reverse_noprimal (backend),
278+ f_and_df,
279+ Duplicated (x, grad),
280+ map (translate, contexts)... ,
281+ )
282+ return grad
232283end
233284
234285function DI. gradient! (
235286 f:: F ,
236287 grad,
237- :: NoGradientPrep ,
288+ prep :: EnzymeGradientPrep ,
238289 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
239290 x,
240291 contexts:: Vararg{Context,C} ,
241292) where {F,C}
242293 f_and_df = get_f_and_df (f, backend)
243- dx_righttype = convert ( typeof (x), grad)
244- make_zero! (dx_righttype )
294+ grad_righttype = grad isa typeof (x) ? grad : prep . grad_righttype
295+ make_zero! (grad_righttype )
245296 autodiff (
246297 reverse_noprimal (backend),
247298 f_and_df,
248299 Active,
249- Duplicated (x, dx_righttype ),
300+ Duplicated (x, grad_righttype ),
250301 map (translate, contexts)... ,
251302 )
252- dx_righttype === grad || copyto! (grad, dx_righttype )
303+ grad isa typeof (x) || copyto! (grad, grad_righttype )
253304 return grad
254305end
255306
256307function DI. value_and_gradient (
257308 f:: F ,
258- :: NoGradientPrep ,
309+ :: EnzymeGradientPrep ,
259310 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
260311 x,
261312 contexts:: Vararg{Context,C} ,
262313) where {F,C}
263314 f_and_df = get_f_and_df (f, backend)
264- (; derivs, val) = gradient (
265- reverse_withprimal (backend), f_and_df, x, map (translate, contexts)...
315+ grad = make_zero (x)
316+ _, y = autodiff (
317+ reverse_withprimal (backend),
318+ f_and_df,
319+ Active,
320+ Duplicated (x, grad),
321+ map (translate, contexts)... ,
266322 )
267- return val, first (derivs)
323+ return y, grad
268324end
269325
270326function DI. value_and_gradient! (
271327 f:: F ,
272328 grad,
273- :: NoGradientPrep ,
329+ prep :: EnzymeGradientPrep ,
274330 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
275331 x,
276332 contexts:: Vararg{Context,C} ,
277333) where {F,C}
278334 f_and_df = get_f_and_df (f, backend)
279- dx_righttype = convert ( typeof (x), grad)
280- make_zero! (dx_righttype )
335+ grad_righttype = grad isa typeof (x) ? grad : prep . grad_righttype
336+ make_zero! (grad_righttype )
281337 _, y = autodiff (
282338 reverse_withprimal (backend),
283339 f_and_df,
284340 Active,
285- Duplicated (x, dx_righttype ),
341+ Duplicated (x, grad_righttype ),
286342 map (translate, contexts)... ,
287343 )
288- dx_righttype === grad || copyto! (grad, dx_righttype )
344+ grad isa typeof (x) || copyto! (grad, grad_righttype )
289345 return y, grad
290346end
291347
0 commit comments