@@ -219,14 +219,8 @@ function DI.gradient(
219219 contexts:: Vararg{Context,C} ,
220220) where {F,C}
221221 f_and_df = get_f_and_df (f, backend)
222- grad = make_zero (x)
223- autodiff (
224- reverse_noprimal (backend),
225- f_and_df,
226- Active,
227- Duplicated (x, grad),
228- map (translate, contexts)... ,
229- )
222+ ders = gradient (reverse_noprimal (backend), f_and_df, x, map (translate, contexts)... )
223+ grad = first (ders)
230224 return grad
231225end
232226
@@ -237,14 +231,10 @@ function DI.value_and_gradient(
237231 contexts:: Vararg{Context,C} ,
238232) where {F,C}
239233 f_and_df = get_f_and_df (f, backend)
240- grad = make_zero (x)
241- _, y = autodiff (
242- reverse_withprimal (backend),
243- f_and_df,
244- Active,
245- Duplicated (x, grad),
246- map (translate, contexts)... ,
234+ ders, y = gradient (
235+ reverse_withprimal (backend), f_and_df, x, map (translate, contexts)...
247236 )
237+ grad = first (ders)
248238 return y, grad
249239end
250240
@@ -272,13 +262,8 @@ function DI.gradient(
272262 contexts:: Vararg{Context,C} ,
273263) where {F,C}
274264 f_and_df = get_f_and_df (f, backend)
275- grad = make_zero (x)
276- autodiff (
277- reverse_noprimal (backend),
278- f_and_df,
279- Duplicated (x, grad),
280- map (translate, contexts)... ,
281- )
265+ ders = gradient (reverse_noprimal (backend), f_and_df, x, map (translate, contexts)... )
266+ grad = first (ders)
282267 return grad
283268end
284269
@@ -300,7 +285,7 @@ function DI.gradient!(
300285 Duplicated (x, grad_righttype),
301286 map (translate, contexts)... ,
302287 )
303- grad isa typeof (x) || copyto! (grad, grad_righttype)
288+ grad === grad_righttype || copyto! (grad, grad_righttype)
304289 return grad
305290end
306291
@@ -312,14 +297,10 @@ function DI.value_and_gradient(
312297 contexts:: Vararg{Context,C} ,
313298) where {F,C}
314299 f_and_df = get_f_and_df (f, backend)
315- grad = make_zero (x)
316- _, y = autodiff (
317- reverse_withprimal (backend),
318- f_and_df,
319- Active,
320- Duplicated (x, grad),
321- map (translate, contexts)... ,
300+ ders, y = gradient (
301+ reverse_withprimal (backend), f_and_df, x, map (translate, contexts)...
322302 )
303+ grad = first (ders)
323304 return y, grad
324305end
325306
@@ -341,7 +322,7 @@ function DI.value_and_gradient!(
341322 Duplicated (x, grad_righttype),
342323 map (translate, contexts)... ,
343324 )
344- grad isa typeof (x) || copyto! (grad, grad_righttype)
325+ grad === grad_righttype || copyto! (grad, grad_righttype)
345326 return y, grad
346327end
347328
0 commit comments