Skip to content

Commit 8ab3a92

Browse files
committed
feat!: specify preparation arguments in DIT Scenario
1 parent bb50d0b commit 8ab3a92

17 files changed

Lines changed: 636 additions & 777 deletions

File tree

DifferentiationInterface/src/utils/prep.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function check_prep(
198198
if SIG != EXEC_SIG
199199
throw(
200200
PreparationMismatchError(
201-
SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts]
201+
SIG, EXEC_SIG; format=[:f, :backend, :x, :t, :contexts]
202202
),
203203
)
204204
end
@@ -213,7 +213,7 @@ function check_prep(
213213
if SIG != EXEC_SIG
214214
throw(
215215
PreparationMismatchError(
216-
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts]
216+
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :t, :contexts]
217217
),
218218
)
219219
end

DifferentiationInterfaceTest/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.9.6"
4+
version = "0.10.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -43,7 +43,7 @@ AllocCheck = "0.2"
4343
Chairmarks = "1.2.1"
4444
ComponentArrays = "0.15"
4545
DataFrames = "1.6.1"
46-
DifferentiationInterface = "0.6.0"
46+
DifferentiationInterface = "0.6.53"
4747
DocStringExtensions = "0.8,0.9"
4848
ExplicitImports = "1.10.1"
4949
FiniteDifferences = "0.12"

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@ function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy
3333
append!(
3434
scens,
3535
[
36-
DIT.Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)),
36+
DIT.Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)),
3737
DIT.Scenario{:gradient,pl_op}(f, x; res1=grad),
3838
],
3939
)
4040
end
4141
for pl_op in (:out,)
42-
append!(
43-
scens, [DIT.Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,))]
44-
)
42+
append!(scens, [DIT.Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))])
4543
end
4644
return scens
4745
end

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
162162
for (model, x) in models_and_xs
163163
Flux.trainmode!(model)
164164
g = gradient_finite_differences(square_loss, model, x)
165-
scen = DIT.Scenario{:gradient,:out}(
166-
square_loss, model; contexts=(DI.Constant(x),), res1=g
167-
)
165+
scen = DIT.Scenario{:gradient,:out}(square_loss, model, DI.Constant(x); res1=g)
168166
push!(scens, scen)
169167
end
170168

@@ -191,7 +189,7 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
191189
Flux.trainmode!(model)
192190
g = gradient_finite_differences(square_loss_iterated, model, x)
193191
scen = DIT.Scenario{:gradient,:out}(
194-
square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g
192+
square_loss_iterated, model, DI.Constant(x); res1=g
195193
)
196194
push!(scens, scen)
197195
end

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap
2323
myjl(::Nothing) = nothing
2424

2525
function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
26-
(; f, x, y, tang, contexts, res1, res2) = scen
27-
return DIT.Scenario{op,pl_op,pl_fun}(
28-
myjl(f);
26+
(; f, x, y, t, contexts, res1, res2) = scen
27+
return DIT.Scenario{op,pl_op,pl_fun}(;
28+
f=myjl(f),
2929
x=myjl(x),
3030
y=myjl(y),
31-
tang=myjl(tang),
31+
t=myjl(t),
3232
contexts=myjl(contexts),
3333
res1=myjl(res1),
3434
res2=myjl(res2),

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,10 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
199199
)
200200
scen = DIT.Scenario{:gradient,:out}(
201201
square_loss,
202-
ComponentArray(ps);
203-
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
202+
ComponentArray(ps),
203+
DI.Constant(model),
204+
DI.Constant(x),
205+
DI.Constant(st);
204206
res1=g,
205207
)
206208
push!(scens, scen)

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ end
3636
mystatic(::Nothing) = nothing
3737

3838
function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
39-
(; f, x, y, tang, contexts, res1, res2) = scen
40-
return DIT.Scenario{op,pl_op,pl_fun}(
41-
mystatic(f);
39+
(; f, x, y, t, contexts, res1, res2) = scen
40+
return DIT.Scenario{op,pl_op,pl_fun}(;
41+
f=mystatic(f),
4242
x=mystatic(x),
4343
y=pl_fun == :in ? mymutablestatic(y) : mystatic(y),
44-
tang=mystatic(tang),
44+
t=mystatic(t),
4545
contexts=mystatic(contexts),
4646
res1=mystatic(res1),
4747
res2=mystatic(res2),

DifferentiationInterfaceTest/src/scenarios/allocfree.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ function identity_scenarios(x::Number; dx::Number, dy::Number)
55
der = one(x)
66

77
return [
8-
Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)),
9-
Scenario{:pullback,:out}(f, x; tang=(dy,), res1=(dx_from_dy,)),
8+
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
9+
Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)),
1010
Scenario{:derivative,:out}(f, x; res1=der),
1111
]
1212
end
@@ -19,8 +19,8 @@ function sum_scenarios(x::AbstractArray; dx::AbstractArray, dy::Number)
1919
grad .= one(eltype(x))
2020

2121
return [
22-
Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)),
23-
Scenario{:pullback,:in}(f, x; tang=(dy,), res1=(dx_from_dy,)),
22+
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
23+
Scenario{:pullback,:in}(f, x, (dy,); res1=(dx_from_dy,)),
2424
Scenario{:gradient,:in}(f, x; res1=grad),
2525
]
2626
end
@@ -34,8 +34,8 @@ function copyto!_scenarios(x::AbstractArray; dx::AbstractArray, dy::AbstractArra
3434
jac = Matrix(Diagonal(ones(eltype(x), length(x))))
3535

3636
return [
37-
Scenario{:pushforward,:in}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)),
38-
Scenario{:pullback,:in}(f!, y, x; tang=(dy,), res1=(dx_from_dy,)),
37+
Scenario{:pushforward,:in}(f!, y, x, (dx,); res1=(dy_from_dx,)),
38+
Scenario{:pullback,:in}(f!, y, x, (dy,); res1=(dx_from_dy,)),
3939
Scenario{:jacobian,:in}(f!, y, x; res1=jac),
4040
]
4141
end

DifferentiationInterfaceTest/src/scenarios/complex.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ function complex_holomorphic_gradient_scenarios()
99
scens = Scenario[
1010
Scenario{:gradient,:out}(square_only, x; res1=grad),
1111
Scenario{:gradient,:in}(square_only, x; res1=grad),
12-
Scenario{:pullback,:out}(square_only, x; tang=(dy,), res1=(grad,)),
13-
Scenario{:pullback,:in}(square_only, x; tang=(dy,), res1=(grad,)),
12+
Scenario{:pullback,:out}(square_only, x, (dy,); res1=(grad,)),
13+
Scenario{:pullback,:in}(square_only, x, (dy,); res1=(grad,)),
1414
]
1515
return scens
1616
end
@@ -22,8 +22,8 @@ function complex_gradient_scenarios()
2222
scens = Scenario[
2323
Scenario{:gradient,:out}(abs2_only, x; res1=grad),
2424
Scenario{:gradient,:in}(abs2_only, x; res1=grad),
25-
Scenario{:pullback,:out}(abs2_only, x; tang=(dy,), res1=(grad,)),
26-
Scenario{:pullback,:in}(abs2_only, x; tang=(dy,), res1=(grad,)),
25+
Scenario{:pullback,:out}(abs2_only, x, (dy,); res1=(grad,)),
26+
Scenario{:pullback,:in}(abs2_only, x, (dy,); res1=(grad,)),
2727
]
2828
return scens
2929
end

DifferentiationInterfaceTest/src/scenarios/default.jl

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ function num_to_num_scenarios(x::Number; dx::Number, dy::Number)
2727

2828
# everyone out of place
2929
scens = Scenario[
30-
Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)),
31-
Scenario{:pullback,:out}(f, x; tang=(dy,), res1=(dx_from_dy,)),
30+
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
31+
Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)),
3232
Scenario{:derivative,:out}(f, x; res1=der),
3333
Scenario{:second_derivative,:out}(f, x; res1=der, res2=der2),
3434
]
@@ -57,10 +57,10 @@ function onevec_to_onevec_scenarios_onearg(x::Number; dx::Number, dy::Number)
5757
scens,
5858
[
5959
Scenario{:pushforward,pl_op}(
60-
onevec_to_onevec, [x]; tang=([dx],), res1=([dy_from_dx],)
60+
onevec_to_onevec, [x], ([dx],); res1=([dy_from_dx],)
6161
),
6262
Scenario{:pullback,pl_op}(
63-
onevec_to_onevec, [x]; tang=([dy],), res1=([dx_from_dy],)
63+
onevec_to_onevec, [x], ([dy],); res1=([dx_from_dy],)
6464
),
6565
Scenario{:jacobian,pl_op}(onevec_to_onevec, [x]; res1=jac),
6666
],
@@ -85,10 +85,10 @@ function onevec_to_onevec_scenarios_twoarg(x::Number; dx::Number, dy::Number)
8585
scens,
8686
[
8787
Scenario{:pushforward,pl_op}(
88-
onevec_to_onevec!, [y], [x]; tang=([dx],), res1=([dy_from_dx],)
88+
onevec_to_onevec!, [y], [x], ([dx],); res1=([dy_from_dx],)
8989
),
9090
Scenario{:pullback,pl_op}(
91-
onevec_to_onevec!, [y], [x]; tang=([dy],), res1=([dx_from_dy],)
91+
onevec_to_onevec!, [y], [x], ([dy],); res1=([dx_from_dy],)
9292
),
9393
Scenario{:jacobian,pl_op}(onevec_to_onevec!, [y], [x]; res1=jac),
9494
],
@@ -137,14 +137,14 @@ function num_to_vec_scenarios_onearg(x::Number; dx::Number, dy::AbstractArray)
137137
append!(
138138
scens,
139139
[
140-
Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,)),
140+
Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)),
141141
Scenario{:derivative,pl_op}(f, x; res1=der),
142142
Scenario{:second_derivative,pl_op}(f, x; res1=der, res2=der2),
143143
],
144144
)
145145
end
146146
for pl_op in (:out,)
147-
append!(scens, [Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,))])
147+
append!(scens, [Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,))])
148148
end
149149
return scens
150150
end
@@ -163,15 +163,13 @@ function num_to_vec_scenarios_twoarg(x::Number; dx::Number, dy::AbstractArray)
163163
append!(
164164
scens,
165165
[
166-
Scenario{:pushforward,pl_op}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)),
166+
Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)),
167167
Scenario{:derivative,pl_op}(f!, y, x; res1=der),
168168
],
169169
)
170170
end
171171
for pl_op in (:out,)
172-
append!(
173-
scens, [Scenario{:pullback,pl_op}(f!, y, x; tang=(dy,), res1=(dx_from_dy,))]
174-
)
172+
append!(scens, [Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,))])
175173
end
176174
return scens
177175
end
@@ -225,14 +223,14 @@ function num_to_mat_scenarios_onearg(x::Number; dx::Number, dy::AbstractArray)
225223
append!(
226224
scens,
227225
[
228-
Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,)),
226+
Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)),
229227
Scenario{:derivative,pl_op}(f, x; res1=der),
230228
Scenario{:second_derivative,pl_op}(f, x; res1=der, res2=der2),
231229
],
232230
)
233231
end
234232
for pl_op in (:out,)
235-
append!(scens, [Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,))])
233+
append!(scens, [Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,))])
236234
end
237235
return scens
238236
end
@@ -251,15 +249,13 @@ function num_to_mat_scenarios_twoarg(x::Number; dx::Number, dy::AbstractArray)
251249
append!(
252250
scens,
253251
[
254-
Scenario{:pushforward,pl_op}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)),
252+
Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)),
255253
Scenario{:derivative,pl_op}(f!, y, x; res1=der),
256254
],
257255
)
258256
end
259257
for pl_op in (:out,)
260-
append!(
261-
scens, [Scenario{:pullback,pl_op}(f!, y, x; tang=(dy,), res1=(dx_from_dy,))]
262-
)
258+
append!(scens, [Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,))])
263259
end
264260
return scens
265261
end
@@ -330,15 +326,15 @@ function arr_to_num_scenarios_onearg(
330326
append!(
331327
scens,
332328
[
333-
Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)),
329+
Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)),
334330
Scenario{:gradient,pl_op}(f, x; res1=grad),
335-
Scenario{:hvp,pl_op}(f, x; tang=(dx,), res1=grad, res2=(dg,)),
331+
Scenario{:hvp,pl_op}(f, x, (dx,); res1=grad, res2=(dg,)),
336332
Scenario{:hessian,pl_op}(f, x; res1=grad, res2=hess),
337333
],
338334
)
339335
end
340336
for pl_op in (:out,)
341-
append!(scens, [Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,))])
337+
append!(scens, [Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))])
342338
end
343339
return scens
344340
end
@@ -351,8 +347,8 @@ function all_array_to_array_scenarios(f, x; dx, dy, dy_from_dx, dx_from_dy, jac)
351347
append!(
352348
scens,
353349
[
354-
Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,)),
355-
Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)),
350+
Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)),
351+
Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)),
356352
Scenario{:jacobian,pl_op}(f, x; res1=jac),
357353
],
358354
)
@@ -366,8 +362,8 @@ function all_array_to_array_scenarios(f!, y, x; dx, dy, dy_from_dx, dx_from_dy,
366362
append!(
367363
scens,
368364
[
369-
Scenario{:pushforward,pl_op}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)),
370-
Scenario{:pullback,pl_op}(f!, y, x; tang=(dy,), res1=(dx_from_dy,)),
365+
Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)),
366+
Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,)),
371367
Scenario{:jacobian,pl_op}(f!, y, x; res1=jac),
372368
],
373369
)
@@ -628,7 +624,7 @@ function default_scenarios(;
628624
)
629625

630626
scens = map(initialscens, smallerscens) do s1, s2
631-
set_smaller(s1, s2)
627+
s1 # TODO: readd smaller scens
632628
end
633629

634630
include_batchified && append!(scens, batchify(scens))

0 commit comments

Comments
 (0)