Skip to content

Commit 0f3d2c1

Browse files
authored
fix: handle function contexts differently from constant contexts (#660)
* fix: handle function contexts differently from constant contexts * Typos * Typo * Fix Enzyme translation * Typo * Forc eannot * Coverage * Pass mode object to translator in Enzyme * Typo * Cleaner error
1 parent 6806fef commit 0f3d2c1

17 files changed

Lines changed: 496 additions & 279 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.26"
4+
version = "0.6.27"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
66
end
77

88
function DI.prepare_pullback(
9-
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.Constant,C}
9+
f,
10+
::AutoReverseChainRules,
11+
x,
12+
ty::NTuple,
13+
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
1014
) where {C}
1115
return DI.NoPullbackPrep()
1216
end
@@ -17,7 +21,7 @@ function DI.prepare_pullback_same_point(
1721
backend::AutoReverseChainRules,
1822
x,
1923
ty::NTuple,
20-
contexts::Vararg{DI.Constant,C},
24+
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
2125
) where {C}
2226
rc = ruleconfig(backend)
2327
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
@@ -30,7 +34,7 @@ function DI.value_and_pullback(
3034
backend::AutoReverseChainRules,
3135
x,
3236
ty::NTuple,
33-
contexts::Vararg{DI.Constant,C},
37+
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
3438
) where {C}
3539
rc = ruleconfig(backend)
3640
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
@@ -46,7 +50,7 @@ function DI.value_and_pullback(
4650
::AutoReverseChainRules,
4751
x,
4852
ty::NTuple,
49-
contexts::Vararg{DI.Constant,C},
53+
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
5054
) where {C}
5155
(; y, pb) = prep
5256
tx = map(ty) do dy
@@ -61,7 +65,7 @@ function DI.pullback(
6165
::AutoReverseChainRules,
6266
x,
6367
ty::NTuple,
64-
contexts::Vararg{DI.Constant,C},
68+
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
6569
) where {C}
6670
(; pb) = prep
6771
tx = map(ty) do dy

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using EnzymeCore:
77
Active,
88
Annotation,
99
BatchDuplicated,
10+
BatchDuplicatedNoNeed,
1011
BatchMixedDuplicated,
1112
Combined,
1213
Const,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ function DI.value_and_pushforward(
1818
tx::NTuple{1},
1919
contexts::Vararg{DI.Context,C},
2020
) where {F,C}
21-
f_and_df = get_f_and_df(f, backend)
21+
mode = forward_withprimal(backend)
22+
f_and_df = get_f_and_df(f, backend, mode)
2223
dx_sametype = convert(typeof(x), only(tx))
2324
x_and_dx = Duplicated(x, dx_sametype)
24-
dy, y = autodiff(
25-
forward_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
26-
)
25+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
26+
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
2727
return y, (dy,)
2828
end
2929

@@ -35,12 +35,12 @@ function DI.value_and_pushforward(
3535
tx::NTuple{B},
3636
contexts::Vararg{DI.Context,C},
3737
) where {F,B,C}
38-
f_and_df = get_f_and_df(f, backend, Val(B))
38+
mode = forward_withprimal(backend)
39+
f_and_df = get_f_and_df(f, backend, mode, Val(B))
3940
tx_sametype = map(Fix1(convert, typeof(x)), tx)
4041
x_and_tx = BatchDuplicated(x, tx_sametype)
41-
ty, y = autodiff(
42-
forward_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
43-
)
42+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
43+
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
4444
return y, values(ty)
4545
end
4646

@@ -52,12 +52,12 @@ function DI.pushforward(
5252
tx::NTuple{1},
5353
contexts::Vararg{DI.Context,C},
5454
) where {F,C}
55-
f_and_df = get_f_and_df(f, backend)
55+
mode = forward_noprimal(backend)
56+
f_and_df = get_f_and_df(f, backend, mode)
5657
dx_sametype = convert(typeof(x), only(tx))
5758
x_and_dx = Duplicated(x, dx_sametype)
58-
dy = only(
59-
autodiff(forward_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...)
60-
)
59+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
60+
dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...))
6161
return (dy,)
6262
end
6363

@@ -69,12 +69,12 @@ function DI.pushforward(
6969
tx::NTuple{B},
7070
contexts::Vararg{DI.Context,C},
7171
) where {F,B,C}
72-
f_and_df = get_f_and_df(f, backend, Val(B))
72+
mode = forward_noprimal(backend)
73+
f_and_df = get_f_and_df(f, backend, mode, Val(B))
7374
tx_sametype = map(Fix1(convert, typeof(x)), tx)
7475
x_and_tx = BatchDuplicated(x, tx_sametype)
75-
ty = only(
76-
autodiff(forward_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...)
77-
)
76+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
77+
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
7878
return values(ty)
7979
end
8080

