Skip to content

Commit fc7aabc

Browse files
committed
Fixes
1 parent 1de1253 commit fc7aabc

5 files changed

Lines changed: 20 additions & 12 deletions

File tree

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,13 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
162162
for (model, x) in models_and_xs
163163
Flux.trainmode!(model)
164164
g = gradient_finite_differences(square_loss, model, x)
165-
scen = DIT.Scenario{:gradient,:out}(square_loss, model, DI.Constant(x); res1=g)
165+
scen = DIT.Scenario{:gradient,:out}(
166+
square_loss,
167+
model,
168+
DI.Constant(x);
169+
prep_args=(; x=model, contexts=(DI.Constant(x),)),
170+
res1=g,
171+
)
166172
push!(scens, scen)
167173
end
168174

@@ -189,7 +195,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
189195
Flux.trainmode!(model)
190196
g = gradient_finite_differences(square_loss_iterated, model, x)
191197
scen = DIT.Scenario{:gradient,:out}(
192-
square_loss_iterated, model, DI.Constant(x); res1=g
198+
square_loss_iterated,
199+
model,
200+
DI.Constant(x);
201+
prep_args=(; x=model, contexts=(DI.Constant(x),)),
202+
res1=g,
193203
)
194204
push!(scens, scen)
195205
end

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap
2323
myjl(::Nothing) = nothing
2424

2525
function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
26-
(; f, x, y, t, contexts, res1, res2) = scen
26+
(; f, x, y, t, contexts, prep_args, res1, res2) = scen
2727
return DIT.Scenario{op,pl_op,pl_fun}(;
2828
f=myjl(f),
2929
x=myjl(x),
3030
y=myjl(y),
3131
t=myjl(t),
3232
contexts=myjl(contexts),
33+
prep_args=map(myjl, prep_args),
3334
res1=myjl(res1),
3435
res2=myjl(res2),
3536
)

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
203203
DI.Constant(model),
204204
DI.Constant(x),
205205
DI.Constant(st);
206+
prep_args=(
207+
x=ComponentArray(ps),
208+
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
209+
),
206210
res1=g,
207211
)
208212
push!(scens, scen)

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@ end
3636
mystatic(::Nothing) = nothing
3737

3838
function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
39-
(; f, x, y, t, contexts, res1, res2) = scen
39+
(; f, x, y, t, contexts, prep_args, res1, res2) = scen
4040
return DIT.Scenario{op,pl_op,pl_fun}(;
4141
f=mystatic(f),
4242
x=mystatic(x),
4343
y=pl_fun == :in ? mymutablestatic(y) : mystatic(y),
4444
t=mystatic(t),
4545
contexts=mystatic(contexts),
46+
prep_args=map(mystatic, prep_args),
4647
res1=mystatic(res1),
4748
res2=mystatic(res2),
4849
)

DifferentiationInterfaceTest/src/utils.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,6 @@ myzero(::Nothing) = nothing
66
mysimilar(x::Number) = one(x)
77
mysimilar(x::AbstractArray) = similar(x)
88
mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x)
9-
mysimilar(x) = deepcopy(x)
10-
11-
myrandom(rng::AbstractRNG, x::Number) = randn(rng, typeof(x))
12-
myrandom(rng::AbstractRNG, x::AbstractArray) = map(Base.Fix1(myrandom, rng), x)
13-
myrandom(rng::AbstractRNG, x::Union{Tuple,NamedTuple}) = map(Base.Fix1(myrandom, rng), x)
14-
myrandom(rng::AbstractRNG, x) = deepcopy(x)
15-
16-
myrandom(x) = myrandom(default_rng(), x)
179

1810
mysize(x::Number) = size(x)
1911
mysize(x::AbstractArray) = size(x)

0 commit comments

Comments
 (0)