Skip to content

Commit 07d272a

Browse files
authored
Improve Enzyme extension, add hessians to ForwardDiff and ReverseDiff (#118)
* Improve Enzyme extension, add hessians to ForwardDiff and ReverseDiff * Remove Enzyme forward from type stable backends
1 parent 633af63 commit 07d272a

13 files changed

Lines changed: 151 additions & 31 deletions

File tree

docs/src/overview.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,19 @@ Several variants of each operator are defined:
3636
!!! warning
3737
The "bang-bang" syntactic convention `!!` signals that some of the arguments _can_ be mutated, but they do not _have to be_.
3838
Such arguments will always be part of the return, so that one can simply reuse the operator's output and forget its input.
39-
4039
In other words, this is good:
4140
```julia
42-
grad = gradient!!(f, grad, backend, x) # do this
41+
# work with grad_in
42+
grad_out = gradient!!(f, grad_in, backend, x)
43+
# work with grad_out
4344
```
44-
On the other hand, this is bad, because if `grad` has not been mutated, you will get wrong results:
45+
On the other hand, this is bad, because if `grad_in` has not been mutated, you will forget the results:
4546
```julia
46-
gradient!!(f, grad, backend, x) # don't do this
47+
# work with grad_in
48+
gradient!!(f, grad_in, backend, x)
49+
# mistakenly keep working with grad_in
4750
```
51+
Note that we don't guarantee `grad_out` will have the same type as `grad_in`.
4852

4953
## Second order
5054

ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module DifferentiationInterfaceEnzymeExt
22

33
using ADTypes: ADTypes, AutoEnzyme
4-
using DifferentiationInterface: myupdate!!
54
import DifferentiationInterface as DI
65
using DocStringExtensions
76
using Enzyme:
@@ -49,6 +48,12 @@ function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T
4948
return b
5049
end
5150

51+
function zero_sametype!!(x_target, x)
52+
x_sametype = convert(typeof(x), x_target)
53+
x_sametype .= zero(eltype(x))
54+
return x_sametype
55+
end
56+
5257
include("forward_allocating.jl")
5358
include("forward_mutating.jl")
5459

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,76 @@
11
## Pushforward
22

33
function DI.value_and_pushforward(f, backend::AutoForwardEnzyme, x, dx, extras::Nothing)
4-
dx_sametype = convert(typeof(x), copy(dx))
4+
dx_sametype = convert(typeof(x), dx)
55
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
66
return y, new_dy
77
end
8+
9+
function DI.pushforward(f, backend::AutoForwardEnzyme, x, dx, extras::Nothing)
10+
dx_sametype = convert(typeof(x), dx)
11+
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype)))
12+
return new_dy
13+
end
14+
15+
function DI.value_and_pushforward!!(
16+
f, _dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing
17+
)
18+
# dy cannot be passed anyway
19+
return DI.value_and_pushforward(f, backend, x, dx, extras)
20+
end
21+
22+
function DI.pushforward!!(f, _dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing)
23+
# dy cannot be passed anyway
24+
return DI.pushforward(f, backend, x, dx, extras)
25+
end
26+
27+
## Gradient
28+
29+
function DI.gradient(f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing)
30+
return reshape(collect(gradient(backend.mode, f, x)), size(x))
31+
end
32+
33+
function DI.value_and_gradient(
34+
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
35+
)
36+
return f(x), DI.gradient(f, backend, x, extras)
37+
end
38+
39+
function DI.gradient!!(
40+
f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
41+
)
42+
return DI.gradient(f, backend, x, extras)
43+
end
44+
45+
function DI.value_and_gradient!!(
46+
f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
47+
)
48+
return DI.value_and_gradient(f, backend, x, extras)
49+
end
50+
51+
## Jacobian
52+
53+
function DI.jacobian(f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing)
54+
jac_wrongshape = jacobian(backend.mode, f, x)
55+
nx = length(x)
56+
ny = length(jac_wrongshape) ÷ length(x)
57+
return reshape(jac_wrongshape, ny, nx)
58+
end
59+
60+
function DI.value_and_jacobian(
61+
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
62+
)
63+
return f(x), DI.jacobian(f, backend, x, extras)
64+
end
65+
66+
function DI.jacobian!!(
67+
f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
68+
)
69+
return DI.jacobian(f, backend, x, extras)
70+
end
71+
72+
function DI.value_and_jacobian!!(
73+
f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing
74+
)
75+
return DI.value_and_jacobian(f, backend, x, extras)
76+
end

ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
function DI.value_and_pushforward!!(
44
f!, y, dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing
55
)
6-
dx_sametype = convert(typeof(x), copy(dx))
7-
dy_sametype = convert(typeof(y), dy)
6+
dx_sametype = convert(typeof(x), dx)
7+
dy_sametype = zero_sametype!!(dy, y)
88
autodiff(
99
backend.mode, f!, Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype)
1010
)
11-
return y, myupdate!!(dy, dy_sametype)
11+
return y, dy_sametype
1212
end

ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
## Pullback
22

