Skip to content

Commit e2611fa

Browse files
authored
Hessian from HVPs, tested in each mod (#116)
1 parent 38fd496 commit e2611fa

5 files changed

Lines changed: 35 additions & 10 deletions

File tree

ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ module DifferentiationInterfaceZygoteExt
33
using ADTypes: AutoZygote
44
import DifferentiationInterface as DI
55
using DocStringExtensions
6-
using Zygote: ZygoteRuleConfig, gradient, jacobian, pullback, withgradient, withjacobian
6+
using Zygote:
7+
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
78

89
DI.supports_mutation(::AutoZygote) = DI.MutationNotSupported()
910

@@ -36,12 +37,26 @@ end
3637

3738
## Jacobian
3839

39-
function DI.value_and_jacobian(f, ::AutoZygote, x::AbstractArray, extras::Nothing)
40-
return f(x), only(jacobian(f, x))
40+
function DI.value_and_jacobian(f, ::AutoZygote, x, extras::Nothing)
41+
return f(x), only(jacobian(f, x)) # https://github.com/FluxML/Zygote.jl/issues/1506
4142
end
4243

43-
function DI.jacobian(f, ::AutoZygote, x::AbstractArray, extras::Nothing)
44+
function DI.jacobian(f, ::AutoZygote, x, extras::Nothing)
4445
return only(jacobian(f, x))
4546
end
4647

48+
function DI.value_and_jacobian!!(f, jac, backend::AutoZygote, x, extras::Nothing)
49+
return DI.value_and_jacobian(f, backend, x, extras)
50+
end
51+
52+
function DI.jacobian!!(f, jac, backend::AutoZygote, x, extras::Nothing)
53+
return DI.jacobian(f, backend, x, extras)
54+
end
55+
56+
## Hessian
57+
58+
function DI.hessian(f, ::AutoZygote, x, extras::Nothing)
59+
return hessian(f, x)
60+
end
61+
4762
end

src/DifferentiationInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using ADTypes:
1818
AbstractSymbolicDifferentiationMode
1919
using DocStringExtensions
2020
using FillArrays: OneElement
21-
using LinearAlgebra: dot
21+
using LinearAlgebra: Symmetric, dot
2222

2323
"""
2424
AutoFastDifferentiation

src/hessian.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ function hessian(f, backend::AbstractADType, x, extras=prepare_hessian(f, backen
1010
end
1111

1212
function hessian(f, backend::SecondOrder, x, extras=prepare_hessian(f, backend, x))
13-
# suboptimal for reverse-over-forward
14-
gradient_closure(z) = gradient(f, inner(backend), z, inner(extras))
15-
hess = jacobian(gradient_closure, outer(backend), x, outer(extras))
16-
return hess
13+
hess = stack(vec(CartesianIndices(x))) do j
14+
hess_col_j = hvp(f, backend, x, basis(backend, x, j), extras)
15+
vec(hess_col_j)
16+
end
17+
return Symmetric(hess)
1718
end

src/hvp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function hvp(f, backend::SecondOrder, x, v, extras=prepare_hvp(f, backend, x))
2929
return hvp_aux(f, backend, x, v, extras, hvp_mode(backend))
3030
end
3131

32-
function hvp_aux(f, backend, x, v, extras, orwardOverReverse)
32+
function hvp_aux(f, backend, x, v, extras, ::ForwardOverReverse)
3333
# JVP of the gradient
3434
gradient_closure(z) = gradient(f, inner(backend), z, inner(extras))
3535
p = pushforward(gradient_closure, outer(backend), x, v, outer(extras))

test/second_order.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
using Enzyme: Enzyme
2+
using FiniteDiff: FiniteDiff
23
using ForwardDiff: ForwardDiff
34
using ReverseDiff: ReverseDiff
5+
using Tracker: Tracker
6+
using Zygote: Zygote
47

58
second_order_backends = [AutoForwardDiff(), AutoReverseDiff()]
69

710
second_order_mixed_backends = [
11+
# forward over forward
812
SecondOrder(AutoEnzyme(Enzyme.Forward), AutoForwardDiff()),
913
SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Forward)),
14+
# forward over reverse
1015
SecondOrder(AutoForwardDiff(), AutoZygote()),
16+
# reverse over forward
17+
SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoForwardDiff()),
18+
# reverse over reverse
19+
SecondOrder(AutoReverseDiff(), AutoZygote()),
1120
]
1221

1322
for backend in vcat(second_order_backends, second_order_mixed_backends)

0 commit comments

Comments
 (0)