@@ -3,7 +3,8 @@ module DifferentiationInterfaceZygoteExt
33using ADTypes: AutoZygote
44import DifferentiationInterface as DI
55using DocStringExtensions
6- using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, withjacobian
6+ using Zygote:
7+ ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
78
89DI. supports_mutation (:: AutoZygote ) = DI. MutationNotSupported ()
910
3637
3738# # Jacobian
3839
39- function DI. value_and_jacobian (f, :: AutoZygote , x:: AbstractArray , extras:: Nothing )
40- return f (x), only (jacobian (f, x))
40+ function DI. value_and_jacobian (f, :: AutoZygote , x, extras:: Nothing )
41+ return f (x), only (jacobian (f, x)) # https://github.com/FluxML/Zygote.jl/issues/1506
4142end
4243
43- function DI. jacobian (f, :: AutoZygote , x:: AbstractArray , extras:: Nothing )
44+ function DI. jacobian (f, :: AutoZygote , x, extras:: Nothing )
4445 return only (jacobian (f, x))
4546end
4647
48+ function DI. value_and_jacobian!! (f, jac, backend:: AutoZygote , x, extras:: Nothing )
49+ return DI. value_and_jacobian (f, backend, x, extras)
50+ end
51+
52+ function DI. jacobian!! (f, jac, backend:: AutoZygote , x, extras:: Nothing )
53+ return DI. jacobian (f, backend, x, extras)
54+ end
55+
56+ # # Hessian
57+
58+ function DI. hessian (f, :: AutoZygote , x, extras:: Nothing )
59+ return hessian (f, x)
60+ end
61+
4762end
0 commit comments