3+
### Out-of-place
4+
35
function DI.value_and_pullback(
46
f, ::AutoReverseEnzyme, x::Number, dy::Number, extras::Nothing
57
)
@@ -20,28 +22,28 @@ function DI.value_and_pullback(
2022
return y, new_dx
2123
end
2224

25+
### In-place
26+
2327
function DI.value_and_pullback!!(
2428
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::Number, extras::Nothing
2529
)
26-
dx_sametype = convert(typeof(x), dx)
27-
dx_sametype .= zero(eltype(dx_sametype))
30+
dx_sametype = zero_sametype!!(dx, x)
2831
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
2932
dx_sametype .*= dy
30-
return y, myupdate!!(dx, dx_sametype)
33+
return y, dx_sametype
3134
end
3235

3336
function DI.value_and_pullback!!(
3437
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::AbstractArray, extras::Nothing
3538
)
36-
dx_sametype = convert(typeof(x), dx)
37-
dx_sametype .= zero(eltype(dx_sametype))
39+
dx_sametype = zero_sametype!!(dx, x)
3840
forw, rev = autodiff_thunk(
3941
ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, Duplicated{typeof(x)}
4042
)
4143
tape, y, new_dy = forw(Const(f), Duplicated(x, dx_sametype))
4244
new_dy .= dy
4345
rev(Const(f), Duplicated(x, dx_sametype), tape)
44-
return y, myupdate!!(dx, dx_sametype)
46+
return y, dx_sametype
4547
end
4648

4749
function DI.value_and_pullback(f, backend::AutoReverseEnzyme, x::AbstractArray, dy, extras)
@@ -58,6 +60,5 @@ end
5860
function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x::AbstractArray, extras::Nothing)
5961
grad_sametype = convert(typeof(x), grad)
6062
gradient!(Reverse, grad_sametype, f, x)
61-
grad .= grad_sametype
62-
return grad
63+
return grad_sametype
6364
end

ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@ end
1111
function DI.value_and_pullback!!(
1212
f!, y, dx, ::AutoReverseEnzyme, x::AbstractArray, dy, extras::Nothing
1313
)
14-
dx_sametype = convert(typeof(x), dx)
15-
dx_sametype .= zero(eltype(dx_sametype))
16-
dy_sametype = convert(typeof(y), copy(dy)) # TODO: how to get rid of copy?
14+
dx_sametype = zero_sametype!!(dx, x)
15+
dy_sametype = convert(typeof(y), copy(dy))
1716
autodiff(Reverse, f!, Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype))
18-
return y, myupdate!!(dx, dx_sametype)
17+
return y, dx_sametype
1918
end

ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using ForwardDiff:
99
DerivativeConfig,
1010
ForwardDiff,
1111
GradientConfig,
12+
HessianConfig,
1213
JacobianConfig,
1314
Tag,
1415
derivative,
@@ -17,6 +18,8 @@ using ForwardDiff:
1718
extract_derivative!,
1819
gradient,
1920
gradient!,
21+
hessian,
22+
hessian!,
2023
jacobian,
2124
jacobian!,
2225
value

ext/DifferentiationInterfaceForwardDiffExt/allocating.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,19 @@ end
7070
function DI.jacobian(f, ::AutoForwardDiff, x::AbstractArray, config::JacobianConfig)
7171
return jacobian(f, x, config)
7272
end
73+
74+
## Hessian
75+
76+
function DI.prepare_hessian(f, backend::AutoForwardDiff, x::AbstractArray)
77+
return HessianConfig(f, x, choose_chunk(backend, x))
78+
end
79+
80+
function DI.hessian!!(
81+
f, hess::AbstractMatrix, ::AutoForwardDiff, x::AbstractArray, config::HessianConfig
82+
)
83+
return hessian!(hess, f, x, config)
84+
end
85+
86+
function DI.hessian(f, ::AutoForwardDiff, x::AbstractArray, config::HessianConfig)
87+
return hessian(f, x, config)
88+
end

ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@ using DocStringExtensions
77
using LinearAlgebra: mul!
88
using ReverseDiff:
99
CompiledGradient,
10+
CompiledHessian,
1011
CompiledJacobian,
1112
GradientConfig,
1213
GradientTape,
14+
HessianConfig,
15+
HessianTape,
1316
JacobianConfig,
1417
JacobianTape,
1518
compile,
1619
gradient,
1720
gradient!,
21+
hessian,
22+
hessian!,
1823
jacobian,
1924
jacobian!
2025

ext/DifferentiationInterfaceReverseDiffExt/allocating.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,29 @@ function DI.jacobian(
141141
)
142142
return jacobian!(tape, x)
143143
end
144+
145+
## Hessian
146+
147+
function DI.prepare_hessian(f, backend::AutoReverseDiff, x::AbstractArray)
148+
tape = HessianTape(f, x)
149+
if backend.compile
150+
tape = compile(tape)
151+
end
152+
return tape
153+
end
154+
155+
function DI.hessian!!(
156+
_f,
157+
hess::AbstractMatrix,
158+
::AutoReverseDiff,
159+
x::AbstractArray,
160+
tape::Union{HessianTape,CompiledHessian},
161+
)
162+
return hessian!(hess, tape, x)
163+
end
164+
165+
function DI.hessian(
166+
_f, ::AutoReverseDiff, x::AbstractArray, tape::Union{HessianTape,CompiledHessian}
167+
)
168+
return hessian!(tape, x)
169+
end

0 commit comments

Comments
 (0)