Skip to content

Commit 26517dc

Browse files
committed
feat!: ensure consistency between preparation result and current signature
1 parent bac2d02 commit 26517dc

52 files changed

Lines changed: 2035 additions & 1050 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: true # TODO: toggle
28+
fail-fast: false # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
22
(; f, backend) = dw
33
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,))
4+
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,); strict=true)
55
function pullbackfunc(dy)
66
tx = DI.pullback(f, prep_same, backend, x, (dy,))
77
return (NoTangent(), only(tx))

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,48 @@
11
## Pullback
22

3-
struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
3+
struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG}
4+
_sig::Type{SIG}
45
y::Y
56
pb::PB
67
end
78

89
function DI.prepare_pullback(
9-
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
10+
f,
11+
backend::AutoReverseChainRules,
12+
x,
13+
ty::NTuple,
14+
contexts::Vararg{DI.GeneralizedConstant,C};
15+
strict::Bool=false,
1016
) where {C}
11-
return DI.NoPullbackPrep()
17+
SIG = DI.signature(f, backend, x, ty, contexts...; strict)
18+
return DI.NoPullbackPrep{SIG}()
1219
end
1320

1421
function DI.prepare_pullback_same_point(
1522
f,
16-
::DI.NoPullbackPrep,
23+
prep::DI.NoPullbackPrep,
1724
backend::AutoReverseChainRules,
1825
x,
1926
ty::NTuple,
20-
contexts::Vararg{DI.GeneralizedConstant,C},
27+
contexts::Vararg{DI.GeneralizedConstant,C};
28+
strict::Bool=false,
2129
) where {C}
30+
DI.check_prep(f, prep, backend, x, ty, contexts...)
31+
SIG = DI.signature(f, backend, x, ty, contexts...; strict)
2232
rc = ruleconfig(backend)
2333
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
24-
return ChainRulesPullbackPrepSamePoint(y, pb)
34+
return ChainRulesPullbackPrepSamePoint(SIG, y, pb)
2535
end
2636

2737
function DI.value_and_pullback(
2838
f,
29-
::DI.NoPullbackPrep,
39+
prep::DI.NoPullbackPrep,
3040
backend::AutoReverseChainRules,
3141
x,
3242
ty::NTuple,
3343
contexts::Vararg{DI.GeneralizedConstant,C},
3444
) where {C}
45+
DI.check_prep(f, prep, backend, x, ty, contexts...)
3546
rc = ruleconfig(backend)
3647
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
3748
tx = map(ty) do dy
@@ -43,11 +54,12 @@ end
4354
function DI.value_and_pullback(
4455
f,
4556
prep::ChainRulesPullbackPrepSamePoint,
46-
::AutoReverseChainRules,
57+
backend::AutoReverseChainRules,
4758
x,
4859
ty::NTuple,
4960
contexts::Vararg{DI.GeneralizedConstant,C},
5061
) where {C}
62+
DI.check_prep(f, prep, backend, x, ty, contexts...)
5163
(; y, pb) = prep
5264
tx = map(ty) do dy
5365
unthunk(pb(dy)[2])
@@ -58,11 +70,12 @@ end
5870
function DI.pullback(
5971
f,
6072
prep::ChainRulesPullbackPrepSamePoint,
61-
::AutoReverseChainRules,
73+
backend::AutoReverseChainRules,
6274
x,
6375
ty::NTuple,
6476
contexts::Vararg{DI.GeneralizedConstant,C},
6577
) where {C}
78+
DI.check_prep(f, prep, backend, x, ty, contexts...)
6679
(; pb) = prep
6780
tx = map(ty) do dy
6881
unthunk(pb(dy)[2])

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1010

1111
## Pushforward
1212

13-
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep()
13+
function DI.prepare_pushforward(f, backend::AutoDiffractor, x, tx::NTuple)
14+
SIG = DI.signature(f, backend, x, tx)
15+
return DI.NoPushforwardPrep{SIG}()
16+
end
1417

15-
function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
18+
function DI.pushforward(
19+
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
20+
)
21+
DI.check_prep(f, prep, backend, x, tx)
1622
ty = map(tx) do dx
1723
# code copied from Diffractor.jl
1824
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
@@ -25,6 +31,7 @@ end
2531
function DI.value_and_pushforward(
2632
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
2733
)
34+
DI.check_prep(f, prep, backend, x, tx)
2835
return f(x), DI.pushforward(f, prep, backend, x, tx)
2936
end
3037

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,25 @@
22

33
function DI.prepare_pushforward(
44
f::F,
5-
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
5+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
66
x,
77
tx::NTuple,
8-
contexts::Vararg{DI.Context,C},
8+
contexts::Vararg{DI.Context,C};
9+
strict::Bool=false,
910
) where {F,C}
10-
return DI.NoPushforwardPrep()
11+
SIG = DI.signature(f, backend, x, tx, contexts...; strict)
12+
return DI.NoPushforwardPrep{SIG}()
1113
end
1214

1315
function DI.value_and_pushforward(
1416
f::F,
15-
::DI.NoPushforwardPrep,
17+
prep::DI.NoPushforwardPrep,
1618
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1719
x,
1820
tx::NTuple{1},
1921
contexts::Vararg{DI.Context,C},
2022
) where {F,C}
23+
DI.check_prep(f, prep, backend, x, tx, contexts...)
2124
mode = forward_withprimal(backend)
2225
f_and_df = get_f_and_df(f, backend, mode)
2326
dx = only(tx)
@@ -29,12 +32,13 @@ end
2932

3033
function DI.value_and_pushforward(
3134
f::F,
32-
::DI.NoPushforwardPrep,
35+
prep::DI.NoPushforwardPrep,
3336
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
3437
x,
3538
tx::NTuple{B},
3639
contexts::Vararg{DI.Context,C},
3740
) where {F,B,C}
41+
DI.check_prep(f, prep, backend, x, tx, contexts...)
3842
mode = forward_withprimal(backend)
3943
f_and_df = get_f_and_df(f, backend, mode, Val(B))
4044
x_and_tx = BatchDuplicated(x, tx)
@@ -45,12 +49,13 @@ end
4549

4650
function DI.pushforward(
4751
f::F,
48-
::DI.NoPushforwardPrep,
52+
prep::DI.NoPushforwardPrep,
4953
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
5054
x,
5155
tx::NTuple{1},
5256
contexts::Vararg{DI.Context,C},
5357
) where {F,C}
58+
DI.check_prep(f, prep, backend, x, tx, contexts...)
5459
mode = forward_noprimal(backend)
5560
f_and_df = get_f_and_df(f, backend, mode)
5661
dx = only(tx)
@@ -62,12 +67,13 @@ end
6267

6368
function DI.pushforward(
6469
f::F,
65-
::DI.NoPushforwardPrep,
70+
prep::DI.NoPushforwardPrep,
6671
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6772
x,
6873
tx::NTuple{B},
6974
contexts::Vararg{DI.Context,C},
7075
) where {F,B,C}
76+
DI.check_prep(f, prep, backend, x, tx, contexts...)
7177
mode = forward_noprimal(backend)
7278
f_and_df = get_f_and_df(f, backend, mode, Val(B))
7379
x_and_tx = BatchDuplicated(x, tx)
@@ -85,6 +91,7 @@ function DI.value_and_pushforward!(
8591
tx::NTuple,
8692
contexts::Vararg{DI.Context,C},
8793
) where {F,C}
94+
DI.check_prep(f, prep, backend, x, tx, contexts...)
8895
# dy cannot be passed anyway
8996
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
9097
foreach(copyto!, ty, new_ty)
@@ -100,6 +107,7 @@ function DI.pushforward!(
100107
tx::NTuple,
101108
contexts::Vararg{DI.Context,C},
102109
) where {F,C}
110+
DI.check_prep(f, prep, backend, x, tx, contexts...)
103111
# dy cannot be passed anyway
104112
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
105113
foreach(copyto!, ty, new_ty)
@@ -108,23 +116,25 @@ end
108116

109117
## Gradient
110118

111-
struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
119+
struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
112120
shadows::O
113121
end
114122

115-
function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
116-
return EnzymeForwardGradientPrep{B,O}(shadows)
123+
function EnzymeForwardGradientPrep(::Type{SIG}, ::Val{B}, shadows::O) where {SIG,B,O}
124+
return EnzymeForwardGradientPrep{SIG,B,O}(shadows)
117125
end
118126

119127
function DI.prepare_gradient(
120128
f::F,
121129
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
122130
x,
123-
contexts::Vararg{DI.Constant,C},
131+
contexts::Vararg{DI.Constant,C};
132+
strict::Bool=false,
124133
) where {F,C}
134+
SIG = DI.signature(f, backend, x, contexts...; strict)
125135
valB = to_val(DI.pick_batchsize(backend, x))
126136
shadows = create_shadows(valB, x)
127-
return EnzymeForwardGradientPrep(valB, shadows)
137+
return EnzymeForwardGradientPrep(SIG, valB, shadows)
128138
end
129139

130140
function DI.gradient(
@@ -134,6 +144,7 @@ function DI.gradient(
134144
x,
135145
contexts::Vararg{DI.Constant,C},
136146
) where {F,B,C}
147+
DI.check_prep(f, prep, backend, x, contexts...)
137148
mode = forward_noprimal(backend)
138149
f_and_df = get_f_and_df(f, backend, mode)
139150
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -150,6 +161,7 @@ function DI.value_and_gradient(
150161
x,
151162
contexts::Vararg{DI.Constant,C},
152163
) where {F,B,C}
164+
DI.check_prep(f, prep, backend, x, contexts...)
153165
mode = forward_withprimal(backend)
154166
f_and_df = get_f_and_df(f, backend, mode)
155167
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -167,6 +179,7 @@ function DI.gradient!(
167179
x,
168180
contexts::Vararg{DI.Constant,C},
169181
) where {F,B,C}
182+
DI.check_prep(f, prep, backend, x, contexts...)
170183
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
171184
end
172185

@@ -178,33 +191,36 @@ function DI.value_and_gradient!(
178191
x,
179192
contexts::Vararg{DI.Constant,C},
180193
) where {F,B,C}
194+
DI.check_prep(f, prep, backend, x, contexts...)
181195
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
182196
return y, copyto!(grad, new_grad)
183197
end
184198

185199
## Jacobian
186200

187-
struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
201+
struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
188202
shadows::O
189203
output_length::Int
190204
end
191205

192206
function EnzymeForwardOneArgJacobianPrep(
193-
::Val{B}, shadows::O, output_length::Integer
194-
) where {B,O}
195-
return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length)
207+
::Type{SIG}, ::Val{B}, shadows::O, output_length::Integer
208+
) where {SIG,B,O}
209+
return EnzymeForwardOneArgJacobianPrep{SIG,B,O}(shadows, output_length)
196210
end
197211

198212
function DI.prepare_jacobian(
199213
f::F,
200214
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
201215
x,
202-
contexts::Vararg{DI.Constant,C},
216+
contexts::Vararg{DI.Constant,C};
217+
strict::Bool=false,
203218
) where {F,C}
219+
SIG = DI.signature(f, backend, x, contexts...; strict)
204220
y = f(x, map(DI.unwrap, contexts)...)
205221
valB = to_val(DI.pick_batchsize(backend, x))
206222
shadows = create_shadows(valB, x)
207-
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
223+
return EnzymeForwardOneArgJacobianPrep(SIG, valB, shadows, length(y))
208224
end
209225

210226
function DI.jacobian(
@@ -214,6 +230,7 @@ function DI.jacobian(
214230
x,
215231
contexts::Vararg{DI.Constant,C},
216232
) where {F,B,C}
233+
DI.check_prep(f, prep, backend, contexts...)
217234
mode = forward_noprimal(backend)
218235
f_and_df = get_f_and_df(f, backend, mode)
219236
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -231,6 +248,7 @@ function DI.value_and_jacobian(
231248
x,
232249
contexts::Vararg{DI.Constant,C},
233250
) where {F,B,C}
251+
DI.check_prep(f, prep, backend, contexts...)
234252
mode = forward_withprimal(backend)
235253
f_and_df = get_f_and_df(f, backend, mode)
236254
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -249,6 +267,7 @@ function DI.jacobian!(
249267
x,
250268
contexts::Vararg{DI.Constant,C},
251269
) where {F,C}
270+
DI.check_prep(f, prep, backend, contexts...)
252271
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
253272
end
254273

@@ -260,6 +279,7 @@ function DI.value_and_jacobian!(
260279
x,
261280
contexts::Vararg{DI.Constant,C},
262281
) where {F,C}
282+
DI.check_prep(f, prep, backend, contexts...)
263283
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
264284
return y, copyto!(jac, new_jac)
265285
end

0 commit comments

Comments
 (0)