Skip to content

Commit f262504

Browse files
ExpandingMangdalle
andauthored
Adapt to Enzyme v0.13 (#471)
* updates for Enzyme 0.13 * remove old enzyme version * small fix * several more fixes * fix something really dumb * some more fixes * that wasnt actually a bug in Enzyme, its just really confusing * Clean up both modes, use batches in forward mode * Add in-place reverse batching, remove Lux tests temporarily * Fixes * Remove hvp * Remove nested docstring * Remove lux --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 5750b31 commit f262504

12 files changed

Lines changed: 418 additions & 330 deletions

File tree

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
- Misc/SparsityDetector
5353
- Misc/ZeroBackends
5454
- Down/Flux
55-
- Down/Lux
55+
# - Down/Lux
5656
exclude:
5757
# lts
5858
- version: "lts"

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ADTypes = "1.7.0"
4848
ChainRulesCore = "1.23.0"
4949
Compat = "3.46,4.2"
5050
Diffractor = "=0.2.6"
51-
Enzyme = "0.12.35"
51+
Enzyme = "0.13.1"
5252
FastDifferentiation = "0.3.17"
5353
FiniteDiff = "2.23.1"
5454
FiniteDifferences = "0.12.31"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DifferentiationInterfaceEnzymeExt
22

33
using ADTypes: ADTypes, AutoEnzyme
4+
using Base: Fix1
45
import DifferentiationInterface as DI
56
using DifferentiationInterface:
67
DerivativeExtras,
@@ -20,23 +21,24 @@ using DifferentiationInterface:
2021
using Enzyme:
2122
Active,
2223
Annotation,
24+
BatchDuplicated,
2325
Const,
2426
Duplicated,
2527
DuplicatedNoNeed,
2628
EnzymeCore,
2729
Forward,
2830
ForwardMode,
31+
ForwardWithPrimal,
2932
MixedDuplicated,
3033
Mode,
3134
Reverse,
32-
ReverseWithPrimal,
33-
ReverseSplitWithPrimal,
3435
ReverseMode,
36+
ReverseModeSplit,
37+
ReverseSplitWithPrimal,
38+
ReverseWithPrimal,
3539
autodiff,
36-
autodiff_deferred,
37-
autodiff_deferred_thunk,
3840
autodiff_thunk,
39-
chunkedonehot,
41+
create_shadows,
4042
gradient,
4143
gradient!,
4244
guess_activity,
@@ -47,6 +49,8 @@ using Enzyme:
4749
make_zero!,
4850
onehot
4951

52+
DI.check_available(::AutoEnzyme) = true
53+
5054
include("utils.jl")
5155

5256
include("forward_onearg.jl")
@@ -55,6 +59,4 @@ include("forward_twoarg.jl")
5559
include("reverse_onearg.jl")
5660
include("reverse_twoarg.jl")
5761

58-
include("second_order.jl")
59-
6062
end # module

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,72 @@
11
## Pushforward
22

33
function DI.prepare_pushforward(
4-
f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
4+
f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
55
)
66
return NoPushforwardExtras()
77
end
88

99
function DI.value_and_pushforward(
1010
f,
11-
extras::NoPushforwardExtras,
12-
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
11+
::NoPushforwardExtras,
12+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1313
x,
14-
tx::Tangents,
14+
tx::Tangents{1},
1515
)
16-
ty = map(tx) do dx
17-
only(DI.pushforward(f, extras, backend, x, Tangents(dx)))
18-
end
19-
y = f(x)
20-
return y, ty
16+
f_and_df = get_f_and_df(f, backend)
17+
dx_sametype = convert(typeof(x), only(tx))
18+
x_and_dx = Duplicated(x, dx_sametype)
19+
dy, y = autodiff(forward_mode_withprimal(backend), f_and_df, x_and_dx)
20+
return y, Tangents(dy)
2121
end
2222

2323
function DI.value_and_pushforward(
2424
f,
2525
::NoPushforwardExtras,
26-
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
26+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
2727
x,
28-
tx::Tangents{1},
29-
)
30-
dx = only(tx)
31-
f_and_df = get_f_and_df(f, backend)
32-
dx_sametype = convert(typeof(x), dx)
33-
x_and_dx = Duplicated(x, dx_sametype)
34-
y, new_dy = if backend isa AutoDeferredEnzyme
35-
autodiff_deferred(forward_mode(backend), f_and_df, Duplicated, x_and_dx)
36-
else
37-
autodiff(forward_mode(backend), f_and_df, Duplicated, x_and_dx)
38-
end
39-
return y, Tangents(new_dy)
28+
tx::Tangents{B},
29+
) where {B}
30+
f_and_df = get_f_and_df(f, backend, Val(B))
31+
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
32+
x_and_dxs = BatchDuplicated(x, dxs_sametype)
33+
dys, y = autodiff(forward_mode_withprimal(backend), f_and_df, x_and_dxs)
34+
return y, Tangents(dys...)
4035
end
4136

4237
function DI.pushforward(
4338
f,
4439
::NoPushforwardExtras,
45-
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
40+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
4641
x,
4742
tx::Tangents{1},
4843
)
49-
dx = only(tx)
5044
f_and_df = get_f_and_df(f, backend)
51-
dx_sametype = convert(typeof(x), dx)
45+
dx_sametype = convert(typeof(x), only(tx))
5246
x_and_dx = Duplicated(x, dx_sametype)
53-
new_dy = if backend isa AutoDeferredEnzyme
54-
only(autodiff_deferred(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx))
55-
else
56-
only(autodiff(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx))
57-
end
58-
return Tangents(new_dy)
47+
dy = only(autodiff(forward_mode_noprimal(backend), f_and_df, x_and_dx))
48+
return Tangents(dy)
49+
end
50+
51+
function DI.pushforward(
52+
f,
53+
::NoPushforwardExtras,
54+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
55+
x,
56+
tx::Tangents{B},
57+
) where {B}
58+
f_and_df = get_f_and_df(f, backend, Val(B))
59+
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
60+
x_and_dxs = BatchDuplicated(x, dxs_sametype)
61+
dys = only(autodiff(forward_mode_noprimal(backend), f_and_df, x_and_dxs))
62+
return Tangents(dys...)
5963
end
6064

6165
function DI.value_and_pushforward!(
6266
f,
6367
ty::Tangents,
6468
extras::NoPushforwardExtras,
65-
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
69+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6670
x,
6771
tx::Tangents,
6872
)
@@ -75,7 +79,7 @@ function DI.pushforward!(
7579
f,
7680
ty::Tangents,
7781
extras::NoPushforwardExtras,
78-
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
82+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
7983
x,
8084
tx::Tangents,
8185
)
@@ -86,15 +90,15 @@ end
8690
## Gradient
8791

8892
struct EnzymeForwardGradientExtras{B,O} <: GradientExtras
89-
shadow::O
93+
shadows::O
9094
end
9195

9296
function DI.prepare_gradient(
9397
f, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
9498
)
9599
B = pick_batchsize(backend, length(x))
96-
shadow = chunkedonehot(x, Val(B))
97-
return EnzymeForwardGradientExtras{B,typeof(shadow)}(shadow)
100+
shadows = create_shadows(Val(B), x)
101+
return EnzymeForwardGradientExtras{B,typeof(shadows)}(shadows)
98102
end
99103

100104
function DI.gradient(
@@ -104,17 +108,23 @@ function DI.gradient(
104108
x,
105109
) where {B}
106110
f_and_df = get_f_and_df(f, backend)
107-
grad_tup = gradient(forward_mode(backend), f_and_df, x, Val(B); shadow=extras.shadow)
108-
return reshape(collect(grad_tup), size(x))
111+
derivs = gradient(
112+
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
113+
)
114+
return only(derivs)
109115
end
110116

111117
function DI.value_and_gradient(
112118
f,
113-
extras::EnzymeForwardGradientExtras,
119+
extras::EnzymeForwardGradientExtras{B},
114120
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
115121
x,
116-
)
117-
return f(x), DI.gradient(f, extras, backend, x)
122+
) where {B}
123+
f_and_df = get_f_and_df(f, backend)
124+
(; derivs, val) = gradient(
125+
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
126+
)
127+
return val, only(derivs)
118128
end
119129

120130
function DI.gradient!(
@@ -124,9 +134,7 @@ function DI.gradient!(
124134
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
125135
x,
126136
) where {B}
127-
f_and_df = get_f_and_df(f, backend)
128-
grad_tup = gradient(forward_mode(backend), f_and_df, x, Val(B); shadow=extras.shadow)
129-
return copyto!(grad, grad_tup)
137+
return copyto!(grad, DI.gradient(f, extras, backend, x))
130138
end
131139

132140
function DI.value_and_gradient!(
@@ -136,27 +144,24 @@ function DI.value_and_gradient!(
136144
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
137145
x,
138146
) where {B}
139-
f_and_df = get_f_and_df(f, backend)
140-
grad_tup = gradient(forward_mode(backend), f_and_df, x, Val(B); shadow=extras.shadow)
141-
return f(x), copyto!(grad, grad_tup)
147+
y, new_grad = DI.value_and_gradient(f, extras, backend, x)
148+
return y, copyto!(grad, new_grad)
142149
end
143150

144151
## Jacobian
145152

146153
struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras
147-
shadow::O
154+
shadows::O
155+
output_length::Int
148156
end
149157

150158
function DI.prepare_jacobian(
151159
f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
152160
)
161+
y = f(x)
153162
B = pick_batchsize(backend, length(x))
154-
if B == 1
155-
shadow = onehot(x)
156-
else
157-
shadow = chunkedonehot(x, Val(B))
158-
end
159-
return EnzymeForwardOneArgJacobianExtras{B,typeof(shadow)}(shadow)
163+
shadows = create_shadows(Val(B), x)
164+
return EnzymeForwardOneArgJacobianExtras{B,typeof(shadows)}(shadows, length(y))
160165
end
161166

162167
function DI.jacobian(
@@ -166,21 +171,25 @@ function DI.jacobian(
166171
x,
167172
) where {B}
168173
f_and_df = get_f_and_df(f, backend)
169-
jac_wrongshape = jacobian(
170-
forward_mode(backend), f_and_df, x, Val(B); shadow=extras.shadow
174+
derivs = jacobian(
175+
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
171176
)
172-
nx = length(x)
173-
ny = length(jac_wrongshape) ÷ length(x)
174-
return reshape(jac_wrongshape, ny, nx)
177+
jac_tensor = only(derivs)
178+
return maybe_reshape(jac_tensor, extras.output_length, length(x))
175179
end
176180

177181
function DI.value_and_jacobian(
178182
f,
179-
extras::EnzymeForwardOneArgJacobianExtras,
183+
extras::EnzymeForwardOneArgJacobianExtras{B},
180184
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
181185
x,
182-
)
183-
return f(x), DI.jacobian(f, extras, backend, x)
186+
) where {B}
187+
f_and_df = get_f_and_df(f, backend)
188+
(; derivs, val) = jacobian(
189+
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
190+
)
191+
jac_tensor = only(derivs)
192+
return val, maybe_reshape(jac_tensor, extras.output_length, length(x))
184193
end
185194

186195
function DI.jacobian!(
Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,41 @@
11
## Pushforward
22

33
function DI.prepare_pushforward(
4-
f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
4+
f!, y, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
55
)
66
return NoPushforwardExtras()
77
end
88

99
function DI.value_and_pushforward(
1010
f!,
1111
y,
12-
extras::NoPushforwardExtras,
13-
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
12+
::NoPushforwardExtras,
13+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1414
x,
15-
tx::Tangents,
15+
tx::Tangents{1},
1616
)
17-
ty = map(tx) do dx
18-
only(DI.pushforward(f!, y, extras, backend, x, Tangents(dx)))
19-
end
20-
f!(y, x)
21-
return y, ty
17+
f!_and_df! = get_f_and_df(f!, backend)
18+
dx_sametype = convert(typeof(x), only(tx))
19+
dy_sametype = make_zero(y)
20+
x_and_dx = Duplicated(x, dx_sametype)
21+
y_and_dy = Duplicated(y, dy_sametype)
22+
autodiff(forward_mode_noprimal(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
23+
return y, Tangents(dy_sametype)
2224
end
2325

2426
function DI.value_and_pushforward(
2527
f!,
2628
y,
2729
::NoPushforwardExtras,
28-
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
30+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
2931
x,
30-
tx::Tangents{1},
31-
)
32-
dx = only(tx)
33-
f!_and_df! = get_f_and_df(f!, backend)
34-
dx_sametype = convert(typeof(x), dx)
35-
dy_sametype = make_zero(y)
36-
y_and_dy = Duplicated(y, dy_sametype)
37-
x_and_dx = Duplicated(x, dx_sametype)
38-
if backend isa AutoDeferredEnzyme
39-
autodiff_deferred(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
40-
else
41-
autodiff(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
42-
end
43-
return y, Tangents(dy_sametype)
32+
tx::Tangents{B},
33+
) where {B}
34+
f!_and_df! = get_f_and_df(f!, backend, Val(B))
35+
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
36+
dys_sametype = ntuple(_ -> make_zero(y), Val(B))
37+
x_and_dxs = BatchDuplicated(x, dxs_sametype)
38+
y_and_dys = BatchDuplicated(y, dys_sametype)
39+
autodiff(forward_mode_noprimal(backend), f!_and_df!, Const, y_and_dys, x_and_dxs)
40+
return y, Tangents(dys_sametype...)
4441
end

0 commit comments

Comments
 (0)