@@ -132,10 +132,9 @@ function DI.gradient(
132132
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
133133
x,
134134
) where {F,B}
135-
f_and_df = get_f_and_df(f, backend)
136-
derivs = gradient(
137-
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
138-
)
135+
mode = forward_noprimal(backend)
136+
f_and_df = get_f_and_df(f, backend, mode)
137+
derivs = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
139138
return only(derivs)
140139
end
141140

@@ -145,10 +144,9 @@ function DI.value_and_gradient(
145144
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
146145
x,
147146
) where {F,B}
148-
f_and_df = get_f_and_df(f, backend)
149-
(; derivs, val) = gradient(
150-
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
151-
)
147+
mode = forward_withprimal(backend)
148+
f_and_df = get_f_and_df(f, backend, mode)
149+
(; derivs, val) = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
152150
return val, only(derivs)
153151
end
154152

@@ -201,10 +199,9 @@ function DI.jacobian(
201199
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
202200
x,
203201
) where {F,B}
204-
f_and_df = get_f_and_df(f, backend)
205-
derivs = jacobian(
206-
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
207-
)
202+
mode = forward_noprimal(backend)
203+
f_and_df = get_f_and_df(f, backend, mode)
204+
derivs = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
208205
jac_tensor = only(derivs)
209206
return maybe_reshape(jac_tensor, prep.output_length, length(x))
210207
end
@@ -215,10 +212,9 @@ function DI.value_and_jacobian(
215212
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
216213
x,
217214
) where {F,B}
218-
f_and_df = get_f_and_df(f, backend)
219-
(; derivs, val) = jacobian(
220-
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
221-
)
215+
mode = forward_withprimal(backend)
216+
f_and_df = get_f_and_df(f, backend, mode)
217+
(; derivs, val) = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)
222218
jac_tensor = only(derivs)
223219
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))
224220
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,14 @@ function DI.value_and_pushforward(
2020
tx::NTuple{1},
2121
contexts::Vararg{DI.Context,C},
2222
) where {F,C}
23-
f!_and_df! = get_f_and_df(f!, backend)
23+
mode = forward_noprimal(backend)
24+
f!_and_df! = get_f_and_df(f!, backend, mode)
2425
dx_sametype = convert(typeof(x), only(tx))
2526
dy_sametype = make_zero(y)
2627
x_and_dx = Duplicated(x, dx_sametype)
2728
y_and_dy = Duplicated(y, dy_sametype)
28-
autodiff(
29-
forward_noprimal(backend),
30-
f!_and_df!,
31-
Const,
32-
y_and_dy,
33-
x_and_dx,
34-
map(translate, contexts)...,
35-
)
29+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
30+
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)
3631
return y, (dy_sametype,)
3732
end
3833

