@@ -396,20 +396,23 @@ end
396396
397397# # HVP
398398
399- struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep
399+ struct FastDifferentiationHVPPrep{E2,E2!,E1 } <: HVPPrep
400400 hvp_exe:: E2
401401 hvp_exe!:: E2!
402+ gradient_prep:: E1
402403end
403404
404- function DI. prepare_hvp (f, :: AutoFastDifferentiation , x, tx:: NTuple )
405+ function DI. prepare_hvp (f, backend :: AutoFastDifferentiation , x, tx:: NTuple )
405406 x_var = make_variables (:x , size (x)... )
406407 y_var = f (x_var)
407408
408409 x_vec_var = vec (x_var)
409410 hv_vec_var, v_vec_var = hessian_times_v (y_var, x_vec_var)
410411 hvp_exe = make_function (hv_vec_var, vcat (x_vec_var, v_vec_var); in_place= false )
411412 hvp_exe! = make_function (hv_vec_var, vcat (x_vec_var, v_vec_var); in_place= true )
412- return FastDifferentiationHVPPrep (hvp_exe, hvp_exe!)
413+
414+ gradient_prep = DI. prepare_gradient (f, backend, x)
415+ return FastDifferentiationHVPPrep (hvp_exe, hvp_exe!, gradient_prep)
413416end
414417
415418function DI. hvp (
@@ -439,6 +442,28 @@ function DI.hvp!(
439442 return tg
440443end
441444
445+ function DI. gradient_and_hvp (
446+ f, prep:: FastDifferentiationHVPPrep , backend:: AutoFastDifferentiation , x, tx:: NTuple
447+ )
448+ tg = DI. hvp (f, prep, backend, x, tx)
449+ grad = DI. gradient (f, prep. gradient_prep, backend, x)
450+ return grad, tg
451+ end
452+
453+ function DI. gradient_and_hvp! (
454+ f,
455+ grad,
456+ tg:: NTuple ,
457+ prep:: FastDifferentiationHVPPrep ,
458+ backend:: AutoFastDifferentiation ,
459+ x,
460+ tx:: NTuple ,
461+ )
462+ DI. hvp! (f, tg, prep, backend, x, tx)
463+ DI. gradient! (f, grad, prep. gradient_prep, backend, x)
464+ return grad, tg
465+ end
466+
442467# # Hessian
443468
444469struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep
0 commit comments