From 9d1ca76472c4c7590a16e384ac100f6783f0f5b9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 12 May 2025 11:29:33 +0200 Subject: [PATCH 1/5] fix!: make strict preparation the default --- DifferentiationInterface/CHANGELOG.md | 2 + DifferentiationInterface/src/docstrings.jl | 3 +- .../src/first_order/derivative.jl | 16 +- .../src/first_order/gradient.jl | 4 +- .../src/first_order/jacobian.jl | 14 +- .../src/first_order/pullback.jl | 16 +- .../src/first_order/pushforward.jl | 16 +- .../src/second_order/hessian.jl | 4 +- .../src/second_order/hvp.jl | 8 +- .../src/second_order/second_derivative.jl | 4 +- DifferentiationInterface/src/utils/prep.jl | 2 +- .../test/Core/Internals/signature.jl | 18 +- .../src/tests/correctness_eval.jl | 176 +++++++++--------- 13 files changed, 132 insertions(+), 151 deletions(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index a792b1039..57784b388 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Preparation is now strict by default ([#799]) - New Arxiv preprint for citation ([#795]) ## [0.6.54] - 2025-05-11 @@ -28,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 +[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799 [#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795 [#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790 [#788]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/788 diff --git a/DifferentiationInterface/src/docstrings.jl b/DifferentiationInterface/src/docstrings.jl index e14ccfb52..6e22c2f1c 100644 --- a/DifferentiationInterface/src/docstrings.jl +++ b/DifferentiationInterface/src/docstrings.jl @@ -26,7 +26,8 @@ function docstring_prepare(operator; samepoint=false, inplace=false) Otherwise, preparation becomes invalid and you need to run it again. In some settings, invalid preparations may still give correct results (e.g. for backends that require no preparation), but this is not a semantic guarantee and should not be relied upon. - When `strict=Val(true)`, type checking is enforced between preparation and execution (but size checking is left to the user). + When `strict=Val(true)` (the default), type checking is enforced between preparation and execution (but size checking is left to the user). + While your code may work for different types by setting `strict=Val(false)`, this is not guaranteed by the API and can break without warning. """ end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 1784247d8..4883f1c82 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -1,24 +1,19 @@ ## Docstrings """ - prepare_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep - prepare_derivative(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_derivative(f, backend, x, [contexts...]; strict=Val(true)) -> prep + prepare_derivative(f!, y, backend, x, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("derivative"; inplace=true)) """ function prepare_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) ) where {F,C} return prepare_derivative_nokwarg(strict, f, backend, x, contexts...) end function prepare_derivative( - f!::F, - y, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) ) where {F,C} return prepare_derivative_nokwarg(strict, f!, y, backend, x, contexts...) end @@ -42,8 +37,7 @@ function prepare!_derivative( old_prep::DerivativePrep, backend::AbstractADType, x, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + contexts::Vararg{Context,C}, ) where {F,C} check_prep(f!, y, old_prep, backend, x, contexts...) return prepare_derivative_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 426fb205b..448adba16 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -1,12 +1,12 @@ ## Docstrings """ - prepare_gradient(f, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_gradient(f, backend, x, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("gradient")) """ function prepare_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) ) where {F,C} return prepare_gradient_nokwarg(strict, f, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 997a4dc75..5e6a07280 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -1,24 +1,19 @@ ## Docstrings """ - prepare_jacobian(f, backend, x, [contexts...]; strict=Val(false)) -> prep - prepare_jacobian(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_jacobian(f, backend, x, [contexts...]; strict=Val(true)) -> prep + prepare_jacobian(f!, y, backend, x, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("jacobian"; inplace=true)) """ function prepare_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) ) where {F,C} return prepare_jacobian_nokwarg(strict, f, backend, x, contexts...) end function prepare_jacobian( - f!::F, - y, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) ) where {F,C} return prepare_jacobian_nokwarg(strict, f!, y, backend, x, contexts...) end @@ -43,7 +38,6 @@ function prepare!_jacobian( backend::AbstractADType, x, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} check_prep(f!, y, old_prep, backend, x, contexts...) return prepare_jacobian_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 2ac1265c0..9d207d9e8 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_pullback(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep - prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep + prepare_pullback(f, backend, x, ty, [contexts...]; strict=Val(true)) -> prep + prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("pullback"; inplace=true)) """ @@ -12,7 +12,7 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pullback_nokwarg(strict, f, backend, x, ty, contexts...) end @@ -24,7 +24,7 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pullback_nokwarg(strict, f!, y, backend, x, ty, contexts...) end @@ -61,8 +61,8 @@ function prepare!_pullback( end """ - prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same - prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same + prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=Val(true)) -> prep_same + prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=Val(true)) -> prep_same $(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ @@ -72,7 +72,7 @@ function prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pullback_same_point_nokwarg(strict, f, backend, x, ty, contexts...) end @@ -84,7 +84,7 @@ function prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pullback_same_point_nokwarg(strict, f!, y, backend, x, ty, contexts...) end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index edba42655..11cfbf185 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_pushforward(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep - prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep + prepare_pushforward(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep + prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("pushforward"; inplace=true)) """ @@ -12,7 +12,7 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pushforward_nokwarg(strict, f, backend, x, tx, contexts...) end @@ -24,7 +24,7 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pushforward_nokwarg(strict, f!, y, backend, x, tx, contexts...) end @@ -63,8 +63,8 @@ function prepare!_pushforward( end """ - prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same - prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same + prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same + prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same $(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ @@ -74,7 +74,7 @@ function prepare_pushforward_same_point( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pushforward_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end @@ -86,7 +86,7 @@ function prepare_pushforward_same_point( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_pushforward_same_point_nokwarg( strict, f!, y, backend, x, tx, contexts... diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 8e17cff99..ffb6ba840 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -1,12 +1,12 @@ ## Docstrings """ - prepare_hessian(f, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_hessian(f, backend, x, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("hessian")) """ function prepare_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) ) where {F,C} return prepare_hessian_nokwarg(strict, f, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 1c4f3292a..8a85bf1c1 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_hvp(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep + prepare_hvp(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("hvp")) """ @@ -11,7 +11,7 @@ function prepare_hvp( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_hvp_nokwarg(strict, f, backend, x, tx, contexts...) end @@ -34,7 +34,7 @@ function prepare!_hvp( end """ - prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same + prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=Val(true)) -> prep_same $(docstring_prepare("hvp"; samepoint=true)) """ @@ -44,7 +44,7 @@ function prepare_hvp_same_point( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val=Val(true), ) where {F,C} return prepare_hvp_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 5c0903153..3d5b8c905 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -1,12 +1,12 @@ ## Docstrings """ - prepare_second_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_second_derivative(f, backend, x, [contexts...]; strict=Val(true)) -> prep $(docstring_prepare("second_derivative")) """ function prepare_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(true) ) where {F,C} return prepare_second_derivative_nokwarg(strict, f, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 2d0107d96..fc8700b39 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -109,7 +109,7 @@ function Base.showerror( end println( io, - "If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(false)` inside DifferentiationInterface.", + "If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(true)` inside DifferentiationInterface.", ) return nothing end diff --git a/DifferentiationInterface/test/Core/Internals/signature.jl b/DifferentiationInterface/test/Core/Internals/signature.jl index be2e144c1..da59872dc 100644 --- a/DifferentiationInterface/test/Core/Internals/signature.jl +++ b/DifferentiationInterface/test/Core/Internals/signature.jl @@ -11,8 +11,8 @@ y = zeros(2) c = 2.0 @testset "Out of place, no tangents" begin - prep = prepare_derivative(f, backend, x, Constant(c); strict=Val(true)) - prep_chill = prepare_derivative(f, backend, x, Constant(c); strict=Val(false)) + prep = prepare_derivative(f, backend, x, Constant(c)) + prep_chill = prepare_derivative(f, backend, x, Constant(c)) @test_throws MethodError derivative(nothing, prep_chill, backend, x, Constant(c)) @@ -68,8 +68,8 @@ c = 2.0 end @testset "In place, no tangents" begin - prep = prepare_derivative(f!, y, backend, x; strict=Val(true)) - prep_chill = prepare_derivative(f!, y, backend, x; strict=Val(false)) + prep = prepare_derivative(f!, y, backend, x) + prep_chill = prepare_derivative(f!, y, backend, x) @test_throws MethodError derivative(nothing, y, prep_chill, backend, x, Constant(c)) @@ -86,8 +86,8 @@ end end @testset "Out of place, with tangents" begin - prep = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(true)) - prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(false)) + prep = prepare_pushforward(f, backend, x, (x,), Constant(c)) + prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c)) @test_throws MethodError pushforward(nothing, prep_chill, backend, x, (x,)) @@ -104,10 +104,8 @@ end end @testset "In place, with tangents" begin - prep = prepare_pushforward(f!, y, backend, x, (x,); strict=Val(true)) - prep_chill = prepare_pushforward( - f!, y, backend, x, (x,), Constant(c); strict=Val(false) - ) + prep = prepare_pushforward(f!, y, backend, x, (x,)) + prep_chill = prepare_pushforward(f!, y, backend, x, (x,), Constant(c)) @test_throws MethodError pushforward(nothing, y, prep_chill, backend, x, (x,)) diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index dbecdbbc7..22208bb43 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -55,17 +55,17 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) - prepstrict = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) + prep_nostrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) ) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, contexts...) end - [(), (prep,), (prepstrict,)] + [(), (prep,), (prep_nostrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -92,8 +92,8 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end - @test_throws PME $val_and_op(nothing, prepstrict, ba, x, contexts...) - @test_throws PME $op(nothing, prepstrict, ba, x, contexts...) + @test_throws PME $val_and_op(nothing, prep, ba, x, contexts...) + @test_throws PME $op(nothing, prep, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -109,17 +109,17 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) - prepstrict = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) + prep_nostrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) ) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, contexts...) end - [(), (prep,), (prepstrict,)] + [(), (prep,), (prep_nostrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -159,9 +159,9 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, mysimilar(res1), prepstrict, ba, x, contexts... + nothing, mysimilar(res1), prep, ba, x, contexts... ) - @test_throws PME $op!(nothing, mysimilar(res1), prepstrict, ba, x, contexts...) + @test_throws PME $op!(nothing, mysimilar(res1), prep, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -179,25 +179,25 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep, ba, x, contexts...) - prepstrict = $prep_op!(f, y, prepstrict, ba, x, contexts...) + prep_nostrict = $prep_op!(f, y, prep_nostrict, ba, x, contexts...) end - [(), (prep,), (prepstrict,)] + [(), (prep,), (prep_nostrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -230,10 +230,8 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end - @test_throws PME $val_and_op( - nothing, mysimilar(y), prepstrict, ba, x, contexts... - ) - @test_throws PME $op(nothing, mysimilar(y), prepstrict, ba, x, contexts...) + @test_throws PME $val_and_op(nothing, mysimilar(y), prep, ba, x, contexts...) + @test_throws PME $op(nothing, mysimilar(y), prep, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -249,25 +247,25 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, prep_args.y, ba, prep_args.x, prep_args.contexts...) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep, ba, x, contexts...) - prepstrict = $prep_op!(f, y, prepstrict, ba, x, contexts...) + prep_nostrict = $prep_op!(f, y, prep_nostrict, ba, x, contexts...) end - [(), (prep,), (prepstrict,)] + [(), (prep,), (prep_nostrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -309,10 +307,10 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, contexts... + nothing, mysimilar(y), mysimilar(res1), prep, ba, x, contexts... ) @test_throws PME $op!( - nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, contexts... + nothing, mysimilar(y), mysimilar(res1), prep, ba, x, contexts... ) scenario_intact && @test new_scen == scen return nothing @@ -330,17 +328,17 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) - prepstrict = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) + prep_nostrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) ) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, contexts...) end - [(), (prep,), (prepstrict,)] + [(), (prep,), (prep_nostrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val, res2_out1_val = $val_and_op( @@ -369,8 +367,8 @@ for op in ALL_OPS @test mynnz(res2_out2_noval) == mynnz(scen.res2) end end - @test_throws PME $val_and_op(nothing, prepstrict, ba, x, contexts...) - @test_throws PME $op(nothing, prepstrict, ba, x, contexts...) + @test_throws PME $val_and_op(nothing, prep, ba, x, contexts...) + @test_throws PME $op(nothing, prep, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -386,17 +384,17 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.contexts...) - prepstrict = $prep_op( - f, ba, prep_args.x, prep_args.contexts...; strict=Val(true) + prep_nostrict = $prep_op( + f, ba, prep_args.x, prep_args.contexts...; strict=Val(false) ) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, contexts...) end - [(), (prep,), (prepstrict,)] + [(), (prep,), (prep_nostrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val, res2_in1_val = mysimilar(res1), mysimilar(res2) @@ -440,9 +438,9 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, mysimilar(res1), mysimilar(res2), prepstrict, ba, x, contexts... + nothing, mysimilar(res1), mysimilar(res2), prep, ba, x, contexts... ) - @test_throws PME $op!(nothing, mysimilar(res2), prepstrict, ba, x, contexts...) + @test_throws PME $op!(nothing, mysimilar(res2), prep, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -459,23 +457,23 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, ba, prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, t, contexts...) end - [(), (prep,), (prepstrict,), (prep_same,)] + [(), (prep,), (prep_nostrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -498,8 +496,8 @@ for op in ALL_OPS end end end - @test_throws PME $val_and_op(nothing, prepstrict, ba, x, t, contexts...) - @test_throws PME $op(nothing, prepstrict, ba, x, t, contexts...) + @test_throws PME $val_and_op(nothing, prep, ba, x, t, contexts...) + @test_throws PME $op(nothing, prep, ba, x, t, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -515,23 +513,23 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, ba, prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, t, contexts...) end - [(), (prep,), (prepstrict,), (prep_same,)] + [(), (prep,), (prep_nostrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -567,11 +565,9 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, mysimilar(res1), prepstrict, ba, x, t, contexts... - ) - @test_throws PME $op!( - nothing, mysimilar(res1), prepstrict, ba, x, t, contexts... + nothing, mysimilar(res1), prep, ba, x, t, contexts... ) + @test_throws PME $op!(nothing, mysimilar(res1), prep, ba, x, t, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -587,19 +583,19 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... ) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...) if reprepare && @@ -607,9 +603,9 @@ for op in ALL_OPS has_size(y) && (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep, ba, x, t, contexts...) - prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...) + prep_nostrict = $prep_op!(f, y, prep_nostrict, ba, x, t, contexts...) end - [(), (prep,), (prepstrict,), (prep_same,)] + [(), (prep,), (prep_nostrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -642,10 +638,8 @@ for op in ALL_OPS end end end - @test_throws PME $val_and_op( - nothing, mysimilar(y), prepstrict, ba, x, t, contexts... - ) - @test_throws PME $op(nothing, mysimilar(y), prepstrict, ba, x, t, contexts...) + @test_throws PME $val_and_op(nothing, mysimilar(y), prep, ba, x, t, contexts...) + @test_throws PME $op(nothing, mysimilar(y), prep, ba, x, t, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -661,19 +655,19 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, t, res1, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts... ) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, prep_args.y, ba, prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...) if reprepare && @@ -681,9 +675,9 @@ for op in ALL_OPS has_size(y) && (size(x) != size(prep_args.x) || size(y) != prep_args.y) prep = $prep_op!(f, y, prep, ba, x, t, contexts...) - prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...) + prep_nostrict = $prep_op!(f, y, prep_nostrict, ba, x, t, contexts...) end - [(), (prep,), (prepstrict,), (prep_same,)] + [(), (prep,), (prep_nostrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -721,10 +715,10 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, t, contexts... + nothing, mysimilar(y), mysimilar(res1), prep, ba, x, t, contexts... ) @test_throws PME $op!( - nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, t, contexts... + nothing, mysimilar(y), mysimilar(res1), prep, ba, x, t, contexts... ) scenario_intact && @test new_scen == scen return nothing @@ -742,23 +736,23 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, ba, prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, t, contexts...) end - [(), (prep,), (prepstrict,), (prep_same,)] + [(), (prep,), (prep_nostrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_out1_noval = $op(f, preptup_noval..., ba, x, t, contexts...) @@ -781,8 +775,8 @@ for op in ALL_OPS end end end - @test_throws PME $val_and_op(nothing, prepstrict, ba, x, t, contexts...) - @test_throws PME $op(nothing, prepstrict, ba, x, t, contexts...) + @test_throws PME $val_and_op(nothing, prep, ba, x, t, contexts...) + @test_throws PME $op(nothing, prep, ba, x, t, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -798,23 +792,23 @@ for op in ALL_OPS reprepare::Bool, ) (; f, x, y, t, res1, res2, contexts, prep_args) = new_scen = deepcopy(scen) - local prepstrict + local prep preptup_cands_val, preptup_cands_noval = map(1:2) do _ prep = $prep_op(f, ba, prep_args.x, prep_args.t, prep_args.contexts...) - prepstrict = $prep_op( + prep_nostrict = $prep_op( f, ba, prep_args.x, prep_args.t, prep_args.contexts...; - strict=Val(true), + strict=Val(false), ) prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...) if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x)) prep = $prep_op!(f, prep, ba, x, t, contexts...) - prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...) + prep_nostrict = $prep_op!(f, prep_nostrict, ba, x, t, contexts...) end - [(), (prep,), (prepstrict,), (prep_same,)] + [(), (prep,), (prep_nostrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_in1_noval = mysimilar(res2) @@ -851,11 +845,9 @@ for op in ALL_OPS end end end - @test_throws PME $op!( - nothing, mysimilar(res2), prepstrict, ba, x, t, contexts... - ) + @test_throws PME $op!(nothing, mysimilar(res2), prep, ba, x, t, contexts...) @test_throws PME $val_and_op!( - nothing, mysimilar(res1), mysimilar(res2), prepstrict, ba, x, t, contexts... + nothing, mysimilar(res1), mysimilar(res2), prep, ba, x, t, contexts... ) scenario_intact && @test new_scen == scen return nothing From ac0fb129c1378897905f4396bfb24280a8256d8f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 12 May 2025 11:34:23 +0200 Subject: [PATCH 2/5] Fix compat --- DifferentiationInterface/Project.toml | 2 +- DifferentiationInterfaceTest/Project.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 8d39ac708..486b91a5b 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.54" +version = "0.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 0ff69f682..666ec721b 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.10.0" +version = "0.10.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -44,7 +44,7 @@ AllocCheck = "0.2" Chairmarks = "1.2.1" ComponentArrays = "0.15" DataFrames = "1.6.1" -DifferentiationInterface = "0.6.53" +DifferentiationInterface = "0.7.0" DocStringExtensions = "0.8,0.9" ExplicitImports = "1.10.1" FiniteDiff = "2.27.0" From 36f5a4a9567738a39c305b3674d8d3f733a48f5e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 12 May 2025 12:32:46 +0200 Subject: [PATCH 3/5] Fix --- .../test/Core/Internals/signature.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/test/Core/Internals/signature.jl b/DifferentiationInterface/test/Core/Internals/signature.jl index da59872dc..3c326c4c7 100644 --- a/DifferentiationInterface/test/Core/Internals/signature.jl +++ b/DifferentiationInterface/test/Core/Internals/signature.jl @@ -12,7 +12,7 @@ c = 2.0 @testset "Out of place, no tangents" begin prep = prepare_derivative(f, backend, x, Constant(c)) - prep_chill = prepare_derivative(f, backend, x, Constant(c)) + prep_chill = prepare_derivative(f, backend, x, Constant(c); strict=Val(false)) @test_throws MethodError derivative(nothing, prep_chill, backend, x, Constant(c)) @@ -69,7 +69,7 @@ end @testset "In place, no tangents" begin prep = prepare_derivative(f!, y, backend, x) - prep_chill = prepare_derivative(f!, y, backend, x) + prep_chill = prepare_derivative(f!, y, backend, x; strict=Val(false)) @test_throws MethodError derivative(nothing, y, prep_chill, backend, x, Constant(c)) @@ -87,7 +87,7 @@ end @testset "Out of place, with tangents" begin prep = prepare_pushforward(f, backend, x, (x,), Constant(c)) - prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c)) + prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(false)) @test_throws MethodError pushforward(nothing, prep_chill, backend, x, (x,)) @@ -105,7 +105,9 @@ end @testset "In place, with tangents" begin prep = prepare_pushforward(f!, y, backend, x, (x,)) - prep_chill = prepare_pushforward(f!, y, backend, x, (x,), Constant(c)) + prep_chill = prepare_pushforward( + f!, y, backend, x, (x,), Constant(c); strict=Val(false) + ) @test_throws MethodError pushforward(nothing, y, prep_chill, backend, x, (x,)) From 85d1b3aba64a28ef9597af27569e709acd0920c3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 12 May 2025 15:48:56 +0200 Subject: [PATCH 4/5] Fix error message --- DifferentiationInterface/src/utils/prep.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index fc8700b39..2d0107d96 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -109,7 +109,7 @@ function Base.showerror( end println( io, - "If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(true)` inside DifferentiationInterface.", + "If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(false)` inside DifferentiationInterface.", ) return nothing end From 08390756ba13881f9919a55385f9a5a33dbf840e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 12 May 2025 15:52:15 +0200 Subject: [PATCH 5/5] Update changelogs --- DifferentiationInterface/CHANGELOG.md | 5 ++++- DifferentiationInterfaceTest/CHANGELOG.md | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 57784b388..a3da3a10c 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.7.0] + ### Changed - Preparation is now strict by default ([#799]) @@ -25,7 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Allocate Enzyme shadow memory during preparation ([#782]) -[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.54...main +[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.0...main +[0.7.0]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.54...DifferentiationInterface-v0.7.0 [0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 diff --git a/DifferentiationInterfaceTest/CHANGELOG.md b/DifferentiationInterfaceTest/CHANGELOG.md index fb687131b..764ffd6ff 100644 --- a/DifferentiationInterfaceTest/CHANGELOG.md +++ b/DifferentiationInterfaceTest/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.10.0] + ### Changed - Specify preparation arguments in DIT Scenario ([#786]) @@ -23,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support nested tuples of arrays as Caches ([#748]) - Test type consistency between preparation and execution ([#745]) -[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.6...main +[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.10.0...main +[0.10.0]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.6...DifferentiationInterfaceTest-v0.10.0 [0.9.6]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.5...DifferentiationInterfaceTest-v0.9.6 [#796]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/796