@@ -45,19 +40,14 @@ function DI.value_and_pushforward(
4540
tx::NTuple{B},
4641
contexts::Vararg{DI.Context,C},
4742
) where {F,B,C}
48-
f!_and_df! = get_f_and_df(f!, backend, Val(B))
43+
mode = forward_noprimal(backend)
44+
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
4945
tx_sametype = map(Fix1(convert, typeof(x)), tx)
5046
ty_sametype = ntuple(_ -> make_zero(y), Val(B))
5147
x_and_tx = BatchDuplicated(x, tx_sametype)
5248
y_and_ty = BatchDuplicated(y, ty_sametype)
53-
autodiff(
54-
forward_noprimal(backend),
55-
f!_and_df!,
56-
Const,
57-
y_and_ty,
58-
x_and_tx,
59-
map(translate, contexts)...,
60-
)
49+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
50+
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
6151
return y, ty_sametype
6252
end
6353

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,14 @@ function DI.value_and_pullback(
6969
ty::NTuple{1},
7070
contexts::Vararg{DI.Context,C},
7171
) where {F,C}
72-
f_and_df = force_annotation(get_f_and_df(f, backend))
7372
mode = reverse_split_withprimal(backend)
73+
f_and_df = force_annotation(get_f_and_df(f, backend, mode))
7474
IA = guess_activity(typeof(x), mode)
7575
RA = guess_activity(eltype(ty), mode)
7676
dx = make_zero(x)
77+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
7778
dinputs, result = seeded_autodiff_thunk(
78-
mode, only(ty), f_and_df, RA, annotate(IA, x, dx), map(translate, contexts)...
79+
mode, only(ty), f_and_df, RA, annotate(IA, x, dx), annotated_contexts...
7980
)
8081
new_dx = first(dinputs)
8182
if isnothing(new_dx)
@@ -93,13 +94,14 @@ function DI.value_and_pullback(
9394
ty::NTuple{B},
9495
contexts::Vararg{DI.Context,C},
9596
) where {F,B,C}
96-
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
9797
mode = reverse_split_withprimal(backend)
98+
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))
9899
IA = batchify_activity(guess_activity(typeof(x), mode), Val(B))
99100
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
100101
tx = ntuple(_ -> make_zero(x), Val(B))
102+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
101103
dinputs, result = batch_seeded_autodiff_thunk(
102-
mode, ty, f_and_df, RA, annotate(IA, x, tx), map(translate, contexts)...
104+
mode, ty, f_and_df, RA, annotate(IA, x, tx), annotated_contexts...
103105
)
104106
new_tx = values(first(dinputs))
105107
if isnothing(new_tx)
@@ -131,18 +133,14 @@ function DI.value_and_pullback!(
131133
ty::NTuple{1},
132134
contexts::Vararg{DI.Context,C},
133135
) where {F,C}
134-
f_and_df = force_annotation(get_f_and_df(f, backend))
135136
mode = reverse_split_withprimal(backend)
137+
f_and_df = force_annotation(get_f_and_df(f, backend, mode))
136138
RA = guess_activity(eltype(ty), mode)
137139
dx_righttype = convert(typeof(x), only(tx))
138140
make_zero!(dx_righttype)
141+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
139142
_, result = seeded_autodiff_thunk(
140-
mode,
141-
only(ty),
142-
f_and_df,
143-
RA,
144-
Duplicated(x, dx_righttype),
145-
map(translate, contexts)...,
143+
mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts...
146144
)
147145
only(tx) === dx_righttype || copyto!(only(tx), dx_righttype)
148146
return result, tx
@@ -157,18 +155,14 @@ function DI.value_and_pullback!(
157155
ty::NTuple{B},
158156
contexts::Vararg{DI.Context,C},
159157
) where {F,B,C}
160-
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
161158
mode = reverse_split_withprimal(backend)
159+
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))
162160
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
163161
tx_righttype = map(Fix1(convert, typeof(x)), tx)
164162
make_zero!(tx_righttype)
163+
annotated_contexts = translate(backend, mode, Val(B), contexts...)
165164
_, result = batch_seeded_autodiff_thunk(
166-
mode,
167-
ty,
168-
f_and_df,
169-
RA,
170-
BatchDuplicated(x, tx_righttype),
171-
map(translate, contexts)...,
165+
mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts...
172166
)
173167
foreach(copyto!, tx, tx_righttype)
174168
return result, tx
@@ -196,12 +190,13 @@ function DI.gradient(
196190
x,
197191
contexts::Vararg{DI.Context,C},
198192
) where {F,C}
199-
f_and_df = get_f_and_df(f, backend)
200193
mode = reverse_noprimal(backend)
194+
f_and_df = get_f_and_df(f, backend, mode)
201195
IA = guess_activity(typeof(x), mode)
202196
grad = make_zero(x)
197+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
203198
dinputs = only(
204-
autodiff(mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...)
199+
autodiff(mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...)
205200
)
206201
new_grad = first(dinputs)
207202
if isnothing(new_grad)
@@ -217,12 +212,13 @@ function DI.value_and_gradient(
217212
x,
218213
contexts::Vararg{DI.Context,C},
219214
) where {F,C}
220-
f_and_df = get_f_and_df(f, backend)
221215
mode = reverse_withprimal(backend)
216+
f_and_df = get_f_and_df(f, backend, mode)
222217
IA = guess_activity(typeof(x), mode)
223218
grad = make_zero(x)
219+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
224220
dinputs, result = autodiff(
225-
mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...
221+
mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...
226222
)
227223
new_grad = first(dinputs)
228224
if isnothing(new_grad)
@@ -263,16 +259,12 @@ function DI.gradient!(
263259
x,
264260
contexts::Vararg{DI.Context,C},
265261
) where {F,C}
266-
f_and_df = get_f_and_df(f, backend)
262+
mode = reverse_noprimal(backend)
263+
f_and_df = get_f_and_df(f, backend, mode)
267264
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
268265
make_zero!(grad_righttype)
269-
autodiff(
270-
reverse_noprimal(backend),
271-
f_and_df,
272-
Active,
273-
Duplicated(x, grad_righttype),
274-
map(translate, contexts)...,
275-
)
266+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
267+
autodiff(mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...)
276268
grad === grad_righttype || copyto!(grad, grad_righttype)
277269
return grad
278270
end
@@ -295,15 +287,13 @@ function DI.value_and_gradient!(
295287
x,
296288
contexts::Vararg{DI.Context,C},
297289
) where {F,C}
298-
f_and_df = get_f_and_df(f, backend)
290+
mode = reverse_withprimal(backend)
291+
f_and_df = get_f_and_df(f, backend, mode)
299292
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
300293
make_zero!(grad_righttype)
294+
annotated_contexts = translate(backend, mode, Val(1), contexts...)
301295
_, y = autodiff(
302-
reverse_withprimal(backend),
303-
f_and_df,
304-
Active,
305-
Duplicated(x, grad_righttype),
306-
map(translate, contexts)...,
296+
mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...
307297
)
308298
grad === grad_righttype || copyto!(grad, grad_righttype)
309299
return y, grad

0 commit comments

Comments
 (0)