From 025367c22d9788b7b288ab7c0188b193f138f154 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 09:17:11 +0100 Subject: [PATCH 01/14] fix: separe `prepare` from the hidden `prepare_nokwarg` --- .../docs/src/dev_guide.md | 10 ++-- .../differentiate_with.jl | 2 +- .../reverse_onearg.jl | 2 +- .../DifferentiationInterfaceDiffractorExt.jl | 4 +- .../forward_onearg.jl | 6 +- .../forward_twoarg.jl | 2 +- .../reverse_onearg.jl | 4 +- .../reverse_twoarg.jl | 2 +- .../onearg.jl | 24 ++++---- .../twoarg.jl | 8 +-- .../onearg.jl | 10 ++-- .../twoarg.jl | 14 +++-- ...rentiationInterfaceFiniteDifferencesExt.jl | 8 +-- .../onearg.jl | 40 +++++++------ .../twoarg.jl | 30 +++++----- .../onearg.jl | 14 ++--- .../twoarg.jl | 4 +- .../onearg.jl | 4 +- .../twoarg.jl | 2 +- .../onearg.jl | 24 ++++---- .../twoarg.jl | 12 ++-- .../onearg.jl | 18 +++--- .../twoarg.jl | 6 +- .../hessian.jl | 10 +++- .../jacobian.jl | 8 +-- .../jacobian_mixed.jl | 8 +-- .../onearg.jl | 22 +++---- .../twoarg.jl | 6 +- .../DifferentiationInterfaceTrackerExt.jl | 4 +- .../DifferentiationInterfaceZygoteExt.jl | 12 ++-- .../src/fallbacks/change_prep.jl | 22 ++++--- .../src/fallbacks/no_prep.jl | 44 +++++++------- .../src/first_order/derivative.jl | 22 +++++-- .../src/first_order/gradient.jl | 12 ++-- .../src/first_order/jacobian.jl | 20 +++++-- .../src/first_order/pullback.jl | 58 ++++++++++++++++--- .../src/first_order/pushforward.jl | 56 +++++++++++++++--- .../src/misc/from_primitive.jl | 24 +++++--- .../src/misc/simple_finite_diff.jl | 4 +- .../src/misc/zero_backends.jl | 8 +-- .../src/second_order/hessian.jl | 12 ++-- .../src/second_order/hvp.jl | 56 ++++++++++++------ .../src/second_order/second_derivative.jl | 10 ++-- .../test/Back/ForwardDiff/test.jl | 4 +- .../test/Back/ReverseDiff/test.jl | 2 +- 45 files changed, 418 insertions(+), 256 deletions(-) diff --git a/DifferentiationInterface/docs/src/dev_guide.md b/DifferentiationInterface/docs/src/dev_guide.md index 10e74a905..a24bb7eb4 100644 --- a/DifferentiationInterface/docs/src/dev_guide.md +++ b/DifferentiationInterface/docs/src/dev_guide.md @@ -23,10 +23,10 @@ Most operators have 4 variants, which look like this in the first order: `operat To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above). For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`). -The method `prepare_operator` must output a `prep` object of the correct type. -For instance, `prepare_gradient(f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref). -Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep`. -Otherwise, define a custom struct like `MyGradientPrep <: DifferentiationInterface.GradientPrep` and put the necessary storage in there. +The method `prepare_operator_kwarg` must output a `prep` object of the correct type. +For instance, `prepare_gradient(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref). +Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`. +Otherwise, define a custom struct like `MyGradientPrep{SIG} <: DifferentiationInterface.GradientPrep{SIG}` and put the necessary storage in there. ## New backend @@ -75,4 +75,4 @@ GROUP = get(ENV, "JULIA_DI_TEST_GROUP", "Back/SuperDiff") but don't forget to switch it back before pushing. -Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`). \ No newline at end of file +Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`). diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index be85d1f24..23f9c9b0c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,)) + prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,)) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 7c68b20e3..07f74edbb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -6,7 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} pb::PB end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoReverseChainRules, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index baab8a75f..bd286dc34 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -10,7 +10,9 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward -function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple) +function DI.prepare_pushforward_nokwarg( + strict::Val, f, backend::AutoDiffractor, x, tx::NTuple +) _sig = DI.signature(f, backend, x, tx; strict) return DI.NoPushforwardPrep(_sig) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 96b2fff76..9ce785d99 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,6 +1,6 @@ ## Pushforward -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, @@ -122,7 +122,7 @@ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG} shadows::O end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, @@ -203,7 +203,7 @@ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG} output_length::Int end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 33e5ce1aa..4d8328c3e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,6 +1,6 @@ ## Pushforward -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 205a86837..b0a52fb92 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -52,7 +52,7 @@ struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG} y_example::Y # useful to create return activity end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, @@ -191,7 +191,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 30562fb7f..ae2b33923 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -5,7 +5,7 @@ struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG} ty_copy::TY end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 0ebfdb7dc..dfbbf3549 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -7,7 +7,7 @@ struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardP jvp_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, @@ -105,7 +105,7 @@ struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} vjp_exe!::E1! end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, @@ -204,7 +204,7 @@ struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePre der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -284,7 +284,7 @@ struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} jac_exe!::E1! end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -360,7 +360,7 @@ struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SI jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, @@ -445,7 +445,7 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: der2_exe!::E2! end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -462,7 +462,7 @@ function DI.prepare_second_derivative( der2_exe = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=false) der2_exe! = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=true) - derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) + derivative_prep = DI.prepare_derivative_nokwarg(strict, f, backend, x, contexts...) return FastDifferentiationAllocatingSecondDerivativePrep( _sig, y_prototype, derivative_prep, der2_exe, der2_exe! ) @@ -534,7 +534,7 @@ struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} gradient_prep::E1 end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, @@ -557,7 +557,7 @@ function DI.prepare_hvp( hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) return FastDifferentiationHVPPrep(_sig, hvp_exe, hvp_exe!, gradient_prep) end @@ -633,7 +633,7 @@ struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} hess_exe!::E2! end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, @@ -656,7 +656,9 @@ function DI.prepare_hessian( hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false) hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true) - gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg( + strict, f, dense_ad(backend), x, contexts... + ) return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index f67ed4324..fb4b560f8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -6,7 +6,7 @@ struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPre jvp_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -107,7 +107,7 @@ struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} vjp_exe!::E1! end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!, y, @@ -213,7 +213,7 @@ struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{ der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -295,7 +295,7 @@ struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 7ebdb6e49..fa3d73aa2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -8,7 +8,7 @@ struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} dir::D end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) @@ -124,7 +124,7 @@ struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} dir::D end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -253,7 +253,7 @@ struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} dir::D end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -341,7 +341,7 @@ struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} dir::D end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -446,7 +446,7 @@ struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} absstep_h::AH end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index 259ebbdd3..11bbcfbb9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -8,7 +8,7 @@ struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} dir::D end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -161,7 +161,7 @@ struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} dir::D end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -198,7 +198,9 @@ function DI.prepare!_derivative( cache.c3 isa Union{Number,Nothing} || resize!(cache.c3, length(y)) return old_prep else - return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_derivative_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end @@ -277,7 +279,7 @@ struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} dir::D end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -318,7 +320,9 @@ function DI.prepare!_jacobian( cache.sparsity = nothing return old_prep else - return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_jacobian_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 31bcd5961..68bd2918f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -11,7 +11,7 @@ DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple() ## Pushforward -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, @@ -54,7 +54,7 @@ end ## Pullback -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, @@ -97,7 +97,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -154,7 +154,7 @@ end ## Jacobian -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 703fe6be6..30f769417 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -67,7 +67,7 @@ struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} contexts_dual::CD end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, @@ -215,11 +215,13 @@ end ### Prepared -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) - pushforward_prep = DI.prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) + pushforward_prep = DI.prepare_pushforward_nokwarg( + strict, f, backend, x, (one(x),), contexts... + ) return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep) end @@ -297,7 +299,7 @@ function DI.value_and_gradient!( grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) return y, grad else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) end end @@ -315,7 +317,7 @@ function DI.value_and_gradient( result = gradient!(result, fc, x) return DR.value(result), DR.gradient(result) else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_gradient(f, prep, backend, x, contexts...) end end @@ -331,7 +333,7 @@ function DI.gradient!( fc = DI.with_contexts(f, contexts...) return gradient!(grad, fc, x) else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.gradient!(f, grad, prep, backend, x, contexts...) end end @@ -347,7 +349,7 @@ function DI.gradient( fc = DI.with_contexts(f, contexts...) return gradient(fc, x) else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.gradient(f, prep, backend, x, contexts...) end end @@ -360,7 +362,7 @@ struct ForwardDiffGradientPrep{SIG,C,CD} <: DI.GradientPrep{SIG} contexts_dual::CD end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, @@ -471,7 +473,7 @@ function DI.value_and_jacobian!( jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result)) return y, jac else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_jacobian!(f, jac, prep, backend, x, contexts...) end end @@ -487,7 +489,7 @@ function DI.value_and_jacobian( fc = DI.with_contexts(f, contexts...) return fc(x), jacobian(fc, x) else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_jacobian(f, prep, backend, x, contexts...) end end @@ -503,7 +505,7 @@ function DI.jacobian!( fc = DI.with_contexts(f, contexts...) return jacobian!(jac, fc, x) else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.jacobian!(f, jac, prep, backend, x, contexts...) end end @@ -519,7 +521,7 @@ function DI.jacobian( fc = DI.with_contexts(f, contexts...) return jacobian(fc, x) else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.jacobian(f, prep, backend, x, contexts...) end end @@ -532,7 +534,7 @@ struct ForwardDiffOneArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} contexts_dual::CD end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -620,7 +622,7 @@ end ## Second derivative -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -719,7 +721,7 @@ function DI.hessian!( fc = DI.with_contexts(f, contexts...) return hessian!(hess, fc, x) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.hessian!(f, hess, prep, backend, x, contexts...) end end @@ -735,7 +737,7 @@ function DI.hessian( fc = DI.with_contexts(f, contexts...) return hessian(fc, x) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.hessian(f, prep, backend, x, contexts...) end end @@ -761,7 +763,7 @@ function DI.value_gradient_and_hessian!( hess === DR.hessian(result) || copyto!(hess, DR.hessian(result)) return (y, grad, hess) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_gradient_and_hessian!(f, grad, hess, prep, backend, x, contexts...) end end @@ -779,7 +781,7 @@ function DI.value_gradient_and_hessian( result = hessian!(result, fc, x) return (DR.value(result), DR.gradient(result), DR.hessian(result)) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_gradient_and_hessian(f, prep, backend, x, contexts...) end end @@ -793,7 +795,7 @@ struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} contexts_dual::CD end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index a919d8965..5073838b4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -8,7 +8,7 @@ struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} contexts_dual::CD end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -137,7 +137,7 @@ function DI.value_and_derivative( result = derivative!(result, fc!, y, x) return DiffResults.value(result), DiffResults.derivative(result) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_derivative(f!, y, prep, backend, x, contexts...) end end @@ -151,7 +151,7 @@ function DI.value_and_derivative!( result = derivative!(result, fc!, y, x) return DiffResults.value(result), DiffResults.derivative(result) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_derivative!(f!, y, der, prep, backend, x, contexts...) end end @@ -163,7 +163,7 @@ function DI.derivative( fc! = DI.with_contexts(f!, contexts...) return derivative(fc!, y, x) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.derivative(f!, y, prep, backend, x, contexts...) end end @@ -175,7 +175,7 @@ function DI.derivative!( fc! = DI.with_contexts(f!, contexts...) return derivative!(der, fc!, y, x) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.derivative!(f!, y, der, prep, backend, x, contexts...) end end @@ -188,7 +188,7 @@ struct ForwardDiffTwoArgDerivativePrep{SIG,C,CD} <: DI.DerivativePrep{SIG} contexts_dual::CD end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -212,7 +212,9 @@ function DI.prepare!_derivative( resize!(config.duals, length(y)) return old_prep else - return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_derivative_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end @@ -312,7 +314,7 @@ function DI.value_and_jacobian( result = jacobian!(result, fc!, y, x) return DiffResults.value(result), DiffResults.jacobian(result) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_jacobian(f!, y, prep, backend, x, contexts...) end end @@ -330,7 +332,7 @@ function DI.value_and_jacobian!( result = jacobian!(result, fc!, y, x) return DiffResults.value(result), DiffResults.jacobian(result) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_jacobian!(f!, y, jac, prep, backend, x, contexts...) end end @@ -346,7 +348,7 @@ function DI.jacobian( fc! = DI.with_contexts(f!, contexts...) return jacobian(fc!, y, x) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.jacobian(f!, y, prep, backend, x, contexts...) end end @@ -362,7 +364,7 @@ function DI.jacobian!( fc! = DI.with_contexts(f!, contexts...) return jacobian!(jac, fc!, y, x) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) end end @@ -375,7 +377,7 @@ struct ForwardDiffTwoArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} contexts_dual::CD end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -402,7 +404,9 @@ function DI.prepare!_jacobian( resize!(xduals, length(x)) return old_prep else - return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_jacobian_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 83be97ed1..514550d26 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -7,7 +7,7 @@ struct GTPSAOneArgPushforwardPrep{SIG,X} <: DI.PushforwardPrep{SIG} xt::X end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoGTPSA{D}, @@ -115,7 +115,7 @@ struct GTPSAOneArgGradientPrep{SIG,X} <: DI.GradientPrep{SIG} end # Unlike JVP, this requires us to use all variables -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -198,7 +198,7 @@ struct GTPSAOneArgJacobianPrep{SIG,X} <: DI.JacobianPrep{SIG} end # To materialize the entire Jacobian we use all variables -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -284,7 +284,7 @@ struct GTPSAOneArgSecondDerivativePrep{SIG,X} <: DI.SecondDerivativePrep{SIG} xt::X end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -414,7 +414,7 @@ struct GTPSAOneArgHessianPrep{SIG,X,M} <: DI.HessianPrep{SIG} m::M end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -549,11 +549,11 @@ struct GTPSAOneArgHVPPrep{SIG,E,H} <: DI.HVPPrep{SIG} hess::H end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C} ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - hessprep = DI.prepare_hessian(strict, f, backend, x, contexts...) + hessprep = DI.prepare_hessian_nokwarg(strict, f, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) return GTPSAOneArgHVPPrep(_sig, hessprep, hess) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index e77c4900d..b106923b5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -10,7 +10,7 @@ struct GTPSATwoArgPushforwardPrep{SIG,X,Y} <: DI.PushforwardPrep{SIG} yt::Y end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -125,7 +125,7 @@ struct GTPSATwoArgJacobianPrep{SIG,X,Y} <: DI.JacobianPrep{SIG} yt::Y end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index e1e0c580b..2e46bc7c4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -6,7 +6,7 @@ struct MooncakeOneArgPullbackPrep{SIG,Tcache,DY} <: DI.PullbackPrep{SIG} dy_righttype::DY end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) @@ -106,7 +106,7 @@ struct MooncakeGradientPrep{SIG,Tcache} <: DI.GradientPrep{SIG} cache::Tcache end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C} ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 43bed9857..e89fbc37e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -5,7 +5,7 @@ struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} target_function::F end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 2e9380fd4..a646488c5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -6,7 +6,7 @@ struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SI single_threaded_prep::P end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, @@ -15,7 +15,7 @@ function DI.prepare_pushforward( contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - single_threaded_prep = DI.prepare_pushforward( + single_threaded_prep = DI.prepare_pushforward_nokwarg( strict, f, single_threaded(backend), x, tx, contexts... ) return PolyesterForwardDiffOneArgPushforwardPrep(_sig, single_threaded_prep) @@ -86,11 +86,11 @@ struct PolyesterForwardDiffOneArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} single_threaded_prep::P end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_derivative( + single_threaded_prep = DI.prepare_derivative_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgDerivativePrep(_sig, single_threaded_prep) @@ -158,7 +158,7 @@ struct PolyesterForwardDiffGradientPrep{SIG,chunksize,P} <: DI.GradientPrep{SIG} single_threaded_prep::P end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff{chunksize}, @@ -171,7 +171,7 @@ function DI.prepare_gradient( else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_gradient( + single_threaded_prep = DI.prepare_gradient_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffGradientPrep(_sig, chunk, single_threaded_prep) @@ -249,7 +249,7 @@ struct PolyesterForwardDiffOneArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPre single_threaded_prep::P end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff{chunksize}, @@ -262,7 +262,7 @@ function DI.prepare_jacobian( else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_jacobian( + single_threaded_prep = DI.prepare_jacobian_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgJacobianPrep(_sig, chunk, single_threaded_prep) @@ -339,11 +339,11 @@ struct PolyesterForwardDiffHessianPrep{SIG,P} <: DI.HessianPrep{SIG} single_threaded_prep::P end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_hessian( + single_threaded_prep = DI.prepare_hessian_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffHessianPrep(_sig, single_threaded_prep) @@ -411,11 +411,11 @@ struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG,P} <: DI.SecondDerivat single_threaded_prep::P end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_second_derivative( + single_threaded_prep = DI.prepare_second_derivative_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgSecondDerivativePrep(_sig, single_threaded_prep) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 74d460cc3..65baba690 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -5,7 +5,7 @@ struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SI single_threaded_prep::P end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -15,7 +15,7 @@ function DI.prepare_pushforward( contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) - single_threaded_prep = DI.prepare_pushforward( + single_threaded_prep = DI.prepare_pushforward_nokwarg( f!, y, single_threaded(backend), x, tx, contexts... ) return PolyesterForwardDiffTwoArgPushforwardPrep(_sig, single_threaded_prep) @@ -90,11 +90,11 @@ struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} single_threaded_prep::P end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_derivative( + single_threaded_prep = DI.prepare_derivative_nokwarg( strict, f!, y, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffTwoArgDerivativePrep(_sig, single_threaded_prep) @@ -166,7 +166,7 @@ struct PolyesterForwardDiffTwoArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPre single_threaded_prep::P end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, @@ -180,7 +180,7 @@ function DI.prepare_jacobian( else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_jacobian( + single_threaded_prep = DI.prepare_jacobian_nokwarg( strict, f!, y, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffTwoArgJacobianPrep(_sig, chunk, single_threaded_prep) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 13b673ac5..c9d10490a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -1,6 +1,6 @@ ## Pullback -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) @@ -79,7 +79,7 @@ struct ReverseDiffGradientPrep{SIG,C,T} <: DI.GradientPrep{SIG} tape::T end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) @@ -143,7 +143,7 @@ end ### With contexts -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -216,7 +216,7 @@ struct ReverseDiffOneArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} tape::T end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) @@ -280,7 +280,7 @@ end ### With contexts -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -354,11 +354,11 @@ struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.Hessia hessian_tape::HT end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) - gradient_prep = DI.prepare_gradient(strict, f, backend, x) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x) if compile hessian_tape = ReverseDiff.compile(HessianTape(f, x)) return ReverseDiffHessianPrep(_sig, gradient_prep, nothing, hessian_tape) @@ -412,11 +412,11 @@ end ### With contexts -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) hessian_config = HessianConfig(x) return ReverseDiffHessianPrep(_sig, gradient_prep, hessian_config, nothing) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 531212fe5..40553d456 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -1,6 +1,6 @@ ## Pullback -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!, y, @@ -139,7 +139,7 @@ struct ReverseDiffTwoArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} tape::T end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f!, y, backend, x; strict) @@ -205,7 +205,7 @@ end ### With contexts -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 1084583f2..8efeb3659 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -25,7 +25,7 @@ SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result) ## Hessian, one argument -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) @@ -62,8 +62,12 @@ function _prepare_sparse_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = DI.prepare_hvp(strict, f, dense_backend, x, batched_seeds[1], contexts...) - gradient_prep = DI.prepare_gradient(strict, f, DI.inner(dense_backend), x, contexts...) + hvp_prep = DI.prepare_hvp_nokwarg( + strict, f, dense_backend, x, batched_seeds[1], contexts... + ) + gradient_prep = DI.prepare_gradient_nokwarg( + strict, f, DI.inner(dense_backend), x, contexts... + ) return SparseHessianPrep( _sig, batch_size_settings, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 85eac3b0f..c5edc8cb2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -36,7 +36,7 @@ struct PullbackSparseJacobianPrep{ pullback_prep::E end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) @@ -45,7 +45,7 @@ function DI.prepare_jacobian( return _prepare_sparse_jacobian_aux(strict, perf, y, (f,), backend, x, contexts...) end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) @@ -108,7 +108,7 @@ function _prepare_sparse_jacobian_aux_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] - pushforward_prep = DI.prepare_pushforward( + pushforward_prep = DI.prepare_pushforward_nokwarg( strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PushforwardSparseJacobianPrep( @@ -142,7 +142,7 @@ function _prepare_sparse_jacobian_aux_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - pullback_prep = DI.prepare_pullback( + pullback_prep = DI.prepare_pullback_nokwarg( strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PullbackSparseJacobianPrep( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index c5d0d2847..42b4970db 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -27,7 +27,7 @@ struct MixedModeSparseJacobianPrep{ pullback_prep::Er end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoSparse{<:DI.MixedMode}, @@ -38,7 +38,7 @@ function DI.prepare_jacobian( return _prepare_mixed_sparse_jacobian_aux(strict, y, (f,), backend, x, contexts...) end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!::F, y, @@ -127,7 +127,7 @@ function _prepare_mixed_sparse_jacobian_aux_aux( ntuple(b -> similar(x), Val(Br)) for _ in batched_seeds_reverse ] - pushforward_prep = DI.prepare_pushforward( + pushforward_prep = DI.prepare_pushforward_nokwarg( strict, f_or_f!y..., DI.forward_backend(dense_backend), @@ -135,7 +135,7 @@ function _prepare_mixed_sparse_jacobian_aux_aux( batched_seeds_forward[1], contexts...; ) - pullback_prep = DI.prepare_pullback( + pullback_prep = DI.prepare_pullback_nokwarg( strict, f_or_f!y..., DI.reverse_backend(dense_backend), diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 105d8a6a1..48314542d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -6,7 +6,7 @@ struct SymbolicsOneArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} pf_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) @@ -94,7 +94,7 @@ struct SymbolicsOneArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -168,7 +168,7 @@ struct SymbolicsOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} grad_exe!::E1! end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -238,7 +238,7 @@ struct SymbolicsOneArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, @@ -316,7 +316,7 @@ struct SymbolicsOneArgHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} hess_exe!::E2! end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, @@ -336,7 +336,9 @@ function DI.prepare_hessian( res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false)) (hess_exe, hess_exe!) = res - gradient_prep = DI.prepare_gradient(strict, f, dense_ad(backend), x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg( + strict, f, dense_ad(backend), x, contexts... + ) return SymbolicsOneArgHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end @@ -405,7 +407,7 @@ struct SymbolicsOneArgHVPPrep{SIG,G,E2,E2!} <: DI.HVPPrep{SIG} hvp_exe!::E2! end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) @@ -422,7 +424,7 @@ function DI.prepare_hvp( ) (hvp_exe, hvp_exe!) = res - gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) return SymbolicsOneArgHVPPrep(_sig, gradient_prep, hvp_exe, hvp_exe!) end @@ -497,7 +499,7 @@ struct SymbolicsOneArgSecondDerivativePrep{SIG,D,E1,E1!} <: DI.SecondDerivativeP der2_exe!::E1! end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -512,7 +514,7 @@ function DI.prepare_second_derivative( elseif res isa RuntimeGeneratedFunction res, nothing end - derivative_prep = DI.prepare_derivative(strict, f, backend, x, contexts...) + derivative_prep = DI.prepare_derivative_nokwarg(strict, f, backend, x, contexts...) return SymbolicsOneArgSecondDerivativePrep(_sig, derivative_prep, der2_exe, der2_exe!) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 58623720e..5237a64d4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -6,7 +6,7 @@ struct SymbolicsTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} pushforward_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -104,7 +104,7 @@ struct SymbolicsTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -182,7 +182,7 @@ struct SymbolicsTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index aaeae6cc7..d3e05294d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -15,7 +15,7 @@ struct TrackerPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} pb::PB end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoTracker, @@ -91,7 +91,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index c7a246211..dc400b03e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -27,7 +27,7 @@ struct ZygotePullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} pb::PB end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) @@ -98,7 +98,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -138,7 +138,7 @@ end ## Jacobian -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -187,11 +187,11 @@ struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG} fd_prep::P end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - fd_prep = DI.prepare_hvp( + fd_prep = DI.prepare_hvp_nokwarg( strict, f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) return ZygoteHVPPrep(_sig, fd_prep) @@ -265,7 +265,7 @@ end ## Hessian -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/src/fallbacks/change_prep.jl b/DifferentiationInterface/src/fallbacks/change_prep.jl index 492813391..504d156c5 100644 --- a/DifferentiationInterface/src/fallbacks/change_prep.jl +++ b/DifferentiationInterface/src/fallbacks/change_prep.jl @@ -20,8 +20,10 @@ for op in [ end val_and_op! = Symbol(val_and_op, "!") prep_op = Symbol("prepare_", op) + prep_op_nokwarg = Symbol("prepare_", op, "_nokwarg") prep_op! = Symbol("prepare!_", op) prep_op_same_point = Symbol("prepare_", op, "_same_point") + prep_op_same_point_nokwarg = Symbol("prepare_", op, "_same_point_nokwarg") P = if op == :derivative DerivativePrep elseif op == :gradient @@ -46,7 +48,7 @@ for op in [ f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} check_prep(f, old_prep, backend, x, contexts...) - return $prep_op(is_strict(old_prep), f, backend, x, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end op == :gradient && continue # 2-arg @@ -54,7 +56,7 @@ for op in [ f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} check_prep(f!, y, old_prep, backend, x, contexts...) - return $prep_op(is_strict(old_prep), f!, y, backend, x, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...) end elseif op in (:second_derivative, :hessian) @@ -63,7 +65,7 @@ for op in [ f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} check_prep(f, old_prep, backend, x, contexts...) - return $prep_op(is_strict(old_prep), f, backend, x, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end elseif op in (:pushforward, :pullback, :hvp) @@ -77,7 +79,7 @@ for op in [ contexts::Vararg{Context,C}; ) where {F,C} check_prep(f, old_prep, backend, x, seed, contexts...) - return $prep_op(is_strict(old_prep), f, backend, x, seed, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f, backend, x, seed, contexts...) end @eval function $prep_op_same_point( f::F, @@ -90,7 +92,7 @@ for op in [ check_prep(f, prep, backend, x, seed, contexts...) return prep end - @eval function $prep_op_same_point( + @eval function $prep_op_same_point_nokwarg( strict::Val, f::F, backend::AbstractADType, @@ -98,7 +100,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}; ) where {F,C} - prep = $prep_op(strict, f, backend, x, seed, contexts...) + prep = $prep_op_nokwarg(strict, f, backend, x, seed, contexts...) return $prep_op_same_point(f, prep, backend, x, seed, contexts...) end op == :hvp && continue @@ -113,7 +115,9 @@ for op in [ contexts::Vararg{Context,C}, ) where {F,C} check_prep(f!, y, old_prep, backend, x, seed, contexts...) - return $prep_op(is_strict(old_prep), f!, y, backend, x, seed, contexts...) + return $prep_op_nokwarg( + is_strict(old_prep), f!, y, backend, x, seed, contexts... + ) end @eval function $prep_op_same_point( f!::F, @@ -127,7 +131,7 @@ for op in [ check_prep(f!, y, prep, backend, x, seed, contexts...) return prep end - @eval function $prep_op_same_point( + @eval function $prep_op_same_point_nokwarg( strict::Val, f!::F, y, @@ -136,7 +140,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}; ) where {F,C} - prep = $prep_op(strict, f!, y, backend, x, seed, contexts...) + prep = $prep_op_nokwarg(strict, f!, y, backend, x, seed, contexts...) return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) end end diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index 81c7b2bd0..f7dc633d6 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -20,8 +20,10 @@ for op in [ end val_and_op! = Symbol(val_and_op, "!") prep_op = Symbol("prepare_", op) + prep_op_nokwarg = Symbol("prepare_", op, "_nokwarg") prep_op! = Symbol("prepare!_", op) prep_op_same_point = Symbol("prepare_", op, "_same_point") + prep_op_same_point_nokwarg = Symbol("prepare_", op, "_same_point_nokwarg") P = if op == :derivative DerivativePrep elseif op == :gradient @@ -45,25 +47,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op!(f, result, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result, prep, backend, x, contexts...) end op == :gradient && continue @@ -71,25 +73,25 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $op(f!, y, prep, backend, x, contexts...) end @eval function $op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $op!(f!, y, result, prep, backend, x, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $val_and_op(f!, y, prep, backend, x, contexts...) end @eval function $val_and_op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $val_and_op!(f!, y, result, prep, backend, x, contexts...) end @@ -98,25 +100,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op!(f, result2, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) end @@ -124,7 +126,7 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $op(f, prep, backend, x, tang, contexts...) end @eval function $op!( @@ -135,13 +137,13 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $op!(f, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $val_and_op(f, prep, backend, x, tang, contexts...) end @@ -154,7 +156,7 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $val_and_op!(f, result, prep, backend, x, tang, contexts...) end elseif op == :hvp @@ -167,7 +169,7 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $val_and_op!( f, result1, result2, prep, backend, x, tang, contexts... ) @@ -179,7 +181,7 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $op(f!, y, prep, backend, x, tang, contexts...) end @eval function $op!( @@ -191,13 +193,13 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $op!(f!, y, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $val_and_op(f!, y, prep, backend, x, tang, contexts...) end @eval function $val_and_op!( @@ -209,7 +211,7 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $val_and_op!(f!, y, result, prep, backend, x, tang, contexts...) end end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index c8753432d..74f90e91e 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -6,8 +6,16 @@ $(docstring_prepare("derivative"; inplace=true)) """ -function prepare_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_derivative(strict, args...) +function prepare_derivative( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) +) where {F,C} + return prepare_derivative_nokwarg(strict, f, backend, x, contexts...) +end + +function prepare_derivative( + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) +) where {F,C} + return prepare_derivative_nokwarg(strict, f!, y, backend, x, contexts...) end """ @@ -65,19 +73,21 @@ struct PushforwardDerivativePrep{SIG,E<:PushforwardPrep} <: DerivativePrep{SIG} pushforward_prep::E end -function prepare_derivative( +function prepare_derivative_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) - pushforward_prep = prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) + pushforward_prep = prepare_pushforward_nokwarg( + strict, f, backend, x, (one(x),), contexts... + ) return PushforwardDerivativePrep(_sig, pushforward_prep) end -function prepare_derivative( +function prepare_derivative_nokwarg( strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, contexts...; strict) - pushforward_prep = prepare_pushforward( + pushforward_prep = prepare_pushforward_nokwarg( strict, f!, y, backend, x, (one(x),), contexts... ) return PushforwardDerivativePrep(_sig, pushforward_prep) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 8ddfb7137..86b29270e 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -5,8 +5,10 @@ $(docstring_prepare("gradient")) """ -function prepare_gradient(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_gradient(strict, args...) +function prepare_gradient( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) +) where {F,C} + return prepare_gradient_nokwarg(strict, f, backend, x, contexts...) end """ @@ -60,12 +62,14 @@ struct PullbackGradientPrep{SIG,Y,E<:PullbackPrep} <: GradientPrep{SIG} pullback_prep::E end -function prepare_gradient( +function prepare_gradient_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? - pullback_prep = prepare_pullback(strict, f, backend, x, (one(typeof(y)),), contexts...) + pullback_prep = prepare_pullback_nokwarg( + strict, f, backend, x, (one(typeof(y)),), contexts... + ) return PullbackGradientPrep(_sig, y, pullback_prep) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 201cbea23..20e6b3048 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -6,8 +6,16 @@ $(docstring_prepare("jacobian"; inplace=true)) """ -function prepare_jacobian(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_jacobian(strict, args...) +function prepare_jacobian( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) +) where {F,C} + return prepare_jacobian_nokwarg(strict, f, backend, x, contexts...) +end + +function prepare_jacobian( + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) +) where {F,C} + return prepare_jacobian_nokwarg(strict, f!, y, backend, x, contexts...) end """ @@ -90,7 +98,7 @@ struct PullbackJacobianPrep{ pullback_prep::E end -function prepare_jacobian( +function prepare_jacobian_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} y = f(x, map(unwrap, contexts)...) @@ -107,7 +115,7 @@ function prepare_jacobian( ) end -function prepare_jacobian( +function prepare_jacobian_nokwarg( strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} perf = pushforward_performance(backend) @@ -140,7 +148,7 @@ function _prepare_jacobian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] - pushforward_prep = prepare_pushforward( + pushforward_prep = prepare_pushforward_nokwarg( strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... ) return PushforwardJacobianPrep( @@ -165,7 +173,7 @@ function _prepare_jacobian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - pullback_prep = prepare_pullback( + pullback_prep = prepare_pullback_nokwarg( strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... ) return PullbackJacobianPrep( diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 63afa0656..e50684932 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -6,8 +6,27 @@ $(docstring_prepare("pullback"; inplace=true)) """ -function prepare_pullback(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pullback(strict, args...) +function prepare_pullback( + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pullback_nokwarg(strict, f, backend, x, ty, contexts...) +end + +function prepare_pullback( + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pullback_nokwarg(strict, f!, y, backend, x, ty, contexts...) end """ @@ -24,8 +43,27 @@ function prepare!_pullback end $(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ -function prepare_pullback_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pullback_same_point(strict, args...) +function prepare_pullback_same_point( + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pullback_same_point_nokwarg(strict, f, backend, x, ty, contexts...) +end + +function prepare_pullback_same_point( + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pullback_same_point_nokwarg(strict, f!, y, backend, x, ty, contexts...) end """ @@ -94,7 +132,7 @@ struct PushforwardPullbackPrep{SIG,E} <: PullbackPrep{SIG} pushforward_prep::E end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pullback_aux( @@ -102,7 +140,7 @@ function prepare_pullback( ) end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f!::F, y, @@ -127,7 +165,9 @@ function _prepare_pullback_aux( ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(strict, f, backend, x, (dx,), contexts...) + pushforward_prep = prepare_pushforward_nokwarg( + strict, f, backend, x, (dx,), contexts... + ) return PushforwardPullbackPrep(_sig, pushforward_prep) end @@ -143,7 +183,9 @@ function _prepare_pullback_aux( ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(strict, f!, y, backend, x, (dx,), contexts...) + pushforward_prep = prepare_pushforward_nokwarg( + strict, f!, y, backend, x, (dx,), contexts... + ) return PushforwardPullbackPrep(_sig, pushforward_prep) end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 883681031..1cb510a5e 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -6,8 +6,27 @@ $(docstring_prepare("pushforward"; inplace=true)) """ -function prepare_pushforward(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pushforward(strict, args...) +function prepare_pushforward( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pushforward_nokwarg(strict, f, backend, x, tx, contexts...) +end + +function prepare_pushforward( + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pushforward_nokwarg(strict, f!, y, backend, x, tx, contexts...) end """ @@ -24,8 +43,29 @@ function prepare!_pushforward end $(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ -function prepare_pushforward_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pushforward_same_point(strict, args...) +function prepare_pushforward_same_point( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pushforward_same_point_nokwarg(strict, f, backend, x, tx, contexts...) +end + +function prepare_pushforward_same_point( + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_pushforward_same_point_nokwarg( + strict, f!, y, backend, x, tx, contexts... + ) end """ @@ -94,7 +134,7 @@ struct PullbackPushforwardPrep{SIG,E} <: PushforwardPrep{SIG} pullback_prep::E end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pushforward_aux( @@ -102,7 +142,7 @@ function prepare_pushforward( ) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -128,7 +168,7 @@ function _prepare_pushforward_aux( _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(strict, f, backend, x, (dy,), contexts...) + pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end @@ -144,7 +184,7 @@ function _prepare_pushforward_aux( ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(strict, f!, y, backend, x, (dy,), contexts...) + pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 35c21a71e..00dd6b445 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -41,7 +41,7 @@ struct FromPrimitivePushforwardPrep{SIG,E<:PushforwardPrep} <: PushforwardPrep{S pushforward_prep::E end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoForwardFromPrimitive, @@ -50,11 +50,13 @@ function prepare_pushforward( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) - primitive_prep = prepare_pushforward(strict, f, backend.backend, x, tx, contexts...) + primitive_prep = prepare_pushforward_nokwarg( + strict, f, backend.backend, x, tx, contexts... + ) return FromPrimitivePushforwardPrep(_sig, primitive_prep) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -64,7 +66,9 @@ function prepare_pushforward( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) - primitive_prep = prepare_pushforward(strict, f!, y, backend.backend, x, tx, contexts...) + primitive_prep = prepare_pushforward_nokwarg( + strict, f!, y, backend.backend, x, tx, contexts... + ) return FromPrimitivePushforwardPrep(_sig, primitive_prep) end @@ -158,7 +162,7 @@ struct FromPrimitivePullbackPrep{SIG,E<:PullbackPrep} <: PullbackPrep{SIG} pullback_prep::E end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoReverseFromPrimitive, @@ -167,11 +171,13 @@ function prepare_pullback( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) - primitive_prep = prepare_pullback(strict, f, backend.backend, x, ty, contexts...) + primitive_prep = prepare_pullback_nokwarg( + strict, f, backend.backend, x, ty, contexts... + ) return FromPrimitivePullbackPrep(_sig, primitive_prep) end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f!::F, y, @@ -181,7 +187,9 @@ function prepare_pullback( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) - primitive_prep = prepare_pullback(strict, f!, y, backend.backend, x, ty, contexts...) + primitive_prep = prepare_pullback_nokwarg( + strict, f!, y, backend.backend, x, ty, contexts... + ) return FromPrimitivePullbackPrep(_sig, primitive_prep) end diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index bfa6656f4..8481cad97 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -36,7 +36,7 @@ function threshold_batchsize( return AutoSimpleFiniteDiff(backend.ε; chunksize) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoSimpleFiniteDiff, @@ -48,7 +48,7 @@ function prepare_pushforward( return NoPushforwardPrep(_sig) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index 1b3c6c0eb..e74449f3c 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -20,14 +20,14 @@ ADTypes.mode(::AutoZeroForward) = ForwardMode() check_available(::AutoZeroForward) = true inplace_support(::AutoZeroForward) = InPlaceSupported() -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -118,14 +118,14 @@ ADTypes.mode(::AutoZeroReverse) = ReverseMode() check_available(::AutoZeroReverse) = true inplace_support(::AutoZeroReverse) = InPlaceSupported() -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) return NoPullbackPrep(_sig) end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 290adb525..21d82cc17 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -5,8 +5,10 @@ $(docstring_prepare("hessian")) """ -function prepare_hessian(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_hessian(strict, args...) +function prepare_hessian( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) +) where {F,C} + return prepare_hessian_nokwarg(strict, f, backend, x, contexts...) end """ @@ -70,7 +72,7 @@ struct HVPGradientHessianPrep{ gradient_prep::E1 end -function prepare_hessian( +function prepare_hessian_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} # type-unstable @@ -94,8 +96,8 @@ function _prepare_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = prepare_hvp(strict, f, backend, x, batched_seeds[1], contexts...) - gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) + hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, batched_seeds[1], contexts...) + gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...) return HVPGradientHessianPrep( _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep ) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 8caeb2b64..3edb30e93 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -5,8 +5,15 @@ $(docstring_prepare("hvp")) """ -function prepare_hvp(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_hvp(strict, args...) +function prepare_hvp( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_hvp_nokwarg(strict, f, backend, x, tx, contexts...) end """ @@ -21,8 +28,15 @@ function prepare!_hvp end $(docstring_prepare("hvp"; samepoint=true)) """ -function prepare_hvp_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_hvp_same_point(strict, args...) +function prepare_hvp_same_point( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict=Val(false), +) where {F,C} + return prepare_hvp_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end """ @@ -61,7 +75,7 @@ $(docstring_preparation_hint("hvp"; same_point=true)) """ function gradient_and_hvp! end -function prepare_hvp( +function prepare_hvp_nokwarg( strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_hvp_aux( @@ -105,11 +119,11 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pushforward( + prepare_pushforward_nokwarg( strict, shuffled_gradient!, grad_buffer, @@ -140,7 +154,9 @@ function _prepare_hvp_aux( grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient - inner_gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) + inner_gradient_prep = prepare_gradient_nokwarg( + strict, f, inner(backend), x, contexts... + ) inner_gradient_in_prep = inner_gradient_prep # Outer pushforward new_contexts = ( @@ -157,11 +173,11 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pushforward( + prepare_pushforward_nokwarg( strict, shuffled_gradient!, grad_buffer, @@ -203,8 +219,12 @@ function _prepare_hvp_aux( ) contextso = adapt_eltype.(contexts, Ref(eltype(xo))) contextsoi = adapt_eltype.(contexts, Ref(eltype(xoi))) - inner_gradient_prep = prepare_gradient(strict, f, inner(backend), xo, contextso...) - inner_gradient_in_prep = prepare_gradient(strict, f, inner(backend), xoi, contextsoi...) + inner_gradient_prep = prepare_gradient_nokwarg( + strict, f, inner(backend), xo, contextso... + ) + inner_gradient_in_prep = prepare_gradient_nokwarg( + strict, f, inner(backend), xoi, contextsoi... + ) # Outer pushforward new_contexts = ( FunctionContext(f), @@ -220,11 +240,11 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pushforward( + prepare_pushforward_nokwarg( strict, shuffled_gradient!, grad_buffer, @@ -477,10 +497,10 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - outer_gradient_prep = prepare_gradient( + outer_gradient_prep = prepare_gradient_nokwarg( strict, shuffled_single_pushforward, outer(backend), x, new_contexts... ) - gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) + gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...) return ReverseOverForwardHVPPrep(_sig, outer_gradient_prep, gradient_prep) end @@ -596,11 +616,11 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - outer_pullback_prep = prepare_pullback( + outer_pullback_prep = prepare_pullback_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pullback_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pullback( + prepare_pullback_nokwarg( strict, shuffled_gradient!, grad_buffer, diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index f26f335e6..649de7cb9 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -5,8 +5,10 @@ $(docstring_prepare("second_derivative")) """ -function prepare_second_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_second_derivative(strict, args...) +function prepare_second_derivative( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) +) where {F,C} + return prepare_second_derivative_nokwarg(strict, f, backend, x, contexts...) end """ @@ -59,7 +61,7 @@ struct DerivativeSecondDerivativePrep{SIG,E<:DerivativePrep} <: SecondDerivative outer_derivative_prep::E end -function prepare_second_derivative( +function prepare_second_derivative_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) @@ -67,7 +69,7 @@ function prepare_second_derivative( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - outer_derivative_prep = prepare_derivative( + outer_derivative_prep = prepare_derivative_nokwarg( strict, shuffled_derivative, outer(backend), x, new_contexts... ) return DerivativeSecondDerivativePrep(_sig, outer_derivative_prep) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index f4024afbd..36a296f89 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -90,9 +90,9 @@ end # Derivative x = 1.0 y = [1.0, 1.0] - @test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == + @test DI.overloaded_input_type(prepare_derivative_nokwarg(copy, backend, x)) == ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,1} - @test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) == + @test DI.overloaded_input_type(prepare_derivative_nokwarg(copyto!, y, backend, x)) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} # Gradient diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index c1b902d14..f1bcd1d25 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -47,7 +47,7 @@ test_differentiation( # Derivative x = 1.0 - @test_skip DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == + @test_skip DI.overloaded_input_type(prepare_derivative_nokwarg(copy, backend, x)) == ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Gradient From efc8fcff841a68142ff0be5f586d71f9a4872f49 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 09:17:45 +0100 Subject: [PATCH 02/14] DOcs --- DifferentiationInterface/docs/src/dev_guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/docs/src/dev_guide.md b/DifferentiationInterface/docs/src/dev_guide.md index a24bb7eb4..3fa1b66b9 100644 --- a/DifferentiationInterface/docs/src/dev_guide.md +++ b/DifferentiationInterface/docs/src/dev_guide.md @@ -23,7 +23,7 @@ Most operators have 4 variants, which look like this in the first order: `operat To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above). For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`). -The method `prepare_operator_kwarg` must output a `prep` object of the correct type. +The method `prepare_operator_nokwarg` must output a `prep` object of the correct type. For instance, `prepare_gradient(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref). Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`. Otherwise, define a custom struct like `MyGradientPrep{SIG} <: DifferentiationInterface.GradientPrep{SIG}` and put the necessary storage in there. From a06f5dedf3134e9a905b21c3ace632577e1c82eb Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 09:24:17 +0100 Subject: [PATCH 03/14] Typing --- DifferentiationInterface/src/first_order/derivative.jl | 9 +++++++-- DifferentiationInterface/src/first_order/gradient.jl | 2 +- DifferentiationInterface/src/first_order/jacobian.jl | 9 +++++++-- DifferentiationInterface/src/first_order/pullback.jl | 8 ++++---- DifferentiationInterface/src/first_order/pushforward.jl | 8 ++++---- DifferentiationInterface/src/second_order/hessian.jl | 2 +- DifferentiationInterface/src/second_order/hvp.jl | 4 ++-- .../src/second_order/second_derivative.jl | 2 +- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 4 ++-- DifferentiationInterface/test/Back/ReverseDiff/test.jl | 2 +- 10 files changed, 30 insertions(+), 20 deletions(-) diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 74f90e91e..69075fc3a 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -7,13 +7,18 @@ $(docstring_prepare("derivative"; inplace=true)) """ function prepare_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} return prepare_derivative_nokwarg(strict, f, backend, x, contexts...) end function prepare_derivative( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) + f!::F, + y, + backend::AbstractADType, + x, + contexts::Vararg{Context,C}; + strict::Val=Val(false), ) where {F,C} return prepare_derivative_nokwarg(strict, f!, y, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 86b29270e..d15344b0d 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -6,7 +6,7 @@ $(docstring_prepare("gradient")) """ function prepare_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} return prepare_gradient_nokwarg(strict, f, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 20e6b3048..24932b289 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -7,13 +7,18 @@ $(docstring_prepare("jacobian"; inplace=true)) """ function prepare_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} return prepare_jacobian_nokwarg(strict, f, backend, x, contexts...) end function prepare_jacobian( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) + f!::F, + y, + backend::AbstractADType, + x, + contexts::Vararg{Context,C}; + strict::Val=Val(false), ) where {F,C} return prepare_jacobian_nokwarg(strict, f!, y, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index e50684932..a22277a6f 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -12,7 +12,7 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pullback_nokwarg(strict, f, backend, x, ty, contexts...) end @@ -24,7 +24,7 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pullback_nokwarg(strict, f!, y, backend, x, ty, contexts...) end @@ -49,7 +49,7 @@ function prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pullback_same_point_nokwarg(strict, f, backend, x, ty, contexts...) end @@ -61,7 +61,7 @@ function prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pullback_same_point_nokwarg(strict, f!, y, backend, x, ty, contexts...) end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 1cb510a5e..601461d6a 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -12,7 +12,7 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pushforward_nokwarg(strict, f, backend, x, tx, contexts...) end @@ -24,7 +24,7 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pushforward_nokwarg(strict, f!, y, backend, x, tx, contexts...) end @@ -49,7 +49,7 @@ function prepare_pushforward_same_point( x, tx::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pushforward_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end @@ -61,7 +61,7 @@ function prepare_pushforward_same_point( x, tx::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_pushforward_same_point_nokwarg( strict, f!, y, backend, x, tx, contexts... diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 21d82cc17..dbdd16da6 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -6,7 +6,7 @@ $(docstring_prepare("hessian")) """ function prepare_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} return prepare_hessian_nokwarg(strict, f, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 3edb30e93..11faa2548 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -11,7 +11,7 @@ function prepare_hvp( x, tx::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_hvp_nokwarg(strict, f, backend, x, tx, contexts...) end @@ -34,7 +34,7 @@ function prepare_hvp_same_point( x, tx::NTuple, contexts::Vararg{Context,C}; - strict=Val(false), + strict::Val=Val(false), ) where {F,C} return prepare_hvp_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 649de7cb9..09e81fbac 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -6,7 +6,7 @@ $(docstring_prepare("second_derivative")) """ function prepare_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict=Val(false) + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} return prepare_second_derivative_nokwarg(strict, f, backend, x, contexts...) end diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 36a296f89..f4024afbd 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -90,9 +90,9 @@ end # Derivative x = 1.0 y = [1.0, 1.0] - @test DI.overloaded_input_type(prepare_derivative_nokwarg(copy, backend, x)) == + @test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,1} - @test DI.overloaded_input_type(prepare_derivative_nokwarg(copyto!, y, backend, x)) == + @test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} # Gradient diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index f1bcd1d25..c1b902d14 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -47,7 +47,7 @@ test_differentiation( # Derivative x = 1.0 - @test_skip DI.overloaded_input_type(prepare_derivative_nokwarg(copy, backend, x)) == + @test_skip DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Gradient From 223b36287bfa493c3e1d8f4f9412f22a68bc2b4d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 09:41:20 +0100 Subject: [PATCH 04/14] Fix --- .github/workflows/Test.yml | 2 +- .../DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index f91ce0853..13de31d96 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true # TODO: toggle + fail-fast: false # TODO: toggle matrix: version: - "1.10" diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 65baba690..78609ea0e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -16,7 +16,7 @@ function DI.prepare_pushforward_nokwarg( ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward_nokwarg( - f!, y, single_threaded(backend), x, tx, contexts... + strict, f!, y, single_threaded(backend), x, tx, contexts... ) return PolyesterForwardDiffTwoArgPushforwardPrep(_sig, single_threaded_prep) end From c83ba21b459a4ec41300eac2443c6b493d25af3e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 10:15:33 +0100 Subject: [PATCH 05/14] Toggle fail fast --- .github/workflows/Test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 13de31d96..f91ce0853 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - "1.10" From 4dd20e16039519dcdccc3e91f1e2ae6aee00c180 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 10:31:52 +0100 Subject: [PATCH 06/14] feat: recursive similar for caches --- .../DifferentiationInterfaceForwardDiffExt/utils.jl | 4 ++-- DifferentiationInterface/src/utils/linalg.jl | 12 ++++++++++++ .../test/Core/Internals/linalg.jl | 9 +++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 DifferentiationInterface/test/Core/Internals/linalg.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 2dd4409cc..44afdd9ce 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -89,7 +89,7 @@ function _translate( end function _translate(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) - return similar(c0, D) + return DI.recursive_similar(c0, D) end function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} @@ -106,7 +106,7 @@ function _translate_toprep( end function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) - return similar(c0, D) + return DI.recursive_similar(c0, D) end function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index fcabdb143..b7dd4a42a 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -10,3 +10,15 @@ At the moment, this only returns `false` for `StaticArrays.SArray`. """ ismutable_array(::Type) = true ismutable_array(x) = ismutable_array(typeof(x)) + +""" + recursive_similar(x, T) + +Apply `similar(_, T)` recursively to `x` or its components. + +Works if `x` is an `AbstractArray` or a (nested) `NTuple` / `NamedTuple` of `AbstractArray`s. +""" +recursive_similar(x::AbstractArray, ::Type{T}) where {T} = similar(x, T) +function recursive_similar(x::Union{Tuple,NamedTuple}, ::Type{T}) where {T} + return map(xi -> recursive_similar(xi, T), x) +end diff --git a/DifferentiationInterface/test/Core/Internals/linalg.jl b/DifferentiationInterface/test/Core/Internals/linalg.jl new file mode 100644 index 000000000..03798da87 --- /dev/null +++ b/DifferentiationInterface/test/Core/Internals/linalg.jl @@ -0,0 +1,9 @@ +using DifferentiationInterface: recursive_similar +using Test + +@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32} +@test recursive_similar((ones(Int, 2), ones(Bool, 3, 4)), Float32) isa + Tuple{Vector{Float32},Matrix{Float32}} +@test recursive_similar((a=ones(Int, 2), b=(ones(Bool, 3, 4),)), Float32) isa + @NamedTuple{a::Vector{Float32}, b::Tuple{Matrix{Float32}}} +@test_throws MethodError recursive_similar(1, Float32) From 8cfc5016ef68c94d92bd6924a3c7c986fa6998e3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 13:11:34 +0100 Subject: [PATCH 07/14] Recursive caches --- .github/workflows/Test.yml | 2 +- .../utils.jl | 4 +-- ...ntiationInterfaceFastDifferentiationExt.jl | 5 ++++ ...ionInterfaceSparseConnectivityTracerExt.jl | 25 ++++++++----------- .../DifferentiationInterfaceZygoteExt.jl | 5 +++- DifferentiationInterface/src/utils/context.jl | 11 +++----- .../test/Core/SimpleFiniteDiff/test.jl | 6 +++++ .../src/scenarios/modify.jl | 16 +++++++----- 8 files changed, 42 insertions(+), 32 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index f91ce0853..13de31d96 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true # TODO: toggle + fail-fast: false # TODO: toggle matrix: version: - "1.10" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 8b1550532..891964579 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -53,9 +53,7 @@ force_annotation(f::F) where {F} = Const(f) return Const(DI.unwrap(c)) end -@inline function _translate( - backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache -) where {B} +@inline function _translate(backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.Cache) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) else diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 430cc95f9..e5e979f82 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -23,6 +23,11 @@ myvec(x::AbstractArray) = vec(x) variablize(::Number, name::Symbol) = only(make_variables(name)) variablize(x::AbstractArray, name::Symbol) = make_variables(name, size(x)...) +function variablize(x::Union{Tuple,NamedTuple}, name::Symbol) + return map(x) do xi + variablize(xi, gensym()) # TODO: fix symbol? + end +end function variablize(contexts::NTuple{C,DI.Context}) where {C} map(enumerate(contexts)) do (k, c) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl index f5a315c3a..6bc11a1bc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl @@ -5,26 +5,23 @@ import DifferentiationInterface as DI using SparseConnectivityTracer: TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer -@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c) -@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray}) - return jacobian_buffer(DI.unwrap(c), detector) +@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c) +@inline function _translate(::Type{T}, c::DI.Cache) where {T} + return DI.recursive_similar(DI.unwrap(c), T) end -function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} +function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C} + T = eltype(jacobian_buffer(x, detector)) new_contexts = map(contexts) do c - _jacobian_translate(detector, c) + _translate(T, c) end return new_contexts end -@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c) -@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray}) - return hessian_buffer(DI.unwrap(c), detector) -end - -function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} +function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C} + T = eltype(hessian_buffer(x, detector)) new_contexts = map(contexts) do c - _hessian_translate(detector, c) + _translate(T, c) end return new_contexts end @@ -35,7 +32,7 @@ function DI.jacobian_sparsity_with_contexts( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - contexts_tracer = jacobian_translate(detector, contexts...) + contexts_tracer = jacobian_translate(detector, x, contexts...) fc = DI.FixTail(f, contexts_tracer...) return jacobian_sparsity(fc, x, detector) end @@ -47,7 +44,7 @@ function DI.jacobian_sparsity_with_contexts( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - contexts_tracer = jacobian_translate(detector, contexts...) + contexts_tracer = jacobian_translate(detector, x, contexts...) fc! = DI.FixTail(f!, contexts_tracer...) return jacobian_sparsity(fc!, y, x, detector) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index dc400b03e..a5d11a07d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -17,7 +17,10 @@ DI.check_available(::AutoZygote) = true DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() translate(c::DI.Context) = DI.unwrap(c) -translate(c::DI.Cache) = Buffer(DI.unwrap(c)) +translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c)) +function translate(c::DI.Cache{<:Union{NTuple,NamedTuple}}) + return map(translate, map(DI.Cache, DI.unwrap(c))) +end ## Pullback diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 65edbde0f..4d42e4836 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -23,7 +23,6 @@ Abstract supertype for additional context arguments, which can be passed to diff abstract type Context end abstract type GeneralizedConstant <: Context end -abstract type GeneralizedCache <: Context end unwrap(c::Context) = c.data Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) @@ -78,7 +77,7 @@ The initial values present inside the cache do not matter. For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl). !!! warning - Most backends require any `Cache` context to be an `AbstractArray`. + Some backends require any `Cache` context to be an `AbstractArray or a (named) tuple of `AbstractArray`s. # Example @@ -97,7 +96,7 @@ julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2))) 1.0 ```` """ -struct Cache{T} <: GeneralizedCache +struct Cache{T} <: Context data::T end @@ -114,12 +113,10 @@ struct BackendContext{T} <: GeneralizedConstant data::T end -struct PrepContext{T} <: GeneralizedCache +struct PrepContext{T} <: Context data::T end -struct UnknownContext <: Context end - ## Context manipulation struct Rewrap{C,T} @@ -146,4 +143,4 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} end adapt_eltype(c::Constant, ::Type) = c -adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T)) +adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T)) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index d716f136b..111a15f88 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -7,6 +7,12 @@ using DifferentiationInterface: using SparseMatrixColorings using Test +test_differentiation( + AutoSimpleFiniteDiff(), + default_scenarios(; include_normal=false, include_cachified=true); + logging=true, +) + LOGGING = get(ENV, "CI", "false") == "false" backends = [ # diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 6f306dc45..a6b37c81c 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -147,7 +147,7 @@ end """ constantify(scen::Scenario) -Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument `a` by which the output is multiplied. +Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument by which the output is multiplied. The output and result fields are updated accordingly. """ function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} @@ -178,7 +178,8 @@ end Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") -function (sc::StoreInCache{:out})(x, y_cache) +function (sc::StoreInCache{:out})(x, y_cache_tup) + y_cache = y_cache_tup.cache y = sc.f(x) if y isa Number y_cache[1] = y @@ -189,7 +190,8 @@ function (sc::StoreInCache{:out})(x, y_cache) end end -function (sc::StoreInCache{:in})(y, x, y_cache) +function (sc::StoreInCache{:in})(y, x, y_cache_tup) + y_cache = y_cache_tup.cache sc.f(y_cache, x) copyto!(y, y_cache) return nothing @@ -198,16 +200,18 @@ end """ cachify(scen::Scenario) -Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument `a` to store the result before it is returned. +Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument to store the result before it is returned. + +If `tup=true` the cache is a tuple of arrays, otherwise just an array. """ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f,) = scen @assert isempty(scen.contexts) cache_f = StoreInCache{pl_fun}(f) y_cache = if scen.y isa Number - [myzero(scen.y)] + (; cache=[myzero(scen.y)]) else - mysimilar(scen.y) + (; cache=mysimilar(scen.y)) end return Scenario{op,pl_op,pl_fun}( cache_f; From d3f5dd0b11f6ae8843f12f1a5410a6153f714aae Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 13:13:39 +0100 Subject: [PATCH 08/14] Enzyme --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 891964579..5575cbb3e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -53,7 +53,9 @@ force_annotation(f::F) where {F} = Const(f) return Const(DI.unwrap(c)) end -@inline function _translate(backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.Cache) where {B} +@inline function _translate( + backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext} +) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) else From 133c04c1a5ea9508a72927c6535b5b8088410b65 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 13:14:48 +0100 Subject: [PATCH 09/14] Remove new tests --- DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 111a15f88..d716f136b 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -7,12 +7,6 @@ using DifferentiationInterface: using SparseMatrixColorings using Test -test_differentiation( - AutoSimpleFiniteDiff(), - default_scenarios(; include_normal=false, include_cachified=true); - logging=true, -) - LOGGING = get(ENV, "CI", "false") == "false" backends = [ # From 0e5b11e958446a58ca215e54cccb1ea91e95e644 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 13:20:20 +0100 Subject: [PATCH 10/14] SCT fix --- .../DifferentiationInterfaceSparseConnectivityTracerExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl index 6bc11a1bc..a01a804ef 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl @@ -55,7 +55,7 @@ function DI.hessian_sparsity_with_contexts( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - contexts_tracer = hessian_translate(detector, contexts...) + contexts_tracer = hessian_translate(detector, x, contexts...) fc = DI.FixTail(f, contexts_tracer...) return hessian_sparsity(fc, x, detector) end From 3d3f5517fdebf3e2a0cb3f4649e9927e619ec7c8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 13:35:05 +0100 Subject: [PATCH 11/14] Nesting in test scens --- .../DifferentiationInterfaceZygoteExt.jl | 2 +- .../DifferentiationInterfaceTestJLArraysExt.jl | 3 ++- .../DifferentiationInterfaceTestStaticArraysExt.jl | 5 ++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index a5d11a07d..72763eb6a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -18,7 +18,7 @@ DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() translate(c::DI.Context) = DI.unwrap(c) translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c)) -function translate(c::DI.Cache{<:Union{NTuple,NamedTuple}}) +function translate(c::DI.Cache{<:Union{Tuple,NamedTuple}}) return map(translate, map(DI.Cache, DI.unwrap(c))) end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index 305bee48b..11dca2cc5 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -18,7 +18,8 @@ myjl(x::Number) = x myjl(x::AbstractArray) = jl(x) myjl(x::Tuple) = map(myjl, x) myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x))) -myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x))) +myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x))) +myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x))) myjl(::Nothing) = nothing function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index c52849ad7..a7620b516 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -29,7 +29,10 @@ end mystatic(x::Tuple) = map(mystatic, x) mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x))) -mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x))) +mystatic(x::DI.Cache{<:AbstractArray}) = DI.Cache(mymutablestatic(DI.unwrap(x))) +function mystatic(x::DI.Cache{<:Union{Tuple,NamedTuple}}) + return map(mystatic, map(DI.Cache, DI.unwrap(x))) +end mystatic(::Nothing) = nothing function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} From 7f63ad91eec6c400b7ac5ab261b6b7cfd4addc47 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 15:17:43 +0100 Subject: [PATCH 12/14] More sophisticated testing --- ...ntiationInterfaceFastDifferentiationExt.jl | 5 --- DifferentiationInterface/src/utils/context.jl | 2 +- .../test/Back/Enzyme/test.jl | 2 +- .../test/Back/FiniteDiff/test.jl | 4 ++- .../test/Back/FiniteDifferences/test.jl | 4 ++- .../test/Back/ForwardDiff/test.jl | 5 ++- .../test/Back/Mooncake/test.jl | 4 ++- .../test/Back/PolyesterForwardDiff/test.jl | 4 ++- .../SymbolicBackends/fastdifferentiation.jl | 4 ++- .../test/Back/SymbolicBackends/symbolics.jl | 2 +- .../test/Back/Zygote/test.jl | 4 ++- .../test/Core/SimpleFiniteDiff/test.jl | 4 ++- .../src/scenarios/default.jl | 3 +- .../src/scenarios/modify.jl | 31 +++++++++++++------ .../src/scenarios/sparse.jl | 7 +++-- DifferentiationInterfaceTest/test/standard.jl | 2 +- DifferentiationInterfaceTest/test/weird.jl | 7 ++++- 17 files changed, 63 insertions(+), 31 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index e5e979f82..430cc95f9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -23,11 +23,6 @@ myvec(x::AbstractArray) = vec(x) variablize(::Number, name::Symbol) = only(make_variables(name)) variablize(x::AbstractArray, name::Symbol) = make_variables(name, size(x)...) -function variablize(x::Union{Tuple,NamedTuple}, name::Symbol) - return map(x) do xi - variablize(xi, gensym()) # TODO: fix symbol? - end -end function variablize(contexts::NTuple{C,DI.Context}) where {C} map(enumerate(contexts)) do (k, c) diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 4d42e4836..201c0ae2e 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -77,7 +77,7 @@ The initial values present inside the cache do not matter. For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl). !!! warning - Some backends require any `Cache` context to be an `AbstractArray or a (named) tuple of `AbstractArray`s. + Some backends require any `Cache` context to be an `AbstractArray`, others accept nested (named) tuples of `AbstractArray`s. # Example diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 4fbb40fa0..2aa2c3268 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -55,7 +55,7 @@ end; test_differentiation( backends[2], - default_scenarios(; include_normal=false, include_cachified=true); + default_scenarios(; include_normal=false, include_cachified=true, use_tuples=true); excluded=SECOND_ORDER, logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index aa92743b6..dc111f45f 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -22,7 +22,9 @@ end @testset "Dense" begin test_differentiation( AutoFiniteDiff(), - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=[:second_derivative, :hvp], logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index d512ee8d4..c55ee3bd2 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -19,7 +19,9 @@ end test_differentiation( AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index f4024afbd..0b9ff0d50 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -36,7 +36,10 @@ end test_differentiation( AutoForwardDiff(), default_scenarios(; - include_normal=false, include_batchified=false, include_cachified=true + include_normal=false, + include_batchified=false, + include_cachified=true, + use_tuples=true, ); logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 1bc485cdf..8c9ab839a 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -19,7 +19,9 @@ end test_differentiation( backends, - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 4f38af5b1..34b59d46a 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -28,7 +28,9 @@ end test_differentiation( backends, - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl index c8efbdaa8..db6d2215b 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl @@ -17,7 +17,9 @@ end test_differentiation( AutoFastDifferentiation(), - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=false + ); logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl index 31f8316a0..91625b700 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl @@ -21,7 +21,7 @@ test_differentiation( test_differentiation( AutoSymbolics(), - default_scenarios(; include_normal=false, include_cachified=true); + default_scenarios(; include_normal=false, include_cachified=true, use_tuples=false); excluded=[:jacobian], # TODO: figure out why this fails logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 6b30e924b..882777e20 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -27,7 +27,9 @@ end @testset "Dense" begin test_differentiation( backends, - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=[:second_derivative], logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index d716f136b..34d93c16a 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -86,7 +86,9 @@ end MyAutoSparse.( vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2])) ), - sparse_scenarios(; include_constantified=true, include_cachified=true); + sparse_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); sparsity=true, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index ef9767df4..7a44f7390 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -559,6 +559,7 @@ function default_scenarios(; include_closurified=false, include_constantified=false, include_cachified=false, + use_tuples=false, ) x_ = 0.42 dx_ = 3.14 @@ -635,7 +636,7 @@ function default_scenarios(; include_normal && append!(final_scens, scens) include_closurified && append!(final_scens, closurify(scens)) include_constantified && append!(final_scens, constantify(scens)) - include_cachified && append!(final_scens, cachify(scens)) + include_cachified && append!(final_scens, cachify(scens; use_tuples=use_tuples)) return final_scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index a6b37c81c..0e35771e7 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -178,8 +178,12 @@ end Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") -function (sc::StoreInCache{:out})(x, y_cache_tup) - y_cache = y_cache_tup.cache +(sc::StoreInCache{:out})(x, y_cache::NamedTuple) = sc(x, y_cache.useful_cache) +(sc::StoreInCache{:in})(y, x, y_cache::NamedTuple) = sc(y, x, y_cache.useful_cache) +(sc::StoreInCache{:out})(x, y_cache::Tuple) = sc(x, first(y_cache)) +(sc::StoreInCache{:in})(y, x, y_cache::Tuple) = sc(y, x, first(y_cache)) + +function (sc::StoreInCache{:out})(x, y_cache::AbstractArray) y = sc.f(x) if y isa Number y_cache[1] = y @@ -190,8 +194,7 @@ function (sc::StoreInCache{:out})(x, y_cache_tup) end end -function (sc::StoreInCache{:in})(y, x, y_cache_tup) - y_cache = y_cache_tup.cache +function (sc::StoreInCache{:in})(y, x, y_cache::AbstractArray) sc.f(y_cache, x) copyto!(y, y_cache) return nothing @@ -204,14 +207,22 @@ Return a new `Scenario` identical to `scen` except for the function `f`, which i If `tup=true` the cache is a tuple of arrays, otherwise just an array. """ -function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl_fun} (; f,) = scen @assert isempty(scen.contexts) cache_f = StoreInCache{pl_fun}(f) - y_cache = if scen.y isa Number - (; cache=[myzero(scen.y)]) + if use_tuples + y_cache = if scen.y isa Number + (; useful_cache=[myzero(scen.y)], useless_cache=([myzero(scen.y)],)) + else + (; useful_cache=mysimilar(scen.y), useless_cache=(mysimilar(scen.y),)) + end else - (; cache=mysimilar(scen.y)) + y_cache = if scen.y isa Number + [myzero(scen.y)] + else + mysimilar(scen.y) + end end return Scenario{op,pl_op,pl_fun}( cache_f; @@ -221,7 +232,7 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} contexts=(Cache(y_cache),), res1=scen.res1, res2=scen.res2, - smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller), + smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller; use_tuples), name=isnothing(scen.name) ? nothing : scen.name * " [cachified]", ) end @@ -233,7 +244,7 @@ end closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens) constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens) -cachify(scens::AbstractVector{<:Scenario}) = cachify.(scens) +cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples) function set_smaller( scen::Scenario{op,pl_op,pl_fun}, smaller::Scenario diff --git a/DifferentiationInterfaceTest/src/scenarios/sparse.jl b/DifferentiationInterfaceTest/src/scenarios/sparse.jl index 5f1cad7fa..99ec02140 100644 --- a/DifferentiationInterfaceTest/src/scenarios/sparse.jl +++ b/DifferentiationInterfaceTest/src/scenarios/sparse.jl @@ -325,7 +325,10 @@ end Create a vector of [`Scenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians. """ function sparse_scenarios(; - band_sizes=[5, 10, 20], include_constantified=false, include_cachified=false + band_sizes=[5, 10, 20], + include_constantified=false, + include_cachified=false, + use_tuples=false, ) x_6 = float.(1:6) x_2_3 = float.(reshape(1:6, 2, 3)) @@ -347,6 +350,6 @@ function sparse_scenarios(; final_scens = Scenario[] append!(final_scens, scens) include_constantified && append!(final_scens, constantify(scens)) - include_cachified && append!(final_scens, cachify(scens)) + include_cachified && append!(final_scens, cachify(scens; use_tuples)) return final_scens end diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index f06f3cd75..5a1705e00 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -33,7 +33,7 @@ sparse_backend = AutoSparse( test_differentiation( sparse_backend, - sparse_scenarios(; include_constantified=true); + sparse_scenarios(; include_cachified=true); sparsity=true, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index b2b27a8df..1cb83a7d0 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -28,12 +28,14 @@ gpu_scenarios(; include_closurified=true, include_batchified=true, include_cachified=true, + use_tuples=true, ) static_scenarios(; include_constantified=true, include_closurified=true, include_batchified=true, include_cachified=true, + use_tuples=false, ) ## Weird arrays @@ -54,7 +56,10 @@ test_differentiation(AutoZygote(), gpu_scenarios(); excluded=SECOND_ORDER, loggi test_differentiation( AutoFiniteDiff(), default_scenarios(; - include_normal=false, include_closurified=true, include_cachified=true + include_normal=false, + include_closurified=true, + include_cachified=true, + use_tuples=true, ); excluded=SECOND_ORDER, logging=LOGGING, From 8a0c9fdaa384e8d8bbc2416563966eed23466fcd Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 16:05:42 +0100 Subject: [PATCH 13/14] Fix --- .github/workflows/Test.yml | 2 +- DifferentiationInterface/Project.toml | 2 +- DifferentiationInterfaceTest/src/scenarios/modify.jl | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 13de31d96..f91ce0853 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - "1.10" diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index e29a979c0..1a1f819dd 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.45" +version = "0.6.46" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 0e35771e7..dad79603d 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -183,7 +183,7 @@ Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") (sc::StoreInCache{:out})(x, y_cache::Tuple) = sc(x, first(y_cache)) (sc::StoreInCache{:in})(y, x, y_cache::Tuple) = sc(y, x, first(y_cache)) -function (sc::StoreInCache{:out})(x, y_cache::AbstractArray) +function (sc::StoreInCache{:out})(x, y_cache) y = sc.f(x) if y isa Number y_cache[1] = y @@ -194,7 +194,7 @@ function (sc::StoreInCache{:out})(x, y_cache::AbstractArray) end end -function (sc::StoreInCache{:in})(y, x, y_cache::AbstractArray) +function (sc::StoreInCache{:in})(y, x, y_cache) sc.f(y_cache, x) copyto!(y, y_cache) return nothing From 5a31a6a48b12f7bd6edbfbbd7fcd64844059a898 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 18 Mar 2025 18:01:54 +0100 Subject: [PATCH 14/14] Coverage --- DifferentiationInterfaceTest/src/scenarios/modify.jl | 4 ++-- DifferentiationInterfaceTest/test/standard.jl | 2 +- DifferentiationInterfaceTest/test/weird.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index dad79603d..e991885f8 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -213,9 +213,9 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl cache_f = StoreInCache{pl_fun}(f) if use_tuples y_cache = if scen.y isa Number - (; useful_cache=[myzero(scen.y)], useless_cache=([myzero(scen.y)],)) + (; useful_cache=([myzero(scen.y)],), useless_cache=[myzero(scen.y)]) else - (; useful_cache=mysimilar(scen.y), useless_cache=(mysimilar(scen.y),)) + (; useful_cache=(mysimilar(scen.y),), useless_cache=mysimilar(scen.y)) end else y_cache = if scen.y isa Number diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index 5a1705e00..5324dd580 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -33,7 +33,7 @@ sparse_backend = AutoSparse( test_differentiation( sparse_backend, - sparse_scenarios(; include_cachified=true); + sparse_scenarios(; include_cachified=true, use_tuples=true); sparsity=true, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 1cb83a7d0..f5e88e1ff 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -35,7 +35,7 @@ static_scenarios(; include_closurified=true, include_batchified=true, include_cachified=true, - use_tuples=false, + use_tuples=true, ) ## Weird arrays