Skip to content
Merged
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/utils/prep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function check_prep(
if SIG != EXEC_SIG
throw(
PreparationMismatchError(
SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts]
SIG, EXEC_SIG; format=[:f, :backend, :x, :t, :contexts]
),
)
end
Expand All @@ -213,7 +213,7 @@ function check_prep(
if SIG != EXEC_SIG
throw(
PreparationMismatchError(
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts]
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :t, :contexts]
),
)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ function differentiatewith_scenarios()
DIT.function_place(scen) == :out
end
good_scens = map(bad_scens) do scen
DIT.change_function(
scen, DifferentiateWith(scen.f, AutoFiniteDiff()); keep_smaller=false
)
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
end
return good_scens
end
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Back/FiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ end
include_cachified=true,
include_constantorcachified=true,
use_tuples=true,
include_smaller=true,
);
excluded=[:second_derivative, :hvp],
logging=LOGGING,
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ end
include_cachified=true,
include_constantorcachified=true,
use_tuples=true,
include_smaller=true,
);
logging=LOGGING,
)
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/test/Core/Internals/signature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ end
- exec: Nothing
- backend: ✅
- x: ✅
- tang: ✅
- t: ✅
- contexts: ✅
""" pushforward(nothing, prep, backend, x, (x,), Constant(c))
end
Expand All @@ -119,7 +119,7 @@ end
- y: ✅
- backend: ✅
- x: ✅
- tang: ✅
- t: ✅
- contexts: ✅
""" pushforward(nothing, y, prep, backend, x, (x,))
end
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end
@testset "Dense" begin
test_differentiation(
vcat(backends, second_order_backends),
default_scenarios(; include_constantified=true);
default_scenarios(; include_constantified=true, include_smaller=true);
logging=LOGGING,
)

Expand Down
5 changes: 5 additions & 0 deletions DifferentiationInterfaceTest/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Specify preparation arguments in DIT Scenario ([#786])

## [0.9.6] - 2025-03-28

### Added
Expand All @@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.6...main
[0.9.6]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.5...DifferentiationInterfaceTest-v0.9.6

[#786]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/786
[#749]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/749
[#748]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/748
[#745]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/745
4 changes: 2 additions & 2 deletions DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterfaceTest"
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.9.6"
version = "0.10.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -44,7 +44,7 @@ AllocCheck = "0.2"
Chairmarks = "1.2.1"
ComponentArrays = "0.15"
DataFrames = "1.6.1"
DifferentiationInterface = "0.6.0"
DifferentiationInterface = "0.6.53"
DocStringExtensions = "0.8,0.9"
ExplicitImports = "1.10.1"
FiniteDiff = "2.27.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy
append!(
scens,
[
DIT.Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)),
DIT.Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)),
DIT.Scenario{:gradient,pl_op}(f, x; res1=grad),
],
)
end
for pl_op in (:out,)
append!(
scens, [DIT.Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,))]
)
append!(scens, [DIT.Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))])
end
return scens
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
g = gradient_finite_differences(square_loss, model, x)

scen = DIT.Scenario{:gradient,:out}(
square_loss, model; contexts=(DI.Constant(x),), res1=g
square_loss,
model,
DI.Constant(x);
prep_args=(x=model, contexts=(DI.Constant(x),)),
res1=g,
)
push!(scens, scen)

Expand Down Expand Up @@ -163,7 +167,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
Flux.trainmode!(model)
g = gradient_finite_differences(square_loss, model, x)
scen = DIT.Scenario{:gradient,:out}(
square_loss, model; contexts=(DI.Constant(x),), res1=g
square_loss,
model,
DI.Constant(x);
prep_args=(; x=model, contexts=(DI.Constant(x),)),
res1=g,
)
push!(scens, scen)
end
Expand Down Expand Up @@ -191,7 +199,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
Flux.trainmode!(model)
g = gradient_finite_differences(square_loss_iterated, model, x)
scen = DIT.Scenario{:gradient,:out}(
square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g
square_loss_iterated,
model,
DI.Constant(x);
prep_args=(; x=model, contexts=(DI.Constant(x),)),
res1=g,
)
push!(scens, scen)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap
myjl(::Nothing) = nothing

function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f, x, y, tang, contexts, res1, res2) = scen
return DIT.Scenario{op,pl_op,pl_fun}(
myjl(f);
(; f, x, y, t, contexts, prep_args, res1, res2, name) = scen
return DIT.Scenario{op,pl_op,pl_fun}(;
f=myjl(f),
x=myjl(x),
y=myjl(y),
tang=myjl(tang),
t=myjl(t),
contexts=myjl(contexts),
prep_args=map(myjl, prep_args),
res1=myjl(res1),
res2=myjl(res2),
name,
)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,14 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
)
scen = DIT.Scenario{:gradient,:out}(
square_loss,
ComponentArray(ps);
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
ComponentArray(ps),
DI.Constant(model),
DI.Constant(x),
DI.Constant(st);
prep_args=(
x=ComponentArray(ps),
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
),
res1=g,
)
push!(scens, scen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DifferentiationInterfaceTestStaticArraysExt
import DifferentiationInterface as DI
import DifferentiationInterfaceTest as DIT
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
using StaticArrays: StaticArray, MArray, MMatrix, MVector, SArray, SMatrix, SVector

static_num_to_vec(x::Number) = sin.(SVector(1, 2) .* x)
static_num_to_mat(x::Number) = hcat(static_num_to_vec(x), static_num_to_vec(3x))
Expand Down Expand Up @@ -36,15 +36,23 @@ end
mystatic(::Nothing) = nothing

function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f, x, y, tang, contexts, res1, res2) = scen
return DIT.Scenario{op,pl_op,pl_fun}(
mystatic(f);
(; f, x, y, t, contexts, prep_args, res1, res2, name) = scen
new_prep_args = (;
x=mystatic(prep_args.x), contexts=map(mystatic, prep_args.contexts), t=mystatic(t)
)
if pl_fun == :in
new_prep_args = (; new_prep_args..., y=mymutablestatic(prep_args.y))
end
return DIT.Scenario{op,pl_op,pl_fun}(;
f=mystatic(f),
x=mystatic(x),
y=pl_fun == :in ? mymutablestatic(y) : mystatic(y),
tang=mystatic(tang),
t=mystatic(t),
contexts=mystatic(contexts),
prep_args=new_prep_args,
res1=mystatic(res1),
res2=mystatic(res2),
name=name,
)
end

Expand Down
12 changes: 6 additions & 6 deletions DifferentiationInterfaceTest/src/scenarios/allocfree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ function identity_scenarios(x::Number; dx::Number, dy::Number)
der = one(x)

return [
Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)),
Scenario{:pullback,:out}(f, x; tang=(dy,), res1=(dx_from_dy,)),
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)),
Scenario{:derivative,:out}(f, x; res1=der),
]
end
Expand All @@ -19,8 +19,8 @@ function sum_scenarios(x::AbstractArray; dx::AbstractArray, dy::Number)
grad .= one(eltype(x))

return [
Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)),
Scenario{:pullback,:in}(f, x; tang=(dy,), res1=(dx_from_dy,)),
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
Scenario{:pullback,:in}(f, x, (dy,); res1=(dx_from_dy,)),
Scenario{:gradient,:in}(f, x; res1=grad),
]
end
Expand All @@ -34,8 +34,8 @@ function copyto!_scenarios(x::AbstractArray; dx::AbstractArray, dy::AbstractArra
jac = Matrix(Diagonal(ones(eltype(x), length(x))))

return [
Scenario{:pushforward,:in}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)),
Scenario{:pullback,:in}(f!, y, x; tang=(dy,), res1=(dx_from_dy,)),
Scenario{:pushforward,:in}(f!, y, x, (dx,); res1=(dy_from_dx,)),
Scenario{:pullback,:in}(f!, y, x, (dy,); res1=(dx_from_dy,)),
Scenario{:jacobian,:in}(f!, y, x; res1=jac),
]
end
Expand Down
8 changes: 4 additions & 4 deletions DifferentiationInterfaceTest/src/scenarios/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ function complex_holomorphic_gradient_scenarios()
scens = Scenario[
Scenario{:gradient,:out}(square_only, x; res1=grad),
Scenario{:gradient,:in}(square_only, x; res1=grad),
Scenario{:pullback,:out}(square_only, x; tang=(dy,), res1=(grad,)),
Scenario{:pullback,:in}(square_only, x; tang=(dy,), res1=(grad,)),
Scenario{:pullback,:out}(square_only, x, (dy,); res1=(grad,)),
Scenario{:pullback,:in}(square_only, x, (dy,); res1=(grad,)),
]
return scens
end
Expand All @@ -22,8 +22,8 @@ function complex_gradient_scenarios()
scens = Scenario[
Scenario{:gradient,:out}(abs2_only, x; res1=grad),
Scenario{:gradient,:in}(abs2_only, x; res1=grad),
Scenario{:pullback,:out}(abs2_only, x; tang=(dy,), res1=(grad,)),
Scenario{:pullback,:in}(abs2_only, x; tang=(dy,), res1=(grad,)),
Scenario{:pullback,:out}(abs2_only, x, (dy,); res1=(grad,)),
Scenario{:pullback,:in}(abs2_only, x, (dy,); res1=(grad,)),
]
return scens
end
Expand Down
Loading