Skip to content

Commit 097930a

Browse files
authored
Add constant contexts for Enzyme (#489)
1 parent 0a0a943 commit 097930a

10 files changed

Lines changed: 298 additions & 146 deletions

File tree

DifferentiationInterface/docs/src/explanation/advanced.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Another option would be creating a closure, but that is sometimes undesirable.
1212

1313
!!! warning
1414
This feature is still experimental, and will likely not be supported by all backends.
15-
At the moment, it only works with ForwardDiff.
15+
At the moment, it only works with ForwardDiff, Zygote and Enzyme.
1616

1717
### Types of contexts
1818

DifferentiationInterface/docs/src/tutorials/advanced.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,6 @@ extras_other_constant = prepare_gradient(f_multiarg, backend, x, Constant(-1))
4949
gradient(f_multiarg, extras_other_constant, backend, x, Constant(10))
5050
```
5151

52-
!!! warning
53-
At the moment, contexts only work with ForwardDiff, but we will add compatibility with other backends soon.
54-
This trick will be especially important to leverage Enzyme's annotations for increased performance.
55-
5652
## Sparsity
5753

5854
Sparse AD is very useful when Jacobian or Hessian matrices have a lot of zeros.

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ADTypes: ADTypes, AutoEnzyme
44
using Base: Fix1
55
import DifferentiationInterface as DI
66
using DifferentiationInterface:
7+
Context,
78
DerivativeExtras,
89
GradientExtras,
910
JacobianExtras,
Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,112 @@
11
## Pushforward
22

33
function DI.prepare_pushforward(
4-
f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
5-
)
4+
f::F,
5+
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6+
x,
7+
tx::Tangents,
8+
contexts::Vararg{Context,C},
9+
) where {F,C}
610
return NoPushforwardExtras()
711
end
812

913
function DI.value_and_pushforward(
10-
f,
14+
f::F,
1115
::NoPushforwardExtras,
1216
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1317
x,
1418
tx::Tangents{1},
15-
)
19+
contexts::Vararg{Context,C},
20+
) where {F,C}
1621
f_and_df = get_f_and_df(f, backend)
1722
dx_sametype = convert(typeof(x), only(tx))
1823
x_and_dx = Duplicated(x, dx_sametype)
19-
dy, y = autodiff(forward_mode_withprimal(backend), f_and_df, x_and_dx)
24+
dy, y = autodiff(
25+
forward_mode_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
26+
)
2027
return y, Tangents(dy)
2128
end
2229

2330
function DI.value_and_pushforward(
24-
f,
31+
f::F,
2532
::NoPushforwardExtras,
2633
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
2734
x,
2835
tx::Tangents{B},
29-
) where {B}
36+
contexts::Vararg{Context,C},
37+
) where {F,B,C}
3038
f_and_df = get_f_and_df(f, backend, Val(B))
3139
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
3240
x_and_dxs = BatchDuplicated(x, dxs_sametype)
33-
dys, y = autodiff(forward_mode_withprimal(backend), f_and_df, x_and_dxs)
41+
dys, y = autodiff(
42+
forward_mode_withprimal(backend), f_and_df, x_and_dxs, map(translate, contexts)...
43+
)
3444
return y, Tangents(dys...)
3545
end
3646

3747
function DI.pushforward(
38-
f,
48+
f::F,
3949
::NoPushforwardExtras,
4050
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
4151
x,
4252
tx::Tangents{1},
43-
)
53+
contexts::Vararg{Context,C},
54+
) where {F,C}
4455
f_and_df = get_f_and_df(f, backend)
4556
dx_sametype = convert(typeof(x), only(tx))
4657
x_and_dx = Duplicated(x, dx_sametype)
47-
dy = only(autodiff(forward_mode_noprimal(backend), f_and_df, x_and_dx))
58+
dy = only(
59+
autodiff(
60+
forward_mode_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
61+
),
62+
)
4863
return Tangents(dy)
4964
end
5065

5166
function DI.pushforward(
52-
f,
67+
f::F,
5368
::NoPushforwardExtras,
5469
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
5570
x,
5671
tx::Tangents{B},
57-
) where {B}
72+
contexts::Vararg{Context,C},
73+
) where {F,B,C}
5874
f_and_df = get_f_and_df(f, backend, Val(B))
5975
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
6076
x_and_dxs = BatchDuplicated(x, dxs_sametype)
61-
dys = only(autodiff(forward_mode_noprimal(backend), f_and_df, x_and_dxs))
77+
dys = only(
78+
autodiff(
79+
forward_mode_noprimal(backend), f_and_df, x_and_dxs, map(translate, contexts)...
80+
),
81+
)
6282
return Tangents(dys...)
6383
end
6484

6585
function DI.value_and_pushforward!(
66-
f,
86+
f::F,
6787
ty::Tangents,
6888
extras::NoPushforwardExtras,
6989
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
7090
x,
7191
tx::Tangents,
72-
)
92+
contexts::Vararg{Context,C},
93+
) where {F,C}
7394
# dy cannot be passed anyway
74-
y, new_ty = DI.value_and_pushforward(f, extras, backend, x, tx)
95+
y, new_ty = DI.value_and_pushforward(f, extras, backend, x, tx, contexts...)
7596
return y, copyto!(ty, new_ty)
7697
end
7798

7899
function DI.pushforward!(
79-
f,
100+
f::F,
80101
ty::Tangents,
81102
extras::NoPushforwardExtras,
82103
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
83104
x,
84105
tx::Tangents,
85-
)
106+
contexts::Vararg{Context,C},
107+
) where {F,C}
86108
# dy cannot be passed anyway
87-
return copyto!(ty, DI.pushforward(f, extras, backend, x, tx))
109+
return copyto!(ty, DI.pushforward(f, extras, backend, x, tx, contexts...))
88110
end
89111

90112
## Gradient
@@ -94,19 +116,19 @@ struct EnzymeForwardGradientExtras{B,O} <: GradientExtras
94116
end
95117

96118
function DI.prepare_gradient(
97-
f, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
98-
)
119+
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
120+
) where {F}
99121
B = pick_batchsize(backend, length(x))
100122
shadows = create_shadows(Val(B), x)
101123
return EnzymeForwardGradientExtras{B,typeof(shadows)}(shadows)
102124
end
103125

104126
function DI.gradient(
105-
f,
127+
f::F,
106128
extras::EnzymeForwardGradientExtras{B},
107129
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
108130
x,
109-
) where {B}
131+
) where {F,B}
110132
f_and_df = get_f_and_df(f, backend)
111133
derivs = gradient(
112134
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
@@ -115,11 +137,11 @@ function DI.gradient(
115137
end
116138

117139
function DI.value_and_gradient(
118-
f,
140+
f::F,
119141
extras::EnzymeForwardGradientExtras{B},
120142
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
121143
x,
122-
) where {B}
144+
) where {F,B}
123145
f_and_df = get_f_and_df(f, backend)
124146
(; derivs, val) = gradient(
125147
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
@@ -128,22 +150,22 @@ function DI.value_and_gradient(
128150
end
129151

130152
function DI.gradient!(
131-
f,
153+
f::F,
132154
grad,
133155
extras::EnzymeForwardGradientExtras{B},
134156
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
135157
x,
136-
) where {B}
158+
) where {F,B}
137159
return copyto!(grad, DI.gradient(f, extras, backend, x))
138160
end
139161

140162
function DI.value_and_gradient!(
141-
f,
163+
f::F,
142164
grad,
143165
extras::EnzymeForwardGradientExtras{B},
144166
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
145167
x,
146-
) where {B}
168+
) where {F,B}
147169
y, new_grad = DI.value_and_gradient(f, extras, backend, x)
148170
return y, copyto!(grad, new_grad)
149171
end
@@ -156,20 +178,20 @@ struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras
156178
end
157179

158180
function DI.prepare_jacobian(
159-
f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
160-
)
181+
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
182+
) where {F}
161183
y = f(x)
162184
B = pick_batchsize(backend, length(x))
163185
shadows = create_shadows(Val(B), x)
164186
return EnzymeForwardOneArgJacobianExtras{B,typeof(shadows)}(shadows, length(y))
165187
end
166188

167189
function DI.jacobian(
168-
f,
190+
f::F,
169191
extras::EnzymeForwardOneArgJacobianExtras{B},
170192
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
171193
x,
172-
) where {B}
194+
) where {F,B}
173195
f_and_df = get_f_and_df(f, backend)
174196
derivs = jacobian(
175197
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
@@ -179,11 +201,11 @@ function DI.jacobian(
179201
end
180202

181203
function DI.value_and_jacobian(
182-
f,
204+
f::F,
183205
extras::EnzymeForwardOneArgJacobianExtras{B},
184206
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
185207
x,
186-
) where {B}
208+
) where {F,B}
187209
f_and_df = get_f_and_df(f, backend)
188210
(; derivs, val) = jacobian(
189211
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=extras.shadows
@@ -193,22 +215,22 @@ function DI.value_and_jacobian(
193215
end
194216

195217
function DI.jacobian!(
196-
f,
218+
f::F,
197219
jac,
198220
extras::EnzymeForwardOneArgJacobianExtras,
199221
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
200222
x,
201-
)
223+
) where {F}
202224
return copyto!(jac, DI.jacobian(f, extras, backend, x))
203225
end
204226

205227
function DI.value_and_jacobian!(
206-
f,
228+
f::F,
207229
jac,
208230
extras::EnzymeForwardOneArgJacobianExtras,
209231
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
210232
x,
211-
)
233+
) where {F}
212234
y, new_jac = DI.value_and_jacobian(f, extras, backend, x)
213235
return y, copyto!(jac, new_jac)
214236
end
Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,62 @@
11
## Pushforward
22

33
function DI.prepare_pushforward(
4-
f!, y, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
5-
)
4+
f!::F,
5+
y,
6+
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
7+
x,
8+
tx::Tangents,
9+
contexts::Vararg{Context,C},
10+
) where {F,C}
611
return NoPushforwardExtras()
712
end
813

914
function DI.value_and_pushforward(
10-
f!,
15+
f!::F,
1116
y,
1217
::NoPushforwardExtras,
1318
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1419
x,
1520
tx::Tangents{1},
16-
)
21+
contexts::Vararg{Context,C},
22+
) where {F,C}
1723
f!_and_df! = get_f_and_df(f!, backend)
1824
dx_sametype = convert(typeof(x), only(tx))
1925
dy_sametype = make_zero(y)
2026
x_and_dx = Duplicated(x, dx_sametype)
2127
y_and_dy = Duplicated(y, dy_sametype)
22-
autodiff(forward_mode_noprimal(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
28+
autodiff(
29+
forward_mode_noprimal(backend),
30+
f!_and_df!,
31+
Const,
32+
y_and_dy,
33+
x_and_dx,
34+
map(translate, contexts)...,
35+
)
2336
return y, Tangents(dy_sametype)
2437
end
2538

2639
function DI.value_and_pushforward(
27-
f!,
40+
f!::F,
2841
y,
2942
::NoPushforwardExtras,
3043
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
3144
x,
3245
tx::Tangents{B},
33-
) where {B}
46+
contexts::Vararg{Context,C},
47+
) where {F,B,C}
3448
f!_and_df! = get_f_and_df(f!, backend, Val(B))
3549
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
3650
dys_sametype = ntuple(_ -> make_zero(y), Val(B))
3751
x_and_dxs = BatchDuplicated(x, dxs_sametype)
3852
y_and_dys = BatchDuplicated(y, dys_sametype)
39-
autodiff(forward_mode_noprimal(backend), f!_and_df!, Const, y_and_dys, x_and_dxs)
53+
autodiff(
54+
forward_mode_noprimal(backend),
55+
f!_and_df!,
56+
Const,
57+
y_and_dys,
58+
x_and_dxs,
59+
map(translate, contexts)...,
60+
)
4061
return y, Tangents(dys_sametype...)
4162
end

0 commit comments

Comments
 (0)