@@ -275,3 +275,96 @@ function DI.value_gradient_and_hessian!(
275275 DI. hessian! (f, hess, prep, backend, x)
276276 return y, grad, hess
277277end
278+
279+ # # HVP
280+
281+ struct SymbolicsOneArgHVPPrep{G,E2,E2!} <: HVPPrep
282+ gradient_prep:: G
283+ hvp_exe:: E2
284+ hvp_exe!:: E2!
285+ end
286+
287+ function DI. prepare_hvp (f, backend:: AutoSymbolics , x, tx:: NTuple )
288+ x_var = variables (:x , axes (x)... )
289+ dx_var = variables (:dx , axes (x)... )
290+ # Symbolic.hessian only accepts vectors
291+ hess_var = hessian (f (x_var), vec (x_var))
292+ hvp_vec_var = hess_var * vec (dx_var)
293+
294+ res = build_function (hvp_vec_var, vcat (vec (x_var), vec (dx_var)); expression= Val (false ))
295+ (hvp_exe, hvp_exe!) = res
296+
297+ gradient_prep = DI. prepare_gradient (f, backend, x)
298+ return SymbolicsOneArgHVPPrep (gradient_prep, hvp_exe, hvp_exe!)
299+ end
300+
301+ function DI. hvp (f, prep:: SymbolicsOneArgHVPPrep , :: AutoSymbolics , x, tx:: NTuple )
302+ return map (tx) do dx
303+ v_vec = vcat (vec (x), vec (dx))
304+ dg_vec = prep. hvp_exe (v_vec)
305+ reshape (dg_vec, size (x))
306+ end
307+ end
308+
309+ function DI. hvp! (
310+ f, tg:: NTuple , prep:: SymbolicsOneArgHVPPrep , :: AutoSymbolics , x, tx:: NTuple
311+ )
312+ for b in eachindex (tx, tg)
313+ dx, dg = tx[b], tg[b]
314+ v_vec = vcat (vec (x), vec (dx))
315+ prep. hvp_exe! (vec (dg), v_vec)
316+ end
317+ return tg
318+ end
319+
320+ # # Second derivative
321+
322+ struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep
323+ derivative_prep:: D
324+ der2_exe:: E1
325+ der2_exe!:: E1!
326+ end
327+
328+ function DI. prepare_second_derivative (f, backend:: AutoSymbolics , x)
329+ x_var = variable (:x )
330+ der_var = derivative (f (x_var), x_var)
331+ der2_var = derivative (der_var, x_var)
332+
333+ res = build_function (der2_var, x_var; expression= Val (false ))
334+ (der2_exe, der2_exe!) = if res isa Tuple
335+ res
336+ elseif res isa RuntimeGeneratedFunction
337+ res, nothing
338+ end
339+ derivative_prep = DI. prepare_derivative (f, backend, x)
340+ return SymbolicsOneArgSecondDerivativePrep (derivative_prep, der2_exe, der2_exe!)
341+ end
342+
343+ function DI. second_derivative (
344+ f, prep:: SymbolicsOneArgSecondDerivativePrep , :: AutoSymbolics , x
345+ )
346+ return prep. der2_exe (x)
347+ end
348+
349+ function DI. second_derivative! (
350+ f, der2, prep:: SymbolicsOneArgSecondDerivativePrep , :: AutoSymbolics , x
351+ )
352+ prep. der2_exe! (der2, x)
353+ return der2
354+ end
355+
356+ function DI. value_derivative_and_second_derivative (
357+ f, prep:: SymbolicsOneArgSecondDerivativePrep , backend:: AutoSymbolics , x
358+ )
359+ y, der = DI. value_and_derivative (f, prep. derivative_prep, backend, x)
360+ der2 = DI. second_derivative (f, prep, backend, x)
361+ return y, der, der2
362+ end
363+
364+ function DI. value_derivative_and_second_derivative! (
365+ f, der, der2, prep:: SymbolicsOneArgSecondDerivativePrep , backend:: AutoSymbolics , x
366+ )
367+ y, _ = DI. value_and_derivative! (f, der, prep. derivative_prep, backend, x)
368+ DI. second_derivative! (f, der2, prep, backend, x)
369+ return y, der, der2
370+ end
0 commit comments