Skip to content

Commit 06841e0

Browse files
authored
Improve Enzyme in forward mode with chunk size (#186)
1 parent 7cc6daf commit 06841e0

8 files changed

Lines changed: 141 additions & 19 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ using DifferentiationInterface:
77
NoGradientExtras,
88
NoJacobianExtras,
99
NoPullbackExtras,
10-
NoPushforwardExtras
10+
NoPushforwardExtras,
11+
pick_chunksize
1112
using DocStringExtensions
1213
using Enzyme:
1314
Active,
@@ -22,6 +23,7 @@ using Enzyme:
2223
ReverseMode,
2324
autodiff,
2425
autodiff_thunk,
26+
chunkedonehot,
2527
gradient,
2628
gradient!,
2729
jacobian,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,56 +33,78 @@ end
3333

3434
## Gradient
3535

36-
DI.prepare_gradient(f, ::AutoForwardEnzyme, x) = NoGradientExtras()
36+
struct EnzymeForwardGradientExtras{C,O}
37+
shadow::O
38+
end
39+
40+
function DI.prepare_gradient(f, ::AutoForwardEnzyme, x)
41+
C = pick_chunksize(length(x))
42+
shadow = chunkedonehot(x, Val(C))
43+
return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow)
44+
end
3745

38-
function DI.gradient(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoGradientExtras)
39-
return reshape(collect(gradient(backend.mode, f, x)), size(x))
46+
function DI.gradient(
47+
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
48+
) where {C}
49+
grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
50+
return reshape(collect(grad_tup), size(x))
4051
end
4152

4253
function DI.value_and_gradient(
43-
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras
54+
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras
4455
)
4556
return f(x), DI.gradient(f, backend, x, extras)
4657
end
4758

4859
function DI.gradient!(
49-
f, grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras
50-
)
51-
return copyto!(grad, DI.gradient(f, backend, x, extras))
60+
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
61+
) where {C}
62+
grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
63+
return copyto!(grad, grad_tup)
5264
end
5365

5466
function DI.value_and_gradient!(
55-
f, grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras
56-
)
57-
y, new_grad = DI.value_and_gradient(f, backend, x, extras)
58-
return y, copyto!(grad, new_grad)
67+
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
68+
) where {C}
69+
grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
70+
return f(x), copyto!(grad, grad_tup)
5971
end
6072

6173
## Jacobian
6274

63-
DI.prepare_jacobian(f, ::AutoForwardEnzyme, x) = NoJacobianExtras()
75+
struct EnzymeForwardOneArgJacobianExtras{C,O}
76+
shadow::O
77+
end
78+
79+
function DI.prepare_jacobian(f, ::AutoForwardEnzyme, x)
80+
C = pick_chunksize(length(x))
81+
shadow = chunkedonehot(x, Val(C))
82+
return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow)
83+
end
6484

65-
function DI.jacobian(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoJacobianExtras)
66-
jac_wrongshape = jacobian(backend.mode, f, x)
85+
function DI.jacobian(
86+
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras{C}
87+
) where {C}
88+
jac_wrongshape = jacobian(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
6789
nx = length(x)
6890
ny = length(jac_wrongshape) ÷ length(x)
6991
return reshape(jac_wrongshape, ny, nx)
7092
end
7193

7294
function DI.value_and_jacobian(
73-
f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras
95+
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
7496
)
7597
return f(x), DI.jacobian(f, backend, x, extras)
7698
end
7799

78100
function DI.jacobian!(
79-
f, jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras
101+
f, jac, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
80102
)
81103
return copyto!(jac, DI.jacobian(f, backend, x, extras))
82104
end
83105

84106
function DI.value_and_jacobian!(
85-
f, jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras
107+
f, jac, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
86108
)
87109
y, new_jac = DI.value_and_jacobian(f, backend, x, extras)
88110
return y, copyto!(jac, new_jac)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,63 @@ end
8484
function DI.value_and_gradient!(f, grad, backend::AutoReverseEnzyme, x, ::NoGradientExtras)
8585
return DI.value_and_pullback!(f, grad, backend, x, one(eltype(x)), NoPullbackExtras())
8686
end
87+
88+
## Jacobian
89+
90+
# see https://github.com/EnzymeAD/Enzyme.jl/issues/1391
91+
92+
#=
93+
94+
struct EnzymeReverseOneArgJacobianExtras{C,N} end
95+
96+
function DI.prepare_jacobian(f, ::AutoReverseEnzyme, x)
97+
C = pick_chunksize(length(x))
98+
y = f(x)
99+
N = length(y)
100+
return EnzymeReverseOneArgJacobianExtras{C,N}()
101+
end
102+
103+
function DI.jacobian(
104+
f,
105+
backend::AutoReverseEnzyme,
106+
x::AbstractArray,
107+
::EnzymeReverseOneArgJacobianExtras{C,N},
108+
) where {C,N}
109+
jac_wrongshape = jacobian(backend.mode, f, x, Val{N}(), Val{C}())
110+
nx = length(x)
111+
ny = length(jac_wrongshape) ÷ length(x)
112+
jac_rightshape = reshape(jac_wrongshape, ny, nx)
113+
return jac_rightshape
114+
end
115+
116+
function DI.value_and_jacobian(
117+
f,
118+
backend::AutoReverseEnzyme,
119+
x::AbstractArray,
120+
extras::EnzymeReverseOneArgJacobianExtras,
121+
)
122+
return f(x), DI.jacobian(f, backend, x, extras)
123+
end
124+
125+
function DI.jacobian!(
126+
f,
127+
jac,
128+
backend::AutoReverseEnzyme,
129+
x::AbstractArray,
130+
extras::EnzymeReverseOneArgJacobianExtras,
131+
)
132+
return copyto!(jac, DI.jacobian(f, backend, x, extras))
133+
end
134+
135+
function DI.value_and_jacobian!(
136+
f,
137+
jac,
138+
backend::AutoReverseEnzyme,
139+
x::AbstractArray,
140+
extras::EnzymeReverseOneArgJacobianExtras,
141+
)
142+
y, new_jac = DI.value_and_jacobian(f, backend, x, extras)
143+
return y, copyto!(jac, new_jac)
144+
end
145+
146+
=#

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ include("hessian.jl")
5757

5858
include("check.jl")
5959
include("sparse.jl")
60+
include("chunk.jl")
6061

6162
export AutoChainRules,
6263
AutoDiffractor,
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#=
2+
This heuristic is taken from ForwardDiff.jl.
3+
Source file: https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/prelude.jl
4+
=#
5+
6+
const DEFAULT_CHUNKSIZE = 8
7+
8+
"""
9+
pick_chunksize(input_length)
10+
11+
Pick a reasonable chunk size for chunked derivative evaluation with an input of length `input_length`.
12+
13+
The result cannot be larger than `DEFAULT_CHUNKSIZE=$DEFAULT_CHUNKSIZE`.
14+
"""
15+
function pick_chunksize(input_length::Integer; threshold::Integer=DEFAULT_CHUNKSIZE)
16+
if input_length <= threshold
17+
return input_length
18+
else
19+
nchunks = round(Int, input_length / threshold, RoundUp)
20+
return round(Int, input_length / nchunks, RoundUp)
21+
end
22+
end
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using DifferentiationInterface: pick_chunksize, DEFAULT_CHUNKSIZE
2+
3+
@test pick_chunksize.(1:DEFAULT_CHUNKSIZE) == 1:DEFAULT_CHUNKSIZE
4+
@test all(
5+
pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .<= DEFAULT_CHUNKSIZE
6+
)
7+
@test all(
8+
pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .>= DEFAULT_CHUNKSIZE / 2
9+
)

DifferentiationInterface/test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,9 @@ include("test_imports.jl")
4343
@testset "Weird arrays" begin
4444
include("weird_arrays.jl")
4545
end
46+
47+
@testset "Chunks" begin
48+
include("chunk.jl")
49+
end
4650
end
4751
end;

DifferentiationInterface/test/type_stability.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
type_stable_backends = [AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)]
1+
type_stable_backends = [
2+
AutoForwardDiff(), AutoEnzyme(Enzyme.Forward), AutoEnzyme(Enzyme.Reverse)
3+
]
24

35
test_differentiation(
46
type_stable_backends;

0 commit comments

Comments
 (0)