From 8ab3a92181be5651d381316f15431007de4c9db3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 9 May 2025 17:05:48 +0200 Subject: [PATCH 01/12] feat!: specify preparation arguments in DIT `Scenario` --- DifferentiationInterface/src/utils/prep.jl | 4 +- DifferentiationInterfaceTest/Project.toml | 4 +- ...ntiationInterfaceTestComponentArraysExt.jl | 6 +- .../DifferentiationInterfaceTestFluxExt.jl | 6 +- ...DifferentiationInterfaceTestJLArraysExt.jl | 8 +- .../DifferentiationInterfaceTestLuxExt.jl | 6 +- ...erentiationInterfaceTestStaticArraysExt.jl | 8 +- .../src/scenarios/allocfree.jl | 12 +- .../src/scenarios/complex.jl | 8 +- .../src/scenarios/default.jl | 48 +- .../src/scenarios/modify.jl | 120 +++-- .../src/scenarios/scenario.jl | 117 +++-- .../src/test_differentiation.jl | 7 +- .../src/tests/allocs_eval.jl | 175 ++++--- .../src/tests/benchmark_eval.jl | 249 +++++---- .../src/tests/correctness_eval.jl | 491 +++++------------- .../src/tests/type_stability_eval.jl | 144 ++--- 17 files changed, 636 insertions(+), 777 deletions(-) diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 169afda86..2d0107d96 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -198,7 +198,7 @@ function check_prep( if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts] + SIG, EXEC_SIG; format=[:f, :backend, :x, :t, :contexts] ), ) end @@ -213,7 +213,7 @@ function check_prep( if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts] + SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :t, :contexts] ), ) end diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 26b7ccd4c..be5e850a6 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.9.6" +version = "0.10.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -43,7 +43,7 @@ AllocCheck = "0.2" Chairmarks = "1.2.1" ComponentArrays = "0.15" DataFrames = "1.6.1" -DifferentiationInterface = "0.6.0" +DifferentiationInterface = "0.6.53" DocStringExtensions = "0.8,0.9" ExplicitImports = "1.10.1" FiniteDifferences = "0.12" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl index ec5e37b27..2b432ad3a 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl @@ -33,15 +33,13 @@ function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy append!( scens, [ - DIT.Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)), + DIT.Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)), DIT.Scenario{:gradient,pl_op}(f, x; res1=grad), ], ) end for pl_op in (:out,) - append!( - scens, [DIT.Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,))] - ) + append!(scens, [DIT.Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))]) end return scens end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index d0825cee3..040c4d6ba 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -162,9 +162,7 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) for (model, x) in models_and_xs Flux.trainmode!(model) g = gradient_finite_differences(square_loss, model, x) - scen = DIT.Scenario{:gradient,:out}( - square_loss, model; contexts=(DI.Constant(x),), res1=g - ) + scen = DIT.Scenario{:gradient,:out}(square_loss, model, DI.Constant(x); res1=g) push!(scens, scen) end @@ -191,7 +189,7 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) Flux.trainmode!(model) g = gradient_finite_differences(square_loss_iterated, model, x) scen = DIT.Scenario{:gradient,:out}( - square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g + square_loss_iterated, model, DI.Constant(x); res1=g ) push!(scens, scen) end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index 11dca2cc5..16053fad5 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -23,12 +23,12 @@ myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap myjl(::Nothing) = nothing function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, tang, contexts, res1, res2) = scen - return DIT.Scenario{op,pl_op,pl_fun}( - myjl(f); + (; f, x, y, t, contexts, res1, res2) = scen + return DIT.Scenario{op,pl_op,pl_fun}(; + f=myjl(f), x=myjl(x), y=myjl(y), - tang=myjl(tang), + t=myjl(t), contexts=myjl(contexts), res1=myjl(res1), res2=myjl(res2), diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index 3b8507a2d..d5c05a40a 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -199,8 +199,10 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) ) scen = DIT.Scenario{:gradient,:out}( square_loss, - ComponentArray(ps); - contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)), + ComponentArray(ps), + DI.Constant(model), + DI.Constant(x), + DI.Constant(st); res1=g, ) push!(scens, scen) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index a7620b516..fa33c5818 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -36,12 +36,12 @@ end mystatic(::Nothing) = nothing function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, tang, contexts, res1, res2) = scen - return DIT.Scenario{op,pl_op,pl_fun}( - mystatic(f); + (; f, x, y, t, contexts, res1, res2) = scen + return DIT.Scenario{op,pl_op,pl_fun}(; + f=mystatic(f), x=mystatic(x), y=pl_fun == :in ? mymutablestatic(y) : mystatic(y), - tang=mystatic(tang), + t=mystatic(t), contexts=mystatic(contexts), res1=mystatic(res1), res2=mystatic(res2), diff --git a/DifferentiationInterfaceTest/src/scenarios/allocfree.jl b/DifferentiationInterfaceTest/src/scenarios/allocfree.jl index f941c1fd1..2134faabc 100644 --- a/DifferentiationInterfaceTest/src/scenarios/allocfree.jl +++ b/DifferentiationInterfaceTest/src/scenarios/allocfree.jl @@ -5,8 +5,8 @@ function identity_scenarios(x::Number; dx::Number, dy::Number) der = one(x) return [ - Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)), - Scenario{:pullback,:out}(f, x; tang=(dy,), res1=(dx_from_dy,)), + Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), + Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)), Scenario{:derivative,:out}(f, x; res1=der), ] end @@ -19,8 +19,8 @@ function sum_scenarios(x::AbstractArray; dx::AbstractArray, dy::Number) grad .= one(eltype(x)) return [ - Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)), - Scenario{:pullback,:in}(f, x; tang=(dy,), res1=(dx_from_dy,)), + Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), + Scenario{:pullback,:in}(f, x, (dy,); res1=(dx_from_dy,)), Scenario{:gradient,:in}(f, x; res1=grad), ] end @@ -34,8 +34,8 @@ function copyto!_scenarios(x::AbstractArray; dx::AbstractArray, dy::AbstractArra jac = Matrix(Diagonal(ones(eltype(x), length(x)))) return [ - Scenario{:pushforward,:in}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)), - Scenario{:pullback,:in}(f!, y, x; tang=(dy,), res1=(dx_from_dy,)), + Scenario{:pushforward,:in}(f!, y, x, (dx,); res1=(dy_from_dx,)), + Scenario{:pullback,:in}(f!, y, x, (dy,); res1=(dx_from_dy,)), Scenario{:jacobian,:in}(f!, y, x; res1=jac), ] end diff --git a/DifferentiationInterfaceTest/src/scenarios/complex.jl b/DifferentiationInterfaceTest/src/scenarios/complex.jl index 2bb673487..39980d521 100644 --- a/DifferentiationInterfaceTest/src/scenarios/complex.jl +++ b/DifferentiationInterfaceTest/src/scenarios/complex.jl @@ -9,8 +9,8 @@ function complex_holomorphic_gradient_scenarios() scens = Scenario[ Scenario{:gradient,:out}(square_only, x; res1=grad), Scenario{:gradient,:in}(square_only, x; res1=grad), - Scenario{:pullback,:out}(square_only, x; tang=(dy,), res1=(grad,)), - Scenario{:pullback,:in}(square_only, x; tang=(dy,), res1=(grad,)), + Scenario{:pullback,:out}(square_only, x, (dy,); res1=(grad,)), + Scenario{:pullback,:in}(square_only, x, (dy,); res1=(grad,)), ] return scens end @@ -22,8 +22,8 @@ function complex_gradient_scenarios() scens = Scenario[ Scenario{:gradient,:out}(abs2_only, x; res1=grad), Scenario{:gradient,:in}(abs2_only, x; res1=grad), - Scenario{:pullback,:out}(abs2_only, x; tang=(dy,), res1=(grad,)), - Scenario{:pullback,:in}(abs2_only, x; tang=(dy,), res1=(grad,)), + Scenario{:pullback,:out}(abs2_only, x, (dy,); res1=(grad,)), + Scenario{:pullback,:in}(abs2_only, x, (dy,); res1=(grad,)), ] return scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index 7523365c8..b601a3b70 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -27,8 +27,8 @@ function num_to_num_scenarios(x::Number; dx::Number, dy::Number) # everyone out of place scens = Scenario[ - Scenario{:pushforward,:out}(f, x; tang=(dx,), res1=(dy_from_dx,)), - Scenario{:pullback,:out}(f, x; tang=(dy,), res1=(dx_from_dy,)), + Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), + Scenario{:pullback,:out}(f, x, (dy,); res1=(dx_from_dy,)), Scenario{:derivative,:out}(f, x; res1=der), Scenario{:second_derivative,:out}(f, x; res1=der, res2=der2), ] @@ -57,10 +57,10 @@ function onevec_to_onevec_scenarios_onearg(x::Number; dx::Number, dy::Number) scens, [ Scenario{:pushforward,pl_op}( - onevec_to_onevec, [x]; tang=([dx],), res1=([dy_from_dx],) + onevec_to_onevec, [x], ([dx],); res1=([dy_from_dx],) ), Scenario{:pullback,pl_op}( - onevec_to_onevec, [x]; tang=([dy],), res1=([dx_from_dy],) + onevec_to_onevec, [x], ([dy],); res1=([dx_from_dy],) ), Scenario{:jacobian,pl_op}(onevec_to_onevec, [x]; res1=jac), ], @@ -85,10 +85,10 @@ function onevec_to_onevec_scenarios_twoarg(x::Number; dx::Number, dy::Number) scens, [ Scenario{:pushforward,pl_op}( - onevec_to_onevec!, [y], [x]; tang=([dx],), res1=([dy_from_dx],) + onevec_to_onevec!, [y], [x], ([dx],); res1=([dy_from_dx],) ), Scenario{:pullback,pl_op}( - onevec_to_onevec!, [y], [x]; tang=([dy],), res1=([dx_from_dy],) + onevec_to_onevec!, [y], [x], ([dy],); res1=([dx_from_dy],) ), Scenario{:jacobian,pl_op}(onevec_to_onevec!, [y], [x]; res1=jac), ], @@ -137,14 +137,14 @@ function num_to_vec_scenarios_onearg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,)), + Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)), Scenario{:derivative,pl_op}(f, x; res1=der), Scenario{:second_derivative,pl_op}(f, x; res1=der, res2=der2), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,))]) + append!(scens, [Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,))]) end return scens end @@ -163,15 +163,13 @@ function num_to_vec_scenarios_twoarg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)), + Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)), Scenario{:derivative,pl_op}(f!, y, x; res1=der), ], ) end for pl_op in (:out,) - append!( - scens, [Scenario{:pullback,pl_op}(f!, y, x; tang=(dy,), res1=(dx_from_dy,))] - ) + append!(scens, [Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,))]) end return scens end @@ -225,14 +223,14 @@ function num_to_mat_scenarios_onearg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,)), + Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)), Scenario{:derivative,pl_op}(f, x; res1=der), Scenario{:second_derivative,pl_op}(f, x; res1=der, res2=der2), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,))]) + append!(scens, [Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,))]) end return scens end @@ -251,15 +249,13 @@ function num_to_mat_scenarios_twoarg(x::Number; dx::Number, dy::AbstractArray) append!( scens, [ - Scenario{:pushforward,pl_op}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)), + Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)), Scenario{:derivative,pl_op}(f!, y, x; res1=der), ], ) end for pl_op in (:out,) - append!( - scens, [Scenario{:pullback,pl_op}(f!, y, x; tang=(dy,), res1=(dx_from_dy,))] - ) + append!(scens, [Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,))]) end return scens end @@ -330,15 +326,15 @@ function arr_to_num_scenarios_onearg( append!( scens, [ - Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)), + Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)), Scenario{:gradient,pl_op}(f, x; res1=grad), - Scenario{:hvp,pl_op}(f, x; tang=(dx,), res1=grad, res2=(dg,)), + Scenario{:hvp,pl_op}(f, x, (dx,); res1=grad, res2=(dg,)), Scenario{:hessian,pl_op}(f, x; res1=grad, res2=hess), ], ) end for pl_op in (:out,) - append!(scens, [Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,))]) + append!(scens, [Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))]) end return scens end @@ -351,8 +347,8 @@ function all_array_to_array_scenarios(f, x; dx, dy, dy_from_dx, dx_from_dy, jac) append!( scens, [ - Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,)), - Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)), + Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,)), + Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)), Scenario{:jacobian,pl_op}(f, x; res1=jac), ], ) @@ -366,8 +362,8 @@ function all_array_to_array_scenarios(f!, y, x; dx, dy, dy_from_dx, dx_from_dy, append!( scens, [ - Scenario{:pushforward,pl_op}(f!, y, x; tang=(dx,), res1=(dy_from_dx,)), - Scenario{:pullback,pl_op}(f!, y, x; tang=(dy,), res1=(dx_from_dy,)), + Scenario{:pushforward,pl_op}(f!, y, x, (dx,); res1=(dy_from_dx,)), + Scenario{:pullback,pl_op}(f!, y, x, (dy,); res1=(dx_from_dy,)), Scenario{:jacobian,pl_op}(f!, y, x; res1=jac), ], ) @@ -628,7 +624,7 @@ function default_scenarios(; ) scens = map(initialscens, smallerscens) do s1, s2 - set_smaller(s1, s2) + s1 # TODO: readd smaller scens end include_batchified && append!(scens, batchify(scens)) diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index ee794892d..b6cb372cb 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -6,15 +6,15 @@ abstract type FunctionModifier end Return a new `Scenario` identical to `scen` except for the first- and second-order results which are set to zero. """ function Base.zero(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - return Scenario{op,pl_op,pl_fun}( - scen.f; + return Scenario{op,pl_op,pl_fun}(; + f=scen.f, x=scen.x, y=scen.y, - tang=scen.tang, + t=scen.t, contexts=scen.contexts, res1=myzero(scen.res1), res2=myzero(scen.res2), - smaller=isnothing(scen.smaller) ? nothing : zero(scen.smaller), + prep_args=scen.prep_args, name=isnothing(scen.name) ? nothing : scen.name * " [zero]", ) end @@ -27,39 +27,19 @@ Return a new `Scenario` identical to `scen` except for the function `f` which is function change_function( scen::Scenario{op,pl_op,pl_fun}, new_f; keep_smaller ) where {op,pl_op,pl_fun} - return Scenario{op,pl_op,pl_fun}( - new_f; + return Scenario{op,pl_op,pl_fun}(; + f=new_f, x=scen.x, y=scen.y, - tang=scen.tang, + t=scen.t, contexts=scen.contexts, res1=scen.res1, res2=scen.res2, - smaller=if isnothing(scen.smaller) || !keep_smaller - nothing - else - change_function(scen.smaller, new_f; keep_smaller=false) - end, + prep_args=scen.prep_args, name=isnothing(scen.name) ? nothing : scen.name * " [new function]", ) end -function set_smaller( - scen::Scenario{op,pl_op,pl_fun}, smaller::Scenario -) where {op,pl_op,pl_fun} - @assert scen.f == smaller.f - return Scenario{op,pl_op,pl_fun}( - scen.f; - x=scen.x, - y=scen.y, - tang=scen.tang, - contexts=scen.contexts, - res1=scen.res1, - res2=scen.res2, - smaller=smaller, - ) -end - """ batchify(scen::Scenario) @@ -68,33 +48,46 @@ Return a new `Scenario` identical to `scen` except for the tangents `tang` and a Only works if `scen` is a `pushforward`, `pullback` or `hvp` scenario. """ function batchify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, tang, contexts, res1, res2, smaller) = scen + (; f, x, y, t, contexts, res1, res2, prep_args) = scen + new_t = (only(t), -only(t)) + new_prep_args = if pl_fun == :out + (; + x=prep_args.x, + contexts=prep_args.contexts, + t=(only(prep_args.t), -only(prep_args.t)), + ) + else + (; + y=prep_args.y, + x=prep_args.x, + contexts=prep_args.contexts, + t=(only(prep_args.t), -only(prep_args.t)), + ) + end if op == :pushforward || op == :pullback - new_tang = (only(tang), -only(tang)) new_res1 = (only(res1), -only(res1)) - return Scenario{op,pl_op,pl_fun}( - f; + return Scenario{op,pl_op,pl_fun}(; + f, x, y, - tang=new_tang, + t=new_t, contexts, res1=new_res1, res2, - smaller=isnothing(smaller) ? nothing : batchify(smaller), + prep_args=new_prep_args, name=isnothing(scen.name) ? nothing : scen.name * " [batchified]", ) elseif op == :hvp - new_tang = (only(tang), -only(tang)) new_res2 = (only(res2), -only(res2)) - return Scenario{op,pl_op,pl_fun}( - f; + return Scenario{op,pl_op,pl_fun}(; + f, x, y, - tang=new_tang, + t=new_t, contexts, res1, res2=new_res2, - smaller=isnothing(smaller) ? nothing : batchify(smaller), + prep_args=new_prep_args, name=isnothing(scen.name) ? nothing : scen.name * " [batchified]", ) end @@ -145,16 +138,16 @@ function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} a = 3.0 b = [4.0] closure_f = WritableClosure{pl_fun}(f, x_buffer, y_buffer, a, b) - return Scenario{op,pl_op,pl_fun}( - closure_f; - x = scen.x, - y = mymultiply(scen.y, a + only(b)), - tang = scen.tang, - contexts = scen.contexts, - res1 = mymultiply(scen.res1, a + only(b)), - res2 = mymultiply(scen.res2, a + only(b)), - smaller = nothing, - name = isnothing(scen.name) ? nothing : scen.name * " [closurified]", + return Scenario{op,pl_op,pl_fun}(; + f=closure_f, + x=scen.x, + y=mymultiply(scen.y, a + only(b)), + t=scen.t, + contexts=scen.contexts, + res1=mymultiply(scen.res1, a + only(b)), + res2=mymultiply(scen.res2, a + only(b)), + prep_args=scen.prep_args, + name=isnothing(scen.name) ? nothing : scen.name * " [closurified]", ) end @@ -188,15 +181,15 @@ function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} @assert isempty(scen.contexts) multiply_f = MultiplyByConstant{pl_fun}(f) a = 3.0 - return Scenario{op,pl_op,pl_fun}( - multiply_f; + return Scenario{op,pl_op,pl_fun}(; + f=multiply_f, x=scen.x, y=mymultiply(scen.y, a), - tang=scen.tang, + t=scen.t, contexts=(Constant(a),), res1=mymultiply(scen.res1, a), res2=mymultiply(scen.res2, a), - smaller=isnothing(scen.smaller) ? nothing : constantify(scen.smaller), + prep_args=(; scen.prep_args..., contexts=(Constant(-a),)), name=isnothing(scen.name) ? nothing : scen.name * " [constantified]", ) end @@ -257,15 +250,15 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl mysimilar(scen.y) end end - return Scenario{op,pl_op,pl_fun}( - cache_f; + return Scenario{op,pl_op,pl_fun}(; + f=cache_f, x=scen.x, y=scen.y, - tang=scen.tang, + t=scen.t, contexts=(Cache(y_cache),), res1=scen.res1, res2=scen.res2, - smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller; use_tuples), + prep_args=(; scen.prep_args..., contexts=(Cache(y_cache),)), name=isnothing(scen.name) ? nothing : scen.name * " [cachified]", ) end @@ -332,15 +325,20 @@ function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_f else (; cache=mysimilar(scen.y), constant=(; a, b)) end - return Scenario{op,pl_op,pl_fun}( - constantorcache_f; + prep_constantorcache = if scen.y isa Number + (; cache=[myzero(scen.y)], constant=(; a=2a, b=3b)) + else + (; cache=mysimilar(scen.y), constant=(; a=2a, b=3b)) + end + return Scenario{op,pl_op,pl_fun}(; + f=constantorcache_f, x=scen.x, y=mymultiply(scen.y, a + only(b)), - tang=scen.tang, + t=scen.t, contexts=(ConstantOrCache(constantorcache),), res1=mymultiply(scen.res1, a + only(b)), res2=mymultiply(scen.res2, a + only(b)), - smaller=isnothing(scen.smaller) ? nothing : constantorcachify(scen.smaller), + prep_args=(; scen.prep_args..., contexts=(ConstantOrCache(prep_constantorcache),)), name=isnothing(scen.name) ? nothing : scen.name * " [constantorcachified]", ) end diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index a36e1c1e5..c1d7e31cf 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -13,81 +13,118 @@ This generic type should never be used directly: use the specific constructor co # Constructors - Scenario{op,pl_op}(f, x; tang, contexts, res1, res2, name) - Scenario{op,pl_op}(f!, y, x; tang, contexts, res1, res2, name) + Scenario{op,pl_op}(f, x, [t], contexts...; res1, res2, name) + Scenario{op,pl_op}(f!, y, x, [t,] contexts...; res1, res2, name) # Fields $(TYPEDFIELDS) """ -struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,S} +struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,P<:NamedTuple} "function `f` (if `pl_fun==:out`) or `f!` (if `pl_fun==:in`) to apply" f::F - "primal input" - x::X "primal output" y::Y - "tangents for pushforward, pullback or HVP" - tang::T + "primal input" + x::X + "tangents (if applicable)" + t::T "contexts (if applicable)" contexts::C "first-order result of the operator (if applicable)" res1::R1 "second-order result of the operator (if applicable)" res2::R2 - "private field (not part of the public API) containing a variant of the scenario to test preparation resizing" - smaller::S + "named tuple of arguments passed to preparation, without the function" + prep_args::P "name of the scenario for display in test sets and dataframes" name::Union{String,Nothing} + + function Scenario{op,pl_op,pl_fun}(; + f::F, + y::Y, + x::X, + t::T, + contexts::C, + res1::R1, + res2::R2, + prep_args::P, + name::Union{String,Nothing}, + ) where {op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,P} + @assert op in ALL_OPS + @assert pl_op in (:in, :out) + @assert pl_fun in (:in, :out) + return new{op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,P}( + f, y, x, t, contexts, res1, res2, prep_args, name + ) + end end -function Scenario{op,pl_op,pl_fun}( - f::F; - x::X, - y::Y, - tang::T, - contexts::C, - res1::R1, - res2::R2, - smaller::S=nothing, +function myzero_contexts(contexts...) + rewrap = Rewrap(contexts...) + return rewrap(map(myzero ∘ unwrap, contexts)...) +end + +function Scenario{op,pl_op}( + f, + x, + contexts::Vararg{Context}; + res1=nothing, + res2=nothing, + prep_args=(; x=myzero(x), contexts=myzero_contexts(contexts...)), name=nothing, -) where {op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,S<:Union{Nothing,Scenario}} - @assert smaller isa Union{Nothing,Scenario{op,pl_op,pl_fun,F,X,Y,T,C,R1,R2}} - return Scenario{op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,S}( - f, x, y, tang, contexts, res1, res2, smaller, name +) where {op,pl_op} + y = f(x, map(unwrap, contexts)...) + return Scenario{op,pl_op,:out}(; + f, y, x, t=nothing, contexts, res1, res2, prep_args, name + ) +end + +function Scenario{op,pl_op}( + f, + y, + x, + contexts::Vararg{Context}; + res1=nothing, + res2=nothing, + prep_args=(; y=myzero(y), x=myzero(x), contexts=myzero_contexts(contexts...)), + name=nothing, +) where {op,pl_op} + f(y, x, map(unwrap, contexts)...) + return Scenario{op,pl_op,:in}(; + f, y, x, t=nothing, contexts, res1, res2, prep_args, name ) end function Scenario{op,pl_op}( f, - x; - tang=nothing, - contexts=(), + x, + t::NTuple, + contexts::Vararg{Context}; res1=nothing, res2=nothing, - smaller=nothing, + prep_args=(; x=myzero(x), t=map(myzero, t), contexts=myzero_contexts(contexts...)), name=nothing, ) where {op,pl_op} - @assert op in ALL_OPS - @assert pl_op in (:in, :out) y = f(x, map(unwrap, contexts)...) - return Scenario{op,pl_op,:out}(f; x, y, tang, contexts, res1, res2, smaller, name) + return Scenario{op,pl_op,:out}(; f, y, x, t, contexts, res1, res2, prep_args, name) end function Scenario{op,pl_op}( - f!, + f, y, - x; - tang=nothing, - contexts=(), + x, + t::NTuple, + contexts::Vararg{Context}; res1=nothing, res2=nothing, - smaller=nothing, + prep_args=(; + y=myzero(y), x=myzero(x), t=map(myzero, t), contexts=myzero_contexts(contexts...) + ), name=nothing, ) where {op,pl_op} - @assert op in ALL_OPS - @assert pl_op in (:in, :out) - return Scenario{op,pl_op,:in}(f!; x, y, tang, contexts, res1, res2, smaller, name) + f(y, x, map(unwrap, contexts)...) + return Scenario{op,pl_op,:in}(; f, y, x, t, contexts, res1, res2, prep_args, name) end Base.:(==)(scen1::Scenario, scen2::Scenario) = false @@ -98,7 +135,7 @@ function Base.:(==)( eq_f = scen1.f == scen2.f eq_x = scen1.x == scen2.x eq_y = scen1.y == scen2.y - eq_tang = scen1.tang == scen2.tang + eq_t = scen1.t == scen2.t eq_contexts = all( map(scen1.contexts, scen2.contexts) do c1, c2 if c1 isa Union{Cache,ConstantOrCache} || c2 isa Union{Cache,ConstantOrCache} @@ -111,7 +148,7 @@ function Base.:(==)( eq_res1 = scen1.res1 == scen2.res1 eq_res2 = scen1.res2 == scen2.res2 eq_name = scen1.name == scen2.name - return (eq_x && eq_y && eq_tang && eq_contexts && eq_res1 && eq_res2 && eq_name) + return (eq_x && eq_y && eq_t && eq_contexts && eq_res1 && eq_res2 && eq_name) end operator(::Scenario{op}) where {op} = op @@ -152,7 +189,7 @@ function Base.show( if isnothing(scen.name) print(io, "Scenario{$(repr(op)),$(repr(pl_op))} $(string(scen.f)) : $X -> $Y") if op in (:pushforward, :pullback, :hvp) - print(io, " ($(length(scen.tang)) tangents)") + print(io, " ($(length(scen.t)) tangents)") end if length(scen.contexts) > 0 print(io, " ($(length(scen.contexts)) contexts)") diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index d36cea933..3fd4023b0 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -6,7 +6,7 @@ Apply a list of `backends` on a list of `scenarios`, running a variety of differ # Return This function always creates and runs a `@testset`, though its contents may vary. - + - if `benchmark == :none`, it returns `nothing`. - if `benchmark != :none`, it returns a `DataFrame` of benchmark results, whose columns correspond to the fields of [`DifferentiationBenchmarkDataRow`](@ref). @@ -142,10 +142,7 @@ function test_differentiation( (:input_size, mysize(scen.x)), (:output_type, typeof(scen.y)), (:output_size, mysize(scen.y)), - ( - :nb_tangents, - scen.tang isa NTuple ? length(scen.tang) : nothing, - ), + (:nb_tangents, scen.t isa NTuple ? length(scen.t) : nothing), (:nb_contexts, length(scen.contexts)), ], ) diff --git a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl index 94b5df67f..8ba00e2a9 100644 --- a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl @@ -27,9 +27,10 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool ) - (; f, x, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, contexts...) + (; f, x, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + (subset == :full) && + test_noallocs(skip, $prep_op, f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && test_noallocs(skip, $op, f, ba, x, contexts...) (subset == :full) && test_noallocs(skip, $val_and_op, f, ba, x, contexts...) (subset != :none) && test_noallocs(skip, $op, f, prep, ba, x, contexts...) @@ -41,10 +42,11 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool ) - (; f, x, res1, contexts) = deepcopy(scen) + (; f, x, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) - prep = $prep_op(f, ba, x, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, contexts...) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + (subset == :full) && + test_noallocs(skip, $prep_op, f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && test_noallocs(skip, $op!, f, res1_sim, prep, ba, x, contexts...) (subset == :full) && @@ -61,9 +63,11 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S2out; subset::Symbol, skip::Bool ) - (; f, x, y, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, y, ba, x, contexts...) + (; f, x, y, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) + (subset == :full) && test_noallocs( + skip, $prep_op, f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) (subset == :full) && test_noallocs(skip, $op, f, y, ba, x, contexts...) (subset == :full) && test_noallocs(skip, $val_and_op, f, y, ba, x, contexts...) (subset != :none) && test_noallocs(skip, $op, f, y, prep, ba, x, contexts...) @@ -75,10 +79,12 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S2in; subset::Symbol, skip::Bool ) - (; f, x, y, res1, contexts) = deepcopy(scen) + (; f, x, y, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) - prep = $prep_op(f, y, ba, x, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, y, ba, x, contexts...) + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) + (subset == :full) && test_noallocs( + skip, $prep_op, f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) (subset == :full) && test_noallocs(skip, $op!, f, y, res1_sim, ba, x, contexts...) (subset == :full) && @@ -94,9 +100,10 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool ) - (; f, x, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, contexts...) + (; f, x, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + (subset == :full) && + test_noallocs(skip, $prep_op, f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && test_noallocs(skip, $op, f, ba, x, contexts...) (subset == :full) && test_noallocs(skip, $val_and_op, f, ba, x, contexts...) (subset != :none) && test_noallocs(skip, $op, f, prep, ba, x, contexts...) @@ -108,10 +115,11 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool ) - (; f, x, res1, res2, contexts) = deepcopy(scen) + (; f, x, res1, res2, contexts, prep_args) = deepcopy(scen) res1_sim, res2_sim = mysimilar(res1), mysimilar(res2) - prep = $prep_op(f, ba, x, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, contexts...) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + (subset == :full) && + test_noallocs(skip, $prep_op, f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && test_noallocs(skip, $op!, f, res2_sim, ba, x, contexts...) (subset == :full) && test_noallocs(skip, $val_and_op!, f, res1_sim, res2_sim, ba, x, contexts...) @@ -127,70 +135,91 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool ) - (; f, x, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, tang, contexts...) - (subset == :full) && test_noallocs(skip, $op, f, ba, x, tang, contexts...) - (subset == :full) && - test_noallocs(skip, $val_and_op, f, ba, x, tang, contexts...) - (subset != :none) && test_noallocs(skip, $op, f, prep, ba, x, tang, contexts...) + (; f, x, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + (subset == :full) && test_noallocs( + skip, $prep_op, f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) + (subset == :full) && test_noallocs(skip, $op, f, ba, x, t, contexts...) + (subset == :full) && test_noallocs(skip, $val_and_op, f, ba, x, t, contexts...) + (subset != :none) && test_noallocs(skip, $op, f, prep, ba, x, t, contexts...) (subset != :none) && - test_noallocs(skip, $val_and_op, f, prep, ba, x, tang, contexts...) + test_noallocs(skip, $val_and_op, f, prep, ba, x, t, contexts...) return nothing end @eval function test_alloccheck( ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool ) - (; f, x, tang, res1, contexts) = deepcopy(scen) + (; f, x, t, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) - prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, tang, contexts...) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + (subset == :full) && test_noallocs( + skip, $prep_op, f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && - test_noallocs(skip, $op!, f, res1_sim, ba, x, tang, contexts...) + test_noallocs(skip, $op!, f, res1_sim, ba, x, t, contexts...) (subset == :full) && - test_noallocs(skip, $val_and_op!, f, res1_sim, ba, x, tang, contexts...) + test_noallocs(skip, $val_and_op!, f, res1_sim, ba, x, t, contexts...) (subset != :none) && - test_noallocs(skip, $op!, f, res1_sim, prep, ba, x, tang, contexts...) - (subset != :none) && test_noallocs( - skip, $val_and_op!, f, res1_sim, prep, ba, x, tang, contexts... - ) + test_noallocs(skip, $op!, f, res1_sim, prep, ba, x, t, contexts...) + (subset != :none) && + test_noallocs(skip, $val_and_op!, f, res1_sim, prep, ba, x, t, contexts...) return nothing end @eval function test_alloccheck( ba::AbstractADType, scen::$S2out; subset::Symbol, skip::Bool ) - (; f, x, y, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, tang, contexts...) - (subset == :full) && - test_noallocs(skip, $prep_op, f, y, ba, x, tang, contexts...) - (subset == :full) && test_noallocs(skip, $op, f, y, ba, x, tang, contexts...) + (; f, x, y, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) + (subset == :full) && test_noallocs( + skip, + $prep_op, + f, + pep_args.y, + ba, + pep_args.x, + pep_args.t, + pep_args.contexts..., + ) + (subset == :full) && test_noallocs(skip, $op, f, y, ba, x, t, contexts...) (subset == :full) && - test_noallocs(skip, $val_and_op, f, y, ba, x, tang, contexts...) + test_noallocs(skip, $val_and_op, f, y, ba, x, t, contexts...) + (subset != :none) && test_noallocs(skip, $op, f, y, prep, ba, x, t, contexts...) (subset != :none) && - test_noallocs(skip, $op, f, y, prep, ba, x, tang, contexts...) - (subset != :none) && - test_noallocs(skip, $val_and_op, f, y, prep, ba, x, tang, contexts...) + test_noallocs(skip, $val_and_op, f, y, prep, ba, x, t, contexts...) return nothing end @eval function test_alloccheck( ba::AbstractADType, scen::$S2in; subset::Symbol, skip::Bool ) - (; f, x, y, tang, res1, contexts) = deepcopy(scen) + (; f, x, y, t, res1, contexts, prep_args) = deepcopy(scen) res1_sim = mysimilar(res1) - prep = $prep_op(f, y, ba, x, tang, contexts...) - (subset == :full) && - test_noallocs(skip, $prep_op, f, y, ba, x, tang, contexts...) + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) + (subset == :full) && test_noallocs( + skip, + $prep_op, + f, + prep_args.y, + ba, + prep_args.x, + prep_args.t, + prep_args.contexts..., + ) (subset == :full) && - test_noallocs(skip, $op!, f, y, res1_sim, ba, x, tang, contexts...) + test_noallocs(skip, $op!, f, y, res1_sim, ba, x, t, contexts...) (subset == :full) && - test_noallocs(skip, $val_and_op!, f, y, res1_sim, ba, x, tang, contexts...) + test_noallocs(skip, $val_and_op!, f, y, res1_sim, ba, x, t, contexts...) (subset != :none) && - test_noallocs(skip, $op!, f, y, res1_sim, prep, ba, x, tang, contexts...) + test_noallocs(skip, $op!, f, y, res1_sim, prep, ba, x, t, contexts...) (subset != :none) && test_noallocs( - skip, $val_and_op!, f, y, res1_sim, prep, ba, x, tang, contexts... + skip, $val_and_op!, f, y, res1_sim, prep, ba, x, t, contexts... ) return nothing end @@ -199,43 +228,37 @@ for op in ALL_OPS @eval function test_alloccheck( ba::AbstractADType, scen::$S1out; subset::Symbol, skip::Bool ) - (; f, x, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, tang, contexts...) - (subset == :full) && - test_noallocs(skip, $val_and_op, f, ba, x, tang, contexts...) - (subset == :full) && test_noallocs(skip, $op, f, ba, x, tang, contexts...) + (; f, x, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + (subset == :full) && test_noallocs( + skip, $prep_op, f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) + (subset == :full) && test_noallocs(skip, $val_and_op, f, ba, x, t, contexts...) + (subset == :full) && test_noallocs(skip, $op, f, ba, x, t, contexts...) (subset != :none) && - test_noallocs(skip, $val_and_op, f, prep, ba, x, tang, contexts...) - (subset != :none) && test_noallocs(skip, $op, f, prep, ba, x, tang, contexts...) + test_noallocs(skip, $val_and_op, f, prep, ba, x, t, contexts...) + (subset != :none) && test_noallocs(skip, $op, f, prep, ba, x, t, contexts...) return nothing end @eval function test_alloccheck( ba::AbstractADType, scen::$S1in; subset::Symbol, skip::Bool ) - (; f, x, tang, res1, res2, contexts) = deepcopy(scen) + (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) res1_sim, res2_sim = mysimilar(res1), mysimilar(res2) - prep = $prep_op(f, ba, x, tang, contexts...) - (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, tang, contexts...) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + (subset == :full) && test_noallocs( + skip, $prep_op, f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && - test_noallocs(skip, $op!, f, res2_sim, ba, x, tang, contexts...) + test_noallocs(skip, $op!, f, res2_sim, ba, x, t, contexts...) (subset == :full) && test_noallocs( - skip, $val_and_op!, f, res1_sim, res2_sim, ba, x, tang, contexts... + skip, $val_and_op!, f, res1_sim, res2_sim, ba, x, t, contexts... ) (subset != :none) && - test_noallocs(skip, $op!, f, res2_sim, prep, ba, x, tang, contexts...) + test_noallocs(skip, $op!, f, res2_sim, prep, ba, x, t, contexts...) (subset != :none) && test_noallocs( - skip, - $val_and_op!, - f, - res1_sim, - res2_sim, - prep, - ba, - x, - tang, - contexts..., + skip, $val_and_op!, f, res1_sim, res2_sim, prep, ba, x, t, contexts... ) return nothing end diff --git a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl index 9a7c7e8ee..64464ba2c 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl @@ -137,12 +137,13 @@ for op in ALL_OPS if op in [:derivative, :gradient, :jacobian] @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) + (; f, x, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prepared_valop = @be prep $val_and_op(f, _, ba, x, contexts...) seconds = s prepared_op = @be prep $op(f, _, ba, x, contexts...) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, contexts...) seconds = s + preparation = @be $prep_op(f, ba, prep_args.x, prep_args.contexts...) seconds = + s unprepared_valop = @be $val_and_op(f, ba, x, contexts...) seconds = s unprepared_op = @be $op(f, ba, x, contexts...) seconds = s return BenchmarkResult(; @@ -158,9 +159,9 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, contexts) = deepcopy(scen) + (; f, x, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.contexts...) preparation = reset_count!(cc) $val_and_op(cc, prep, ba, x, contexts...) prepared_valop = reset_count!(cc) @@ -176,8 +177,8 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, res1, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) + (; f, x, res1, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prepared_valop = @be (mysimilar(res1), prep) $val_and_op!( f, _[1], _[2], ba, x, contexts... ) seconds = s @@ -185,7 +186,8 @@ for op in ALL_OPS f, _[1], _[2], ba, x, contexts... ) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, contexts...) seconds = s + preparation = @be $prep_op(f, ba, prep_args.x, prep_args.contexts...) seconds = + s unprepared_valop = @be mysimilar(res1) $val_and_op!( f, _, ba, x, contexts... ) seconds = s @@ -204,9 +206,9 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, res1, contexts) = deepcopy(scen) + (; f, x, res1, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.contexts...) preparation = reset_count!(cc) $val_and_op!(cc, mysimilar(res1), prep, ba, x, contexts...) prepared_valop = reset_count!(cc) @@ -224,13 +226,15 @@ for op in ALL_OPS op == :gradient && continue @eval function benchmark_aux(ba::AbstractADType, scen::$S2out; subset::Symbol, s) - (; f, x, y, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, contexts...) + (; f, x, y, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) prepared_valop = @be (y, prep) $val_and_op(f, _[1], _[2], ba, x, contexts...) seconds = s prepared_op = @be (y, prep) $op(f, _[1], _[2], ba, x, contexts...) seconds = s if subset == :full - preparation = @be $prep_op(f, y, ba, x, contexts...) seconds = s + preparation = @be $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) seconds = s unprepared_valop = @be y $val_and_op(f, _, ba, x, contexts...) seconds = s unprepared_op = @be y $op(f, _, ba, x, contexts...) seconds = s return BenchmarkResult(; @@ -246,9 +250,9 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S2out; subset::Symbol, s) - (; f, x, y, contexts) = deepcopy(scen) + (; f, x, y, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, y, ba, x, contexts...) + prep = $prep_op(cc, prep_args.y, ba, prep_args.x, prep_args.contexts...) preparation = reset_count!(cc) $val_and_op(cc, y, prep, ba, x, contexts...) prepared_valop = reset_count!(cc) @@ -264,8 +268,8 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S2in; subset::Symbol, s) - (; f, x, y, res1, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, contexts...) + (; f, x, y, res1, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) prepared_valop = @be (y, mysimilar(res1), prep) $val_and_op!( f, _[1], _[2], _[3], ba, x, contexts... ) seconds = s @@ -273,7 +277,9 @@ for op in ALL_OPS f, _[1], _[2], _[3], ba, x, contexts... ) seconds = s if subset == :full - preparation = @be $prep_op(f, y, ba, x, contexts...) seconds = s + preparation = @be $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) seconds = s unprepared_valop = @be (y, mysimilar(res1)) $val_and_op!( f, _[1], _[2], ba, x, contexts... ) seconds = s @@ -293,9 +299,9 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S2in; subset::Symbol, s) - (; f, x, y, res1, contexts) = deepcopy(scen) + (; f, x, y, res1, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, y, ba, x, contexts...) + prep = $prep_op(cc, prep_args.y, ba, prep_args.x, prep_args.contexts...) preparation = reset_count!(cc) $val_and_op!(cc, y, mysimilar(res1), prep, ba, x, contexts...) prepared_valop = reset_count!(cc) @@ -312,12 +318,13 @@ for op in ALL_OPS elseif op in [:hessian, :second_derivative] @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) + (; f, x, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prepared_valop = @be prep $val_and_op(f, _, ba, x, contexts...) seconds = s prepared_op = @be prep $op(f, _, ba, x, contexts...) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, contexts...) seconds = s + preparation = @be $prep_op(f, ba, prep_args.x, prep_args.contexts...) seconds = + s unprepared_valop = @be $val_and_op(f, ba, x, contexts...) seconds = s unprepared_op = @be $op(f, ba, x, contexts...) seconds = s return BenchmarkResult(; @@ -333,9 +340,9 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, contexts) = deepcopy(scen) + (; f, x, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.contexts...) preparation = reset_count!(cc) $val_and_op(cc, prep, ba, x, contexts...) prepared_valop = reset_count!(cc) @@ -351,9 +358,8 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, res1, res2, contexts) = deepcopy(scen) - - prep = $prep_op(f, ba, x, contexts...) + (; f, x, res1, res2, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) prepared_valop = @be (mysimilar(res1), mysimilar(res2), prep) $val_and_op!( f, _[1], _[2], _[3], ba, x, contexts... ) seconds = s @@ -361,7 +367,8 @@ for op in ALL_OPS f, _[1], _[2], ba, x, contexts... ) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, contexts...) seconds = s + preparation = @be $prep_op(f, ba, prep_args.x, prep_args.contexts...) seconds = + s unprepared_valop = @be (mysimilar(res1), mysimilar(res2)) $val_and_op!( f, _[1], _[2], ba, x, contexts... ) seconds = s @@ -380,9 +387,9 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, res1, res2, contexts) = deepcopy(scen) + (; f, x, res1, res2, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.contexts...) preparation = reset_count!(cc) $val_and_op!(cc, mysimilar(res1), mysimilar(res2), prep, ba, x, contexts...) prepared_valop = reset_count!(cc) @@ -399,15 +406,16 @@ for op in ALL_OPS elseif op in [:pushforward, :pullback] @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) - prepared_valop = @be prep $val_and_op(f, _, ba, x, tang, contexts...) seconds = - s - prepared_op = @be prep $op(f, _, ba, x, tang, contexts...) seconds = s + (; f, x, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + prepared_valop = @be prep $val_and_op(f, _, ba, x, t, contexts...) seconds = s + prepared_op = @be prep $op(f, _, ba, x, t, contexts...) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, tang, contexts...) seconds = s - unprepared_valop = @be $val_and_op(f, ba, x, tang, contexts...) seconds = s - unprepared_op = @be $op(f, ba, x, tang, contexts...) seconds = s + preparation = @be $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) seconds = s + unprepared_valop = @be $val_and_op(f, ba, x, t, contexts...) seconds = s + unprepared_op = @be $op(f, ba, x, t, contexts...) seconds = s return BenchmarkResult(; prepared_valop, prepared_op, @@ -420,17 +428,17 @@ for op in ALL_OPS end end @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, tang, contexts) = deepcopy(scen) + (; f, x, t, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, tang, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.t, prep_args.contexts...) preparation = reset_count!(cc) - $val_and_op(cc, prep, ba, x, tang, contexts...) + $val_and_op(cc, prep, ba, x, t, contexts...) prepared_valop = reset_count!(cc) - $op(cc, prep, ba, x, tang, contexts...) + $op(cc, prep, ba, x, t, contexts...) prepared_op = reset_count!(cc) - $val_and_op(cc, ba, x, tang, contexts...) + $val_and_op(cc, ba, x, t, contexts...) unprepared_valop = reset_count!(cc) - $op(cc, ba, x, tang, contexts...) + $op(cc, ba, x, t, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op @@ -438,20 +446,22 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, tang, res1, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) + (; f, x, t, res1, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) prepared_valop = @be (mysimilar(res1), prep) $val_and_op!( - f, _[1], _[2], ba, x, tang, contexts... + f, _[1], _[2], ba, x, t, contexts... ) seconds = s prepared_op = @be (mysimilar(res1), prep) $op!( - f, _[1], _[2], ba, x, tang, contexts... + f, _[1], _[2], ba, x, t, contexts... ) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, tang, contexts...) seconds = s + preparation = @be $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) seconds = s unprepared_valop = @be mysimilar(res1) $val_and_op!( - f, _, ba, x, tang, contexts... + f, _, ba, x, t, contexts... ) seconds = s - unprepared_op = @be mysimilar(res1) $op!(f, _, ba, x, tang, contexts...) seconds = + unprepared_op = @be mysimilar(res1) $op!(f, _, ba, x, t, contexts...) seconds = s return BenchmarkResult(; prepared_valop, @@ -466,17 +476,17 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, tang, res1, contexts) = deepcopy(scen) + (; f, x, t, res1, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, tang, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.t, prep_args.contexts...) preparation = reset_count!(cc) - $val_and_op!(cc, mysimilar(res1), prep, ba, x, tang, contexts...) + $val_and_op!(cc, mysimilar(res1), prep, ba, x, t, contexts...) prepared_valop = reset_count!(cc) - $op!(cc, mysimilar(res1), prep, ba, x, tang, contexts...) + $op!(cc, mysimilar(res1), prep, ba, x, t, contexts...) prepared_op = reset_count!(cc) - $val_and_op!(cc, mysimilar(res1), ba, x, tang, contexts...) + $val_and_op!(cc, mysimilar(res1), ba, x, t, contexts...) unprepared_valop = reset_count!(cc) - $op!(cc, mysimilar(res1), ba, x, tang, contexts...) + $op!(cc, mysimilar(res1), ba, x, t, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op @@ -484,17 +494,19 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S2out; subset::Symbol, s) - (; f, x, y, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, tang, contexts...) - prepared_valop = @be (y, prep) $val_and_op( - f, _[1], _[2], ba, x, tang, contexts... + (; f, x, y, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... ) - prepared_op = @be (y, prep) $op(f, _[1], _[2], ba, x, tang, contexts...) + prepared_valop = @be (y, prep) $val_and_op(f, _[1], _[2], ba, x, t, contexts...) + prepared_op = @be (y, prep) $op(f, _[1], _[2], ba, x, t, contexts...) if subset == :full - preparation = @be $prep_op(f, y, ba, x, tang, contexts...) seconds = s - unprepared_valop = @be y $val_and_op(f, _, ba, x, tang, contexts...) seconds = + preparation = @be $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) seconds = s + unprepared_valop = @be y $val_and_op(f, _, ba, x, t, contexts...) seconds = s - unprepared_op = @be y $op(f, _, ba, x, tang, contexts...) seconds = s + unprepared_op = @be y $op(f, _, ba, x, t, contexts...) seconds = s return BenchmarkResult(; prepared_valop, prepared_op, @@ -508,17 +520,19 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S2out; subset::Symbol, s) - (; f, x, y, tang, contexts) = deepcopy(scen) + (; f, x, y, t, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, y, ba, x, tang, contexts...) + prep = $prep_op( + cc, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) preparation = reset_count!(cc) - $val_and_op(cc, y, prep, ba, x, tang, contexts...) + $val_and_op(cc, y, prep, ba, x, t, contexts...) prepared_valop = reset_count!(cc) - $op(cc, y, prep, ba, x, tang, contexts...) + $op(cc, y, prep, ba, x, t, contexts...) prepared_op = reset_count!(cc) - $val_and_op(cc, y, ba, x, tang, contexts...) + $val_and_op(cc, y, ba, x, t, contexts...) unprepared_valop = reset_count!(cc) - $op(cc, y, ba, x, tang, contexts...) + $op(cc, y, ba, x, t, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op @@ -526,21 +540,25 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S2in; subset::Symbol, s) - (; f, x, y, tang, res1, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, tang, contexts...) + (; f, x, y, t, res1, contexts, prep_args) = deepcopy(scen) + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) prepared_valop = @be (y, mysimilar(res1), prep) $val_and_op!( - f, _[1], _[2], _[3], ba, x, tang, contexts... + f, _[1], _[2], _[3], ba, x, t, contexts... ) seconds = s prepared_op = @be (y, mysimilar(res1), prep) $op!( - f, _[1], _[2], _[3], ba, x, tang, contexts... + f, _[1], _[2], _[3], ba, x, t, contexts... ) seconds = s if subset == :full - preparation = @be $prep_op(f, y, ba, x, tang, contexts...) seconds = s + preparation = @be $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) seconds = s unprepared_valop = @be (y, mysimilar(res1)) $val_and_op!( - f, _[1], _[2], ba, x, tang, contexts... + f, _[1], _[2], ba, x, t, contexts... ) seconds = s unprepared_op = @be (y, mysimilar(res1)) $op!( - f, _[1], _[2], ba, x, tang, contexts... + f, _[1], _[2], ba, x, t, contexts... ) seconds = s return BenchmarkResult(; prepared_valop, @@ -555,17 +573,19 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S2in; subset::Symbol, s) - (; f, x, y, tang, res1, contexts) = deepcopy(scen) + (; f, x, y, t, res1, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, y, ba, x, tang, contexts...) + prep = $prep_op( + cc, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) preparation = reset_count!(cc) - $val_and_op!(cc, y, mysimilar(res1), prep, ba, x, tang, contexts...) + $val_and_op!(cc, y, mysimilar(res1), prep, ba, x, t, contexts...) prepared_valop = reset_count!(cc) - $op!(cc, y, mysimilar(res1), prep, ba, x, tang, contexts...) + $op!(cc, y, mysimilar(res1), prep, ba, x, t, contexts...) prepared_op = reset_count!(cc) - $val_and_op!(cc, y, mysimilar(res1), ba, x, tang, contexts...) + $val_and_op!(cc, y, mysimilar(res1), ba, x, t, contexts...) unprepared_valop = reset_count!(cc) - $op!(cc, y, mysimilar(res1), ba, x, tang, contexts...) + $op!(cc, y, mysimilar(res1), ba, x, t, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op @@ -574,15 +594,16 @@ for op in ALL_OPS elseif op in [:hvp] @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) - prepared_valop = @be prep $val_and_op(f, _, ba, x, tang, contexts...) seconds = - s - prepared_op = @be prep $op(f, _, ba, x, tang, contexts...) seconds = s + (; f, x, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + prepared_valop = @be prep $val_and_op(f, _, ba, x, t, contexts...) seconds = s + prepared_op = @be prep $op(f, _, ba, x, t, contexts...) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, tang, contexts...) seconds = s - unprepared_valop = @be $val_and_op(f, ba, x, tang, contexts...) seconds = s - unprepared_op = @be $op(f, ba, x, tang, contexts...) seconds = s + preparation = @be $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) seconds = s + unprepared_valop = @be $val_and_op(f, ba, x, t, contexts...) seconds = s + unprepared_op = @be $op(f, ba, x, t, contexts...) seconds = s return BenchmarkResult(; prepared_valop, prepared_op, @@ -596,17 +617,17 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1out; subset::Symbol, s) - (; f, x, tang, contexts) = deepcopy(scen) + (; f, x, t, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, tang, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.t, prep_args.contexts...) preparation = reset_count!(cc) - $val_and_op(cc, prep, ba, x, tang, contexts...) + $val_and_op(cc, prep, ba, x, t, contexts...) prepared_valop = reset_count!(cc) - $op(cc, prep, ba, x, tang, contexts...) + $op(cc, prep, ba, x, t, contexts...) prepared_op = reset_count!(cc) - $val_and_op(cc, ba, x, tang, contexts...) + $val_and_op(cc, ba, x, t, contexts...) unprepared_valop = reset_count!(cc) - $op(cc, ba, x, tang, contexts...) + $op(cc, ba, x, t, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op @@ -614,20 +635,22 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, tang, res1, res2, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) + (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) prepared_valop = @be (mysimilar(res1), mysimilar(res2), prep) $val_and_op!( - f, _[1], _[2], _[3], ba, x, tang, contexts... + f, _[1], _[2], _[3], ba, x, t, contexts... ) seconds = s prepared_op = @be (mysimilar(res2), prep) $op!( - f, _[1], _[2], ba, x, tang, contexts... + f, _[1], _[2], ba, x, t, contexts... ) seconds = s if subset == :full - preparation = @be $prep_op(f, ba, x, tang, contexts...) seconds = s + preparation = @be $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) seconds = s unprepared_valop = @be (mysimilar(res1), mysimilar(res2)) $val_and_op!( - f, _[1], _[2], ba, x, tang, contexts... + f, _[1], _[2], ba, x, t, contexts... ) seconds = s - unprepared_op = @be mysimilar(res2) $op!(f, _, ba, x, tang, contexts...) seconds = + unprepared_op = @be mysimilar(res2) $op!(f, _, ba, x, t, contexts...) seconds = s return BenchmarkResult(; prepared_valop, @@ -642,19 +665,17 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol, s) - (; f, x, tang, res1, res2, contexts) = deepcopy(scen) + (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) cc = CallCounter(f) - prep = $prep_op(cc, ba, x, tang, contexts...) + prep = $prep_op(cc, ba, prep_args.x, prep_args.t, prep_args.contexts...) preparation = reset_count!(cc) - $val_and_op!( - cc, mysimilar(res1), mysimilar(res2), prep, ba, x, tang, contexts... - ) + $val_and_op!(cc, mysimilar(res1), mysimilar(res2), prep, ba, x, t, contexts...) prepared_valop = reset_count!(cc) - $op!(cc, mysimilar(res2), prep, ba, x, tang, contexts...) + $op!(cc, mysimilar(res2), prep, ba, x, t, contexts...) prepared_op = reset_count!(cc) - $val_and_op!(cc, mysimilar(res1), mysimilar(res2), ba, x, tang, contexts...) + $val_and_op!(cc, mysimilar(res1), mysimilar(res2), ba, x, t, contexts...) unprepared_valop = reset_count!(cc) - $op!(cc, mysimilar(res2), ba, x, tang, contexts...) + $op!(cc, mysimilar(res2), ba, x, t, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; prepared_valop, prepared_op, preparation, unprepared_valop, unprepared_op diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 8b183059e..e7e0d49a2 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -50,30 +50,12 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand = myrandom(x) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, contextsrand...) - prepstrict = $prep_op!( - f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.contexts...; - strict=Val(true), - ), - ba, - xrand, - contextsrand..., + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + prepstrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) [(), (prep,), (prepstrict,)] end @@ -117,30 +99,12 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand = myrandom(x) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, contextsrand...) - prepstrict = $prep_op!( - f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.contexts...; - strict=Val(true), - ), - ba, - xrand, - contextsrand..., + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + prepstrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) [(), (prep,), (prepstrict,)] end @@ -200,32 +164,17 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand, yrand = myrandom(x), myrandom(y) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) + prepstrict = $prep_op( f, - copy(yrand), - $prep_op( - new_smaller.f, - copy(new_smaller.y), - ba, - new_smaller.x, - new_smaller.contexts...; - strict=Val(true), - ), + prep_args.y, ba, - xrand, - contextsrand..., + prep_args.x, + prep_args.contexts...; + strict=Val(true), ) [(), (prep,), (prepstrict,)] end @@ -277,32 +226,17 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand, yrand = myrandom(x), myrandom(y) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) + prepstrict = $prep_op( f, - copy(yrand), - $prep_op( - new_smaller.f, - copy(new_smaller.y), - ba, - new_smaller.x, - new_smaller.contexts...; - strict=Val(true), - ), + prep_args.y, ba, - xrand, - contextsrand..., + prep_args.x, + prep_args.contexts...; + strict=Val(true), ) [(), (prep,), (prepstrict,)] end @@ -365,30 +299,12 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, res1, res2, contexts, smaller) = new_scen = deepcopy(scen) - xrand = myrandom(x) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, contextsrand...) - prepstrict = $prep_op!( - f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.contexts...; - strict=Val(true), - ), - ba, - xrand, - contextsrand..., + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + prepstrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) [(), (prep,), (prepstrict,)] end @@ -434,30 +350,12 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, res1, res2, contexts, smaller) = new_scen = deepcopy(scen) - xrand = myrandom(x) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, contextsrand...) - prepstrict = $prep_op!( - f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.contexts...; - strict=Val(true), - ), - ba, - xrand, - contextsrand..., + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) + prepstrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) [(), (prep,), (prepstrict,)] end @@ -520,45 +418,30 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand, tangrand = myrandom(x), myrandom(tang) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + prepstrict = $prep_op( f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.tang, - new_smaller.contexts...; - strict=Val(true), - ), ba, - xrand, - tangrand, - contextsrand..., + prep_args.x, + prep_args.t, + prep_args.contexts...; + strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( - f, preptup_val..., ba, x, tang, contexts... + f, preptup_val..., ba, x, t, contexts... ) y_out2_val, res1_out2_val = $val_and_op( - f, preptup_val..., ba, x, tang, contexts... + f, preptup_val..., ba, x, t, contexts... ) - res1_out1_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) - res1_out2_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) + res1_out1_noval = $op(f, preptup_noval..., ba, x, t, contexts...) + res1_out2_noval = $op(f, preptup_noval..., ba, x, t, contexts...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @@ -571,8 +454,8 @@ for op in ALL_OPS end end end - @test_throws PME $val_and_op(nothing, prepstrict, ba, x, tang, contexts...) - @test_throws PME $op(nothing, prepstrict, ba, x, tang, contexts...) + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, t, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, t, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -586,34 +469,19 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand, tangrand = myrandom(x), myrandom(tang) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + prepstrict = $prep_op( f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.tang, - new_smaller.contexts...; - strict=Val(true), - ), ba, - xrand, - tangrand, - contextsrand..., + prep_args.x, + prep_args.t, + prep_args.contexts...; + strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -622,38 +490,38 @@ for op in ALL_OPS res1_in1_noval = mysimilar(res1) res1_in2_noval = mysimilar(res1) y_out1_val, res1_out1_val = $val_and_op!( - f, res1_in1_val, preptup_val..., ba, x, tang, contexts... + f, res1_in1_val, preptup_val..., ba, x, t, contexts... ) y_out2_val, res1_out2_val = $val_and_op!( - f, res1_in2_val, preptup_val..., ba, x, tang, contexts... + f, res1_in2_val, preptup_val..., ba, x, t, contexts... ) res1_out1_noval = $op!( - f, res1_in1_noval, preptup_noval..., ba, x, tang, contexts... + f, res1_in1_noval, preptup_noval..., ba, x, t, contexts... ) res1_out2_noval = $op!( - f, res1_in2_noval, preptup_noval..., ba, x, tang, contexts... + f, res1_in2_noval, preptup_noval..., ba, x, t, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y + @test res1_in1_val === res1_out1_val + @test res1_in2_val === res1_out2_val + @test res1_in1_noval === res1_out1_noval + @test res1_in2_noval === res1_out2_noval for b in eachindex(scen.res1) - @test res1_in1_val[b] === res1_out1_val[b] - @test res1_in2_val[b] === res1_out2_val[b] @test res1_out1_val[b] ≈ scen.res1[b] @test res1_out2_val[b] ≈ scen.res1[b] - @test res1_in1_noval[b] === res1_out1_noval[b] - @test res1_in2_noval[b] === res1_out2_noval[b] @test res1_out1_noval[b] ≈ scen.res1[b] @test res1_out2_noval[b] ≈ scen.res1[b] end end end @test_throws PME $val_and_op!( - nothing, mysimilar(res1), prepstrict, ba, x, tang, contexts... + nothing, mysimilar(res1), prepstrict, ba, x, t, contexts... ) @test_throws PME $op!( - nothing, mysimilar(res1), prepstrict, ba, x, tang, contexts... + nothing, mysimilar(res1), prepstrict, ba, x, t, contexts... ) scenario_intact && @test new_scen == scen return nothing @@ -668,36 +536,22 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) + prepstrict = $prep_op( f, - copy(yrand), - $prep_op( - new_smaller.f, - copy(new_smaller.y), - ba, - new_smaller.x, - new_smaller.tang, - new_smaller.contexts...; - strict=Val(true), - ), + prep_args.y, ba, - xrand, - tangrand, - contextsrand..., + prep_args.x, + prep_args.t, + prep_args.contexts...; + strict=Val(true), ) - prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) + prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...) [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -706,16 +560,16 @@ for op in ALL_OPS y_in1_noval = mysimilar(y) y_in2_noval = mysimilar(y) y_out1_val, res1_out1_val = $val_and_op( - f, y_in1_val, preptup_val..., ba, x, tang, contexts... + f, y_in1_val, preptup_val..., ba, x, t, contexts... ) y_out2_val, res1_out2_val = $val_and_op( - f, y_in2_val, preptup_val..., ba, x, tang, contexts... + f, y_in2_val, preptup_val..., ba, x, t, contexts... ) res1_out1_noval = $op( - f, y_in1_noval, preptup_noval..., ba, x, tang, contexts... + f, y_in1_noval, preptup_noval..., ba, x, t, contexts... ) res1_out2_noval = $op( - f, y_in2_noval, preptup_noval..., ba, x, tang, contexts... + f, y_in2_noval, preptup_noval..., ba, x, t, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @@ -732,11 +586,9 @@ for op in ALL_OPS end end @test_throws PME $val_and_op( - nothing, mysimilar(y), prepstrict, ba, x, tang, contexts... - ) - @test_throws PME $op( - nothing, mysimilar(y), prepstrict, ba, x, tang, contexts... + nothing, mysimilar(y), prepstrict, ba, x, t, contexts... ) + @test_throws PME $op(nothing, mysimilar(y), prepstrict, ba, x, t, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -750,36 +602,22 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res1, contexts, smaller) = new_scen = deepcopy(scen) - xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) + prepstrict = $prep_op( f, - copy(yrand), - $prep_op( - new_smaller.f, - copy(new_smaller.y), - ba, - new_smaller.x, - new_smaller.tang, - new_smaller.contexts...; - strict=Val(true), - ), + prep_args.y, ba, - xrand, - tangrand, - contextsrand..., + prep_args.x, + prep_args.t, + prep_args.contexts...; + strict=Val(true), ) - prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) + prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...) [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -788,30 +626,16 @@ for op in ALL_OPS y_in1_noval, res1_in1_noval = mysimilar(y), mysimilar(res1) y_in2_noval, res1_in2_noval = mysimilar(y), mysimilar(res1) y_out1_val, res1_out1_val = $val_and_op!( - f, y_in1_val, res1_in1_val, preptup_val..., ba, x, tang, contexts... + f, y_in1_val, res1_in1_val, preptup_val..., ba, x, t, contexts... ) y_out2_val, res1_out2_val = $val_and_op!( - f, y_in2_val, res1_in2_val, preptup_val..., ba, x, tang, contexts... + f, y_in2_val, res1_in2_val, preptup_val..., ba, x, t, contexts... ) res1_out1_noval = $op!( - f, - y_in1_noval, - res1_in1_noval, - preptup_noval..., - ba, - x, - tang, - contexts..., + f, y_in1_noval, res1_in1_noval, preptup_noval..., ba, x, t, contexts... ) res1_out2_noval = $op!( - f, - y_in2_noval, - res1_in2_noval, - preptup_noval..., - ba, - x, - tang, - contexts..., + f, y_in2_noval, res1_in2_noval, preptup_noval..., ba, x, t, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @@ -819,23 +643,23 @@ for op in ALL_OPS @test y_in2_val === y_out2_val @test y_out1_val ≈ scen.y @test y_out2_val ≈ scen.y + @test res1_in1_val === res1_out1_val + @test res1_in2_val === res1_out2_val + @test res1_in1_noval === res1_out1_noval + @test res1_in2_noval === res1_out2_noval for b in eachindex(scen.res1) - @test res1_in1_val[b] === res1_out1_val[b] - @test res1_in2_val[b] === res1_out2_val[b] @test res1_out1_val[b] ≈ scen.res1[b] @test res1_out2_val[b] ≈ scen.res1[b] - @test res1_in1_noval[b] === res1_out1_noval[b] - @test res1_in2_noval[b] === res1_out2_noval[b] @test res1_out1_noval[b] ≈ scen.res1[b] @test res1_out2_noval[b] ≈ scen.res1[b] end end end @test_throws PME $val_and_op!( - nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, tang, contexts... + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, t, contexts... ) @test_throws PME $op!( - nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, tang, contexts... + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, t, contexts... ) scenario_intact && @test new_scen == scen return nothing @@ -851,44 +675,29 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res1, res2, contexts, smaller) = new_scen = deepcopy(scen) - xrand, tangrand = myrandom(x), myrandom(tang) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + prepstrict = $prep_op( f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.tang, - new_smaller.contexts...; - strict=Val(true), - ), ba, - xrand, - tangrand, - contextsrand..., + prep_args.x, + prep_args.t, + prep_args.contexts...; + strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) - res2_out1_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) - res2_out2_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) + res2_out1_noval = $op(f, preptup_noval..., ba, x, t, contexts...) + res2_out2_noval = $op(f, preptup_noval..., ba, x, t, contexts...) res1_out1_val, res2_out1_val = $val_and_op( - f, preptup_noval..., ba, x, tang, contexts... + f, preptup_noval..., ba, x, t, contexts... ) res1_out2_val, res2_out2_val = $val_and_op( - f, preptup_noval..., ba, x, tang, contexts... + f, preptup_noval..., ba, x, t, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @@ -902,8 +711,8 @@ for op in ALL_OPS end end end - @test_throws PME $val_and_op(nothing, prepstrict, ba, x, tang, contexts...) - @test_throws PME $op(nothing, prepstrict, ba, x, tang, contexts...) + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, t, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, t, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -917,34 +726,19 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res1, res2, contexts, smaller) = new_scen = deepcopy(scen) - xrand, tangrand = myrandom(x), myrandom(tang) - rewrap = Rewrap(contexts...) - contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ - new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba - deepcopy(scen) - else - deepcopy(smaller) - end - prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepstrict = $prep_op!( + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) + prepstrict = $prep_op( f, - $prep_op( - new_smaller.f, - ba, - new_smaller.x, - new_smaller.tang, - new_smaller.contexts...; - strict=Val(true), - ), ba, - xrand, - tangrand, - contextsrand..., + prep_args.x, + prep_args.t, + prep_args.contexts...; + strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -953,30 +747,16 @@ for op in ALL_OPS res1_in1_val, res2_in1_val = mysimilar(res1), mysimilar(res2) res1_in2_val, res2_in2_val = mysimilar(res1), mysimilar(res2) res2_out1_noval = $op!( - f, res2_in1_noval, preptup_noval..., ba, x, tang, contexts... + f, res2_in1_noval, preptup_noval..., ba, x, t, contexts... ) res2_out2_noval = $op!( - f, res2_in2_noval, preptup_noval..., ba, x, tang, contexts... + f, res2_in2_noval, preptup_noval..., ba, x, t, contexts... ) res1_out1_val, res2_out1_val = $val_and_op!( - f, - res1_in1_val, - res2_in1_val, - preptup_noval..., - ba, - x, - tang, - contexts..., + f, res1_in1_val, res2_in1_val, preptup_noval..., ba, x, t, contexts... ) res1_out2_val, res2_out2_val = $val_and_op!( - f, - res1_in2_val, - res2_in2_val, - preptup_noval..., - ba, - x, - tang, - contexts..., + f, res1_in2_val, res2_in2_val, preptup_noval..., ba, x, t, contexts... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @@ -984,30 +764,23 @@ for op in ALL_OPS @test res1_in2_val === res1_out2_val @test res1_out1_val ≈ scen.res1 @test res1_out2_val ≈ scen.res1 + @test res2_in1_noval === res2_out1_noval + @test res2_in2_noval === res2_out2_noval + @test res2_in1_val === res2_out1_val + @test res2_in2_val === res2_out2_val for b in eachindex(scen.res2) - @test res2_in1_noval[b] === res2_out1_noval[b] - @test res2_in2_noval[b] === res2_out2_noval[b] @test res2_out1_noval[b] ≈ scen.res2[b] @test res2_out2_noval[b] ≈ scen.res2[b] - @test res2_in1_val[b] === res2_out1_val[b] - @test res2_in2_val[b] === res2_out2_val[b] @test res2_out1_val[b] ≈ scen.res2[b] @test res2_out2_val[b] ≈ scen.res2[b] end end end @test_throws PME $op!( - nothing, mysimilar(res2), prepstrict, ba, x, tang, contexts... + nothing, mysimilar(res2), prepstrict, ba, x, t, contexts... ) @test_throws PME $val_and_op!( - nothing, - mysimilar(res1), - mysimilar(res2), - prepstrict, - ba, - x, - tang, - contexts..., + nothing, mysimilar(res1), mysimilar(res2), prepstrict, ba, x, t, contexts... ) scenario_intact && @test new_scen == scen return nothing diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index bf02fb562..5219bf410 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -26,11 +26,11 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) + (; f, x, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op(f, ba, x, contexts...) @@ -53,11 +53,11 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, res1, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) + (; f, x, res1, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op!(f, mysimilar(res1), ba, x, contexts...) @@ -84,11 +84,13 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, y, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, contexts...) + (; f, x, y, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, y, ba, x, contexts...) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op(f, y, ba, x, contexts...) @@ -111,11 +113,13 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, y, res1, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, contexts...) + (; f, x, y, res1, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, y, ba, x, contexts...) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op!(f, y, mysimilar(res1), ba, x, contexts...) @@ -141,11 +145,11 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) + (; f, x, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op(f, ba, x, contexts...) @@ -168,11 +172,11 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, res1, res2, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, contexts...) + (; f, x, res1, res2, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, contexts...) + function_filter $prep_op(f, ba, prep_args.x, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op!(f, mysimilar(res2), ba, x, contexts...) @@ -200,23 +204,25 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) + (; f, x, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, tang, contexts...) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, ba, x, tang, contexts...) + function_filter $op(f, ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, ba, x, tang, contexts...) + function_filter $val_and_op(f, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, prep, ba, x, tang, contexts...) + function_filter $op(f, prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, prep, ba, x, tang, contexts...) + function_filter $val_and_op(f, prep, ba, x, t, contexts...) return nothing end @@ -227,26 +233,26 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, tang, res1, res2, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) + (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, tang, contexts...) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res1), ba, x, tang, contexts...) + function_filter $op!(f, mysimilar(res1), ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op!( - f, mysimilar(res1), ba, x, tang, contexts... - ) + function_filter $val_and_op!(f, mysimilar(res1), ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res1), prep, ba, x, tang, contexts...) + function_filter $op!(f, mysimilar(res1), prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $val_and_op!( - f, mysimilar(res1), prep, ba, x, tang, contexts... + f, mysimilar(res1), prep, ba, x, t, contexts... ) return nothing end @@ -258,23 +264,27 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, y, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, tang, contexts...) + (; f, x, y, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, y, ba, x, tang, contexts...) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, y, ba, x, tang, contexts...) + function_filter $op(f, y, ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, y, ba, x, tang, contexts...) + function_filter $val_and_op(f, y, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, y, prep, ba, x, tang, contexts...) + function_filter $op(f, y, prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, y, prep, ba, x, tang, contexts...) + function_filter $val_and_op(f, y, prep, ba, x, t, contexts...) return nothing end @@ -285,28 +295,30 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, y, tang, res1, contexts) = deepcopy(scen) - prep = $prep_op(f, y, ba, x, tang, contexts...) + (; f, x, y, t, res1, contexts, prep_args) = deepcopy(scen) + prep = $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, y, ba, x, tang, contexts...) + function_filter $prep_op( + f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, y, mysimilar(res1), ba, x, tang, contexts...) + function_filter $op!(f, y, mysimilar(res1), ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $val_and_op!( - f, y, mysimilar(res1), ba, x, tang, contexts... + f, y, mysimilar(res1), ba, x, t, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!( - f, y, mysimilar(res1), prep, ba, x, tang, contexts... - ) + function_filter $op!(f, y, mysimilar(res1), prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $val_and_op!( - f, y, mysimilar(res1), prep, ba, x, tang, contexts... + f, y, mysimilar(res1), prep, ba, x, t, contexts... ) return nothing end @@ -319,23 +331,25 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, tang, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) + (; f, x, t, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, tang, contexts...) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, ba, x, tang, contexts...) + function_filter $op(f, ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, ba, x, tang, contexts...) + function_filter $val_and_op(f, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op(f, prep, ba, x, tang, contexts...) + function_filter $op(f, prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $val_and_op(f, prep, ba, x, tang, contexts...) + function_filter $val_and_op(f, prep, ba, x, t, contexts...) return nothing end @@ -346,26 +360,28 @@ for op in ALL_OPS ignored_modules, function_filter, ) - (; f, x, tang, res1, res2, contexts) = deepcopy(scen) - prep = $prep_op(f, ba, x, tang, contexts...) + (; f, x, t, res1, res2, contexts, prep_args) = deepcopy(scen) + prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $prep_op(f, ba, x, tang, contexts...) + function_filter $prep_op( + f, ba, prep_args.x, prep_args.t, prep_args.contexts... + ) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res2), ba, x, tang, contexts...) + function_filter $op!(f, mysimilar(res2), ba, x, t, contexts...) (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $val_and_op!( - f, mysimilar(res1), mysimilar(res2), ba, x, tang, contexts... + f, mysimilar(res1), mysimilar(res2), ba, x, t, contexts... ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = - function_filter $op!(f, mysimilar(res2), prep, ba, x, tang, contexts...) + function_filter $op!(f, mysimilar(res2), prep, ba, x, t, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $val_and_op!( - f, mysimilar(res1), mysimilar(res2), prep, ba, x, tang, contexts... + f, mysimilar(res1), mysimilar(res2), prep, ba, x, t, contexts... ) return nothing end From 1de1253ebe1a18196b6f1eee27eb2d44b4480a45 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 9 May 2025 17:23:40 +0200 Subject: [PATCH 02/12] Fix --- DifferentiationInterface/test/Core/Internals/signature.jl | 4 ++-- DifferentiationInterfaceTest/CHANGELOG.md | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/test/Core/Internals/signature.jl b/DifferentiationInterface/test/Core/Internals/signature.jl index 2d28ce7dc..be2e144c1 100644 --- a/DifferentiationInterface/test/Core/Internals/signature.jl +++ b/DifferentiationInterface/test/Core/Internals/signature.jl @@ -98,7 +98,7 @@ end - exec: Nothing - backend: ✅ - x: ✅ - - tang: ✅ + - t: ✅ - contexts: ✅ """ pushforward(nothing, prep, backend, x, (x,), Constant(c)) end @@ -119,7 +119,7 @@ end - y: ✅ - backend: ✅ - x: ✅ - - tang: ✅ + - t: ✅ - contexts: ✅ """ pushforward(nothing, y, prep, backend, x, (x,)) end diff --git a/DifferentiationInterfaceTest/CHANGELOG.md b/DifferentiationInterfaceTest/CHANGELOG.md index e9e121dd1..8f336510f 100644 --- a/DifferentiationInterfaceTest/CHANGELOG.md +++ b/DifferentiationInterfaceTest/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Specify preparation arguments in DIT Scenario ([#786]) + ## [0.9.6] - 2025-03-28 ### Added @@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.6...main [0.9.6]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.5...DifferentiationInterfaceTest-v0.9.6 +[#786]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/786 [#749]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/749 [#748]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/748 [#745]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/745 \ No newline at end of file From fc7aabc84f98cf7fe485290aa981f2c8bc91bf85 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 9 May 2025 18:59:10 +0200 Subject: [PATCH 03/12] Fixes --- .../DifferentiationInterfaceTestFluxExt.jl | 14 ++++++++++++-- .../DifferentiationInterfaceTestJLArraysExt.jl | 3 ++- .../DifferentiationInterfaceTestLuxExt.jl | 4 ++++ .../DifferentiationInterfaceTestStaticArraysExt.jl | 3 ++- DifferentiationInterfaceTest/src/utils.jl | 8 -------- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index 040c4d6ba..0c0e6a623 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -162,7 +162,13 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) for (model, x) in models_and_xs Flux.trainmode!(model) g = gradient_finite_differences(square_loss, model, x) - scen = DIT.Scenario{:gradient,:out}(square_loss, model, DI.Constant(x); res1=g) + scen = DIT.Scenario{:gradient,:out}( + square_loss, + model, + DI.Constant(x); + prep_args=(; x=model, contexts=(DI.Constant(x),)), + res1=g, + ) push!(scens, scen) end @@ -189,7 +195,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) Flux.trainmode!(model) g = gradient_finite_differences(square_loss_iterated, model, x) scen = DIT.Scenario{:gradient,:out}( - square_loss_iterated, model, DI.Constant(x); res1=g + square_loss_iterated, + model, + DI.Constant(x); + prep_args=(; x=model, contexts=(DI.Constant(x),)), + res1=g, ) push!(scens, scen) end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index 16053fad5..00bdc359f 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -23,13 +23,14 @@ myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap myjl(::Nothing) = nothing function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, t, contexts, res1, res2) = scen + (; f, x, y, t, contexts, prep_args, res1, res2) = scen return DIT.Scenario{op,pl_op,pl_fun}(; f=myjl(f), x=myjl(x), y=myjl(y), t=myjl(t), contexts=myjl(contexts), + prep_args=map(myjl, prep_args), res1=myjl(res1), res2=myjl(res2), ) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index d5c05a40a..8119b073f 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -203,6 +203,10 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) DI.Constant(model), DI.Constant(x), DI.Constant(st); + prep_args=( + x=ComponentArray(ps), + contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)), + ), res1=g, ) push!(scens, scen) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index fa33c5818..63ca9dff7 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -36,13 +36,14 @@ end mystatic(::Nothing) = nothing function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, t, contexts, res1, res2) = scen + (; f, x, y, t, contexts, prep_args, res1, res2) = scen return DIT.Scenario{op,pl_op,pl_fun}(; f=mystatic(f), x=mystatic(x), y=pl_fun == :in ? mymutablestatic(y) : mystatic(y), t=mystatic(t), contexts=mystatic(contexts), + prep_args=map(mystatic, prep_args), res1=mystatic(res1), res2=mystatic(res2), ) diff --git a/DifferentiationInterfaceTest/src/utils.jl b/DifferentiationInterfaceTest/src/utils.jl index 81369e1ac..dbf1a2a1d 100644 --- a/DifferentiationInterfaceTest/src/utils.jl +++ b/DifferentiationInterfaceTest/src/utils.jl @@ -6,14 +6,6 @@ myzero(::Nothing) = nothing mysimilar(x::Number) = one(x) mysimilar(x::AbstractArray) = similar(x) mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x) -mysimilar(x) = deepcopy(x) - -myrandom(rng::AbstractRNG, x::Number) = randn(rng, typeof(x)) -myrandom(rng::AbstractRNG, x::AbstractArray) = map(Base.Fix1(myrandom, rng), x) -myrandom(rng::AbstractRNG, x::Union{Tuple,NamedTuple}) = map(Base.Fix1(myrandom, rng), x) -myrandom(rng::AbstractRNG, x) = deepcopy(x) - -myrandom(x) = myrandom(default_rng(), x) mysize(x::Number) = size(x) mysize(x::AbstractArray) = size(x) From d95953b6b503e14c04717b7f623c68d96814eba3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 9 May 2025 21:53:23 +0200 Subject: [PATCH 04/12] Fixes --- .../DifferentiationInterfaceTestJLArraysExt.jl | 3 ++- .../DifferentiationInterfaceTestStaticArraysExt.jl | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index 00bdc359f..9800889f9 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -23,7 +23,7 @@ myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap myjl(::Nothing) = nothing function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, t, contexts, prep_args, res1, res2) = scen + (; f, x, y, t, contexts, prep_args, res1, res2, name) = scen return DIT.Scenario{op,pl_op,pl_fun}(; f=myjl(f), x=myjl(x), @@ -33,6 +33,7 @@ function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} prep_args=map(myjl, prep_args), res1=myjl(res1), res2=myjl(res2), + name, ) end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index 63ca9dff7..6bb90d7e0 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -36,7 +36,7 @@ end mystatic(::Nothing) = nothing function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} - (; f, x, y, t, contexts, prep_args, res1, res2) = scen + (; f, x, y, t, contexts, prep_args, res1, res2, name) = scen return DIT.Scenario{op,pl_op,pl_fun}(; f=mystatic(f), x=mystatic(x), @@ -46,6 +46,7 @@ function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} prep_args=map(mystatic, prep_args), res1=mystatic(res1), res2=mystatic(res2), + name=name, ) end From 27fbc23e80d8950f4eb44a22af9b7309d5936054 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 9 May 2025 22:00:08 +0200 Subject: [PATCH 05/12] Fixes --- .../DifferentiationInterfaceTestFluxExt.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index 0c0e6a623..88e3ef490 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -92,7 +92,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) g = gradient_finite_differences(square_loss, model, x) scen = DIT.Scenario{:gradient,:out}( - square_loss, model; contexts=(DI.Constant(x),), res1=g + square_loss, + model, + DI.Constant(x); + prep_args=(x=model, contexts=(DI.Constant(x),)), + res1=g, ) push!(scens, scen) From 944355aac2279f563293655b62dfbed307d277ff Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 07:53:14 +0200 Subject: [PATCH 06/12] Fix static arrays --- .../DifferentiationInterfaceTestStaticArraysExt.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index 6bb90d7e0..9aac9fd44 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -3,7 +3,7 @@ module DifferentiationInterfaceTestStaticArraysExt import DifferentiationInterface as DI import DifferentiationInterfaceTest as DIT using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm -using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector +using StaticArrays: StaticArray, MArray, MMatrix, MVector, SArray, SMatrix, SVector static_num_to_vec(x::Number) = sin.(SVector(1, 2) .* x) static_num_to_mat(x::Number) = hcat(static_num_to_vec(x), static_num_to_vec(3x)) @@ -37,13 +37,19 @@ mystatic(::Nothing) = nothing function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f, x, y, t, contexts, prep_args, res1, res2, name) = scen + new_prep_args = (; + x=mystatic(prep_args.x), contexts=map(mystatic, prep_args.contexts), t=mystatic(t) + ) + if pl_fun == :in + new_prep_args = (; new_prep_args..., y=mymutablestatic(prep_args.y)) + end return DIT.Scenario{op,pl_op,pl_fun}(; f=mystatic(f), x=mystatic(x), y=pl_fun == :in ? mymutablestatic(y) : mystatic(y), t=mystatic(t), contexts=mystatic(contexts), - prep_args=map(mystatic, prep_args), + prep_args=new_prep_args, res1=mystatic(res1), res2=mystatic(res2), name=name, From 12d98ed9d8e5798433b80a7ebf47038b874fbbbe Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 08:21:31 +0200 Subject: [PATCH 07/12] Fix --- DifferentiationInterfaceTest/src/tests/allocs_eval.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl index 8ba00e2a9..f30454de3 100644 --- a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl @@ -179,11 +179,11 @@ for op in ALL_OPS skip, $prep_op, f, - pep_args.y, + prep_args.y, ba, - pep_args.x, - pep_args.t, - pep_args.contexts..., + prep_args.x, + prep_args.t, + prep_args.contexts..., ) (subset == :full) && test_noallocs(skip, $op, f, y, ba, x, t, contexts...) (subset == :full) && From d2671509a32f960ea774c71ef008d5861566ecbf Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 09:09:41 +0200 Subject: [PATCH 08/12] Fix sparse and complex --- .../src/scenarios/sparse.jl | 73 +++++++++++++++---- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/DifferentiationInterfaceTest/src/scenarios/sparse.jl b/DifferentiationInterfaceTest/src/scenarios/sparse.jl index d3821c7e4..3c914c053 100644 --- a/DifferentiationInterfaceTest/src/scenarios/sparse.jl +++ b/DifferentiationInterfaceTest/src/scenarios/sparse.jl @@ -34,14 +34,19 @@ function sparse_vec_to_vec_scenarios(x::AbstractVector) f! = diffsquare! y = f(x) jac = diffsquare_jacobian(x) + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) scens = Scenario[] for pl_op in (:out, :in) append!( scens, [ - Scenario{:jacobian,pl_op}(f, x; res1=jac), - Scenario{:jacobian,pl_op}(f!, y, x; res1=jac), + Scenario{:jacobian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + ), + Scenario{:jacobian,pl_op}( + f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + ), ], ) end @@ -70,14 +75,19 @@ function sparse_mat_to_vec_scenarios(x::AbstractMatrix) f! = diffsquarecube_matvec! y = f(x) jac = diffsquarecube_matvec_jacobian(x) + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) scens = Scenario[] for pl_op in (:out, :in) append!( scens, [ - Scenario{:jacobian,pl_op}(f, x; res1=jac), - Scenario{:jacobian,pl_op}(f!, y, x; res1=jac), + Scenario{:jacobian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + ), + Scenario{:jacobian,pl_op}( + f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + ), ], ) end @@ -103,14 +113,19 @@ function sparse_vec_to_mat_scenarios(x::AbstractVector) f! = diffsquarecube_vecmat! y = f(x) jac = diffsquarecube_vecmat_jacobian(vec(x)) + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) scens = Scenario[] for pl_op in (:out, :in) append!( scens, [ - Scenario{:jacobian,pl_op}(f, x; res1=jac), - Scenario{:jacobian,pl_op}(f!, y, x; res1=jac), + Scenario{:jacobian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + ), + Scenario{:jacobian,pl_op}( + f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + ), ], ) end @@ -138,14 +153,19 @@ function sparse_mat_to_mat_scenarios(x::AbstractMatrix) f! = diffsquarecube_matmat! y = f(x) jac = diffsquarecube_matmat_jacobian(x) + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) scens = Scenario[] for pl_op in (:out, :in) append!( scens, [ - Scenario{:jacobian,pl_op}(f, x; res1=jac), - Scenario{:jacobian,pl_op}(f!, y, x; res1=jac), + Scenario{:jacobian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + ), + Scenario{:jacobian,pl_op}( + f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + ), ], ) end @@ -180,10 +200,18 @@ function sparse_vec_to_num_scenarios(x::AbstractVector) f = sumdiffcube grad = sumdiffcube_gradient(x) hess = sumdiffcube_hessian(x) + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) scens = Scenario[] for pl_op in (:out, :in) - append!(scens, [Scenario{:hessian,pl_op}(f, x; res1=grad, res2=hess)]) + append!( + scens, + [ + Scenario{:hessian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=grad, res2=hess + ), + ], + ) end return scens end @@ -204,10 +232,18 @@ function sparse_mat_to_num_scenarios(x::AbstractMatrix) f = sumdiffcube_mat grad = sumdiffcube_mat_gradient(x) hess = sumdiffcube_mat_hessian(x) + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) scens = Scenario[] for pl_op in (:out, :in) - append!(scens, [Scenario{:hessian,pl_op}(f, x; res1=grad, res2=hess)]) + append!( + scens, + [ + Scenario{:hessian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=grad, res2=hess + ), + ], + ) end return scens end @@ -251,12 +287,17 @@ function squarelinearmap_scenarios(x::AbstractVector, band_sizes) f! = f y = f(x) jac = sparse(squarelinearmap_jacobian(x, A)) + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) for pl_op in (:out, :in) append!( scens, [ - Scenario{:jacobian,pl_op}(f, x; res1=jac), - Scenario{:jacobian,pl_op}(f!, y, x; res1=jac), + Scenario{:jacobian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=jac + ), + Scenario{:jacobian,pl_op}( + f!, y, x; prep_args=(; y=zero(y), x=x_prep, contexts=()), res1=jac + ), ], ) end @@ -306,12 +347,18 @@ end function squarequadraticform_scenarios(x::AbstractVector, band_sizes) n = length(x) scens = Scenario[] + x_prep = reshape(eltype(x).(1:length(x)) .^ 3, size(x)) for A in banded_matrix.(eltype(x), n, band_sizes) f = SquareQuadraticForm(A) grad = squarequadraticform_gradient(x, A) hess = sparse(squarequadraticform_hessian(x, A)) for pl_op in (:out, :in) - push!(scens, Scenario{:hessian,pl_op}(f, x; res1=grad, res2=hess)) + push!( + scens, + Scenario{:hessian,pl_op}( + f, x; prep_args=(; x=x_prep, contexts=()), res1=grad, res2=hess + ), + ) end end return scens From 5e26c1c77bac0fbfc5383a6c985031fd2c717454 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 19:40:30 +0200 Subject: [PATCH 09/12] All works except HVP --- .../test/Back/FiniteDiff/test.jl | 1 + .../test/Back/ForwardDiff/test.jl | 1 + .../test/Core/SimpleFiniteDiff/test.jl | 2 +- .../src/scenarios/default.jl | 18 ++++- .../src/test_differentiation.jl | 2 + .../src/tests/correctness_eval.jl | 81 +++++++++++++++++++ DifferentiationInterfaceTest/test/standard.jl | 2 +- 7 files changed, 102 insertions(+), 5 deletions(-) diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index e4be24ebc..911dab203 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -27,6 +27,7 @@ end include_cachified=true, include_constantorcachified=true, use_tuples=true, + include_smaller=true, ); excluded=[:second_derivative, :hvp], logging=LOGGING, diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index ada83edf2..8a99c08f4 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -41,6 +41,7 @@ end include_cachified=true, include_constantorcachified=true, use_tuples=true, + include_smaller=true, ); logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 3ceeb8ee2..beec9841f 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -63,7 +63,7 @@ end @testset "Dense" begin test_differentiation( vcat(backends, second_order_backends), - default_scenarios(; include_constantified=true); + default_scenarios(; include_constantified=true, include_smaller=true); logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index b601a3b70..39b8f4ff4 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -557,6 +557,7 @@ function default_scenarios(; include_cachified=false, include_constantorcachified=false, use_tuples=false, + include_smaller=false, ) x_ = 0.42 dx_ = 3.14 @@ -575,7 +576,7 @@ function default_scenarios(; dy_2_3 = float.(reshape(-5:2:5, 2, 3)) dy_6_2 = float.(reshape(-11:2:11, 6, 2)) - initialscens = vcat( + scens = vcat( # one argument num_to_num_scenarios(x_; dx=dx_, dy=dy_), onevec_to_onevec_scenarios_onearg(x_; dx=dx_, dy=dy_), @@ -623,8 +624,18 @@ function default_scenarios(; ), ) - scens = map(initialscens, smallerscens) do s1, s2 - s1 # TODO: readd smaller scens + scens_smaller_prep = map(scens, smallerscens) do s1, s2 + Scenario{operator(s1),operator_place(s1),function_place(s1)}(; + f=s1.f, + y=s1.y, + x=s1.x, + t=s1.t, + contexts=s1.contexts, + res1=s1.res1, + res2=s1.res2, + name=isnothing(s1.name) ? nothing : s1.name * " [smaller prep]", + prep_args=s2.prep_args, + ) end include_batchified && append!(scens, batchify(scens)) @@ -635,6 +646,7 @@ function default_scenarios(; include_constantified && append!(final_scens, constantify(scens)) include_cachified && append!(final_scens, cachify(scens; use_tuples=use_tuples)) include_constantorcachified && append!(final_scens, constantorcachify(scens)) + include_smaller && append!(final_scens, scens_smaller_prep) return final_scens end diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index 3fd4023b0..bc77ba038 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -84,6 +84,7 @@ function test_differentiation( rtol::Real=1e-3, scenario_intact::Bool=true, sparsity::Bool=false, + reprepare::Bool=true, # type stability options ignored_modules=nothing, function_filter=if VERSION >= v"1.11" @@ -160,6 +161,7 @@ function test_differentiation( rtol, scenario_intact, sparsity, + reprepare, ) end yield() diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index e7e0d49a2..bbedf3550 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -1,3 +1,6 @@ +has_size(::Union{Number,AbstractArray}) = true +has_size(_x) = false + const PME = PreparationMismatchError for op in ALL_OPS @@ -49,6 +52,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -57,6 +61,10 @@ for op in ALL_OPS prepstrict = $prep_op( f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + end [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -98,6 +106,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -106,6 +115,10 @@ for op in ALL_OPS prepstrict = $prep_op( f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + end [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -163,6 +176,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -176,6 +190,13 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) + if reprepare && + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) + prep = $prep_op!(f, y, prep, ba, x, contexts...) + prepstrict = $prep_op!(f, y, prepstrict, ba, x, contexts...) + end [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -225,6 +246,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -238,6 +260,13 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) + if reprepare && + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) + prep = $prep_op!(f, y, prep, ba, x, contexts...) + prepstrict = $prep_op!(f, y, prepstrict, ba, x, contexts...) + end [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -298,6 +327,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -306,6 +336,10 @@ for op in ALL_OPS prepstrict = $prep_op( f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + end [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -349,6 +383,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -357,6 +392,10 @@ for op in ALL_OPS prepstrict = $prep_op( f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) ) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + end [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -417,6 +456,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -431,6 +471,11 @@ for op in ALL_OPS strict=Val(true), ) prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, t, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_same = $prep_op_same(f, ba, x, t, contexts...) + end [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -468,6 +513,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -482,6 +528,11 @@ for op in ALL_OPS strict=Val(true), ) prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, t, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_same = $prep_op_same(f, ba, x, t, contexts...) + end [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -535,6 +586,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -552,6 +604,14 @@ for op in ALL_OPS strict=Val(true), ) prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...) + if reprepare && + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) + prep = $prep_op!(f, y, prep, ba, x, t, contexts...) + prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...) + prep_same = $prep_op_same(f, y, ba, x, t, contexts...) + end [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -601,6 +661,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -618,6 +679,14 @@ for op in ALL_OPS strict=Val(true), ) prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...) + if reprepare && + has_size(x) && + has_size(y) && + (size(x) != size(prep_args.x) || size(y) != prep_args.y) + prep = $prep_op!(f, y, prep, ba, x, t, contexts...) + prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...) + prep_same = $prep_op_same(f, y, ba, x, t, contexts...) + end [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -674,6 +743,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -688,6 +758,11 @@ for op in ALL_OPS strict=Val(true), ) prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, t, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_same = $prep_op_same(f, ba, x, t, contexts...) + end [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) @@ -725,6 +800,7 @@ for op in ALL_OPS rtol::Real, scenario_intact::Bool, sparsity::Bool, + reprepare::Bool, ) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) local prepstrict @@ -739,6 +815,11 @@ for op in ALL_OPS strict=Val(true), ) prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) + prep = $prep_op!(f, prep, ba, x, t, contexts...) + prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_same = $prep_op_same(f, ba, x, t, contexts...) + end [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index d85ab0603..b8cd1f3c4 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -13,7 +13,7 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( [AutoForwardDiff(), AutoForwardDiff(; chunksize=100)], - default_scenarios(; include_constantified=true); + default_scenarios(; include_smaller=true, include_constantified=true); logging=LOGGING, ) From 401096dd765be9816f0687b626027c9693360c63 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 11 May 2025 19:56:36 +0200 Subject: [PATCH 10/12] Fix tangents for prep same point --- .../src/scenarios/modify.jl | 32 +++++++++++++------ .../src/scenarios/scenario.jl | 14 ++++---- .../src/tests/correctness_eval.jl | 18 ++++------- DifferentiationInterfaceTest/src/utils.jl | 5 --- 4 files changed, 34 insertions(+), 35 deletions(-) diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index b6cb372cb..256fa538f 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -6,14 +6,26 @@ abstract type FunctionModifier end Return a new `Scenario` identical to `scen` except for the first- and second-order results which are set to zero. """ function Base.zero(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} + zero_res1 = if op in (:pushforward, :pullback) + map(zero, scen.res1) + else + zero(scen.res1) + end + zero_res2 = if isnothing(scen.res2) + nothing + elseif op == :hvp + map(zero, scen.res2) + else + zero(scen.res2) + end return Scenario{op,pl_op,pl_fun}(; f=scen.f, x=scen.x, y=scen.y, t=scen.t, contexts=scen.contexts, - res1=myzero(scen.res1), - res2=myzero(scen.res2), + res1=zero_res1, + res2=zero_res2, prep_args=scen.prep_args, name=isnothing(scen.name) ? nothing : scen.name * " [zero]", ) @@ -239,15 +251,15 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl cache_f = StoreInCache{pl_fun}(f) if use_tuples y_cache = if scen.y isa Number - (; useful_cache=([myzero(scen.y)],), useless_cache=[myzero(scen.y)]) + (; useful_cache=([zero(scen.y)],), useless_cache=[zero(scen.y)]) else - (; useful_cache=(mysimilar(scen.y),), useless_cache=mysimilar(scen.y)) + (; useful_cache=(similar(scen.y),), useless_cache=similar(scen.y)) end else y_cache = if scen.y isa Number - [myzero(scen.y)] + [zero(scen.y)] else - mysimilar(scen.y) + similar(scen.y) end end return Scenario{op,pl_op,pl_fun}(; @@ -321,14 +333,14 @@ function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_f a = 3.0 b = [4.0] constantorcache = if scen.y isa Number - (; cache=[myzero(scen.y)], constant=(; a, b)) + (; cache=[zero(scen.y)], constant=(; a, b)) else - (; cache=mysimilar(scen.y), constant=(; a, b)) + (; cache=similar(scen.y), constant=(; a, b)) end prep_constantorcache = if scen.y isa Number - (; cache=[myzero(scen.y)], constant=(; a=2a, b=3b)) + (; cache=[zero(scen.y)], constant=(; a=2a, b=3b)) else - (; cache=mysimilar(scen.y), constant=(; a=2a, b=3b)) + (; cache=similar(scen.y), constant=(; a=2a, b=3b)) end return Scenario{op,pl_op,pl_fun}(; f=constantorcache_f, diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index c1d7e31cf..929bb4999 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -60,9 +60,9 @@ struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,P< end end -function myzero_contexts(contexts...) +function zero_contexts(contexts...) rewrap = Rewrap(contexts...) - return rewrap(map(myzero ∘ unwrap, contexts)...) + return rewrap(map(zero ∘ unwrap, contexts)...) end function Scenario{op,pl_op}( @@ -71,7 +71,7 @@ function Scenario{op,pl_op}( contexts::Vararg{Context}; res1=nothing, res2=nothing, - prep_args=(; x=myzero(x), contexts=myzero_contexts(contexts...)), + prep_args=(; x=zero(x), contexts=zero_contexts(contexts...)), name=nothing, ) where {op,pl_op} y = f(x, map(unwrap, contexts)...) @@ -87,7 +87,7 @@ function Scenario{op,pl_op}( contexts::Vararg{Context}; res1=nothing, res2=nothing, - prep_args=(; y=myzero(y), x=myzero(x), contexts=myzero_contexts(contexts...)), + prep_args=(; y=zero(y), x=zero(x), contexts=zero_contexts(contexts...)), name=nothing, ) where {op,pl_op} f(y, x, map(unwrap, contexts)...) @@ -103,7 +103,7 @@ function Scenario{op,pl_op}( contexts::Vararg{Context}; res1=nothing, res2=nothing, - prep_args=(; x=myzero(x), t=map(myzero, t), contexts=myzero_contexts(contexts...)), + prep_args=(; x=zero(x), t=map(zero, t), contexts=zero_contexts(contexts...)), name=nothing, ) where {op,pl_op} y = f(x, map(unwrap, contexts)...) @@ -118,9 +118,7 @@ function Scenario{op,pl_op}( contexts::Vararg{Context}; res1=nothing, res2=nothing, - prep_args=(; - y=myzero(y), x=myzero(x), t=map(myzero, t), contexts=myzero_contexts(contexts...) - ), + prep_args=(; y=zero(y), x=zero(x), t=map(zero, t), contexts=zero_contexts(contexts...)), name=nothing, ) where {op,pl_op} f(y, x, map(unwrap, contexts)...) diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index bbedf3550..dbecdbbc7 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -470,11 +470,10 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) - prep_same = $prep_op_same(f, ba, x, t, contexts...) end [(), (prep,), (prepstrict,), (prep_same,)] end @@ -527,11 +526,10 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) - prep_same = $prep_op_same(f, ba, x, t, contexts...) end [(), (prep,), (prepstrict,), (prep_same,)] end @@ -603,14 +601,13 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) - prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...) + prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep, ba, x, t, contexts...) prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...) - prep_same = $prep_op_same(f, y, ba, x, t, contexts...) end [(), (prep,), (prepstrict,), (prep_same,)] end @@ -678,14 +675,13 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) - prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...) + prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep, ba, x, t, contexts...) prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...) - prep_same = $prep_op_same(f, y, ba, x, t, contexts...) end [(), (prep,), (prepstrict,), (prep_same,)] end @@ -757,11 +753,10 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) - prep_same = $prep_op_same(f, ba, x, t, contexts...) end [(), (prep,), (prepstrict,), (prep_same,)] end @@ -814,11 +809,10 @@ for op in ALL_OPS prep_args.contexts...; strict=Val(true), ) - prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...) + prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) - prep_same = $prep_op_same(f, ba, x, t, contexts...) end [(), (prep,), (prepstrict,), (prep_same,)] end diff --git a/DifferentiationInterfaceTest/src/utils.jl b/DifferentiationInterfaceTest/src/utils.jl index dbf1a2a1d..06e1747f2 100644 --- a/DifferentiationInterfaceTest/src/utils.jl +++ b/DifferentiationInterfaceTest/src/utils.jl @@ -1,8 +1,3 @@ -myzero(x::Number) = zero(x) -myzero(x::AbstractArray) = zero(x) -myzero(x::Union{Tuple,NamedTuple}) = map(myzero, x) -myzero(::Nothing) = nothing - mysimilar(x::Number) = one(x) mysimilar(x::AbstractArray) = similar(x) mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x) From d970dafaf09c39d8f9b8cbbb0a3475a7e6b6753b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 11 May 2025 22:19:20 +0200 Subject: [PATCH 11/12] Fixes --- .../test/Back/DifferentiateWith/test.jl | 4 +--- DifferentiationInterfaceTest/src/scenarios/modify.jl | 6 +++--- DifferentiationInterfaceTest/src/scenarios/scenario.jl | 10 ++++++---- .../src/test_differentiation.jl | 1 + DifferentiationInterfaceTest/test/standard.jl | 7 +++++-- DifferentiationInterfaceTest/test/zero_backends.jl | 8 +++++--- 6 files changed, 21 insertions(+), 15 deletions(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index dbc41f548..c8ea57c0b 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -16,9 +16,7 @@ function differentiatewith_scenarios() DIT.function_place(scen) == :out end good_scens = map(bad_scens) do scen - DIT.change_function( - scen, DifferentiateWith(scen.f, AutoFiniteDiff()); keep_smaller=false - ) + DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff())) end return good_scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 256fa538f..7677d6d15 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -36,9 +36,7 @@ end Return a new `Scenario` identical to `scen` except for the function `f` which is changed to `new_f`. """ -function change_function( - scen::Scenario{op,pl_op,pl_fun}, new_f; keep_smaller -) where {op,pl_op,pl_fun} +function change_function(scen::Scenario{op,pl_op,pl_fun}, new_f) where {op,pl_op,pl_fun} return Scenario{op,pl_op,pl_fun}(; f=new_f, x=scen.x, @@ -52,6 +50,8 @@ function change_function( ) end +same_function(scen) = change_function(scen, scen.f) + """ batchify(scen::Scenario) diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 929bb4999..9eaf5325c 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -199,13 +199,15 @@ function Base.show( end function adapt_batchsize(backend::AbstractADType, scen::Scenario) - (; x, y) = scen + (; x, y, prep_args) = scen + xprep = prep_args.x + yprep = hasproperty(prep_args, :y) ? prep_args.y : y Bmax = if x isa AbstractArray && y isa AbstractArray - min(length(x), length(y)) + min(length(x), length(y), length(xprep), length(yprep)) elseif x isa AbstractArray - length(x) + min(length(x), length(xprep)) elseif y isa AbstractArray - length(y) + min(length(y), length(yprep)) else typemax(Int) end diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index bc77ba038..e80d4e3d6 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -48,6 +48,7 @@ Each setting tests/benchmarks a different subset of calls: - `rtol=1e-3`: relative precision for correctness testing (when comparing to the reference outputs) - `scenario_intact=true`: whether to check that the scenario remains unchanged after the operators are applied - `sparsity=false`: whether to check sparsity patterns for Jacobians / Hessians +- `reprepare::Bool=true`: whether to modify preparation before testing when the preparation arguments have the wrong size **Type stability options:** diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index b8cd1f3c4..c4c351db4 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -20,7 +20,10 @@ test_differentiation( test_differentiation( [AutoForwardDiff(), AutoFiniteDiff(; relstep=1e-5)], default_scenarios(; - include_batchified=false, include_normal=false, include_constantorcachified=true + include_batchified=false, + include_normal=false, + include_cachified=true, + include_constantorcachified=true, ); logging=LOGGING, ) @@ -35,7 +38,7 @@ sparse_backend = AutoSparse( test_differentiation( sparse_backend, - sparse_scenarios(; include_cachified=true, use_tuples=true); + sparse_scenarios(; include_cachified=true, use_tuples=false); sparsity=true, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index e8f5ace6c..935de0cd1 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -11,15 +11,17 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( AutoZeroForward(), - default_scenarios(; include_batchified=false); - correctness=false, + map(zero, default_scenarios(; include_batchified=false)); type_stability=:full, logging=LOGGING, ) test_differentiation( AutoZeroReverse(), - default_scenarios(; include_batchified=false); + map( + DifferentiationInterfaceTest.same_function, + default_scenarios(; include_batchified=false), + ); correctness=false, type_stability=:prepared, logging=LOGGING, From 8e093bb15898774d0a70642e38407776d3d425ff Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 11 May 2025 23:44:56 +0200 Subject: [PATCH 12/12] Update DifferentiationInterfaceTest/src/scenarios/scenario.jl --- DifferentiationInterfaceTest/src/scenarios/scenario.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 9eaf5325c..9a9dea1e6 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -35,7 +35,7 @@ struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,P< res1::R1 "second-order result of the operator (if applicable)" res2::R2 - "named tuple of arguments passed to preparation, without the function" + "named tuple of arguments passed to preparation, without the function - the required keys are a subset of `(; y, x, t, contexts)` depending on the operator" prep_args::P "name of the scenario for display in test sets and dataframes" name::Union{String,Nothing}