diff --git a/DifferentiationInterface/docs/src/dev_guide.md b/DifferentiationInterface/docs/src/dev_guide.md index 10e74a905..3fa1b66b9 100644 --- a/DifferentiationInterface/docs/src/dev_guide.md +++ b/DifferentiationInterface/docs/src/dev_guide.md @@ -23,10 +23,10 @@ Most operators have 4 variants, which look like this in the first order: `operat To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above). For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`). -The method `prepare_operator` must output a `prep` object of the correct type. -For instance, `prepare_gradient(f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref). -Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep`. -Otherwise, define a custom struct like `MyGradientPrep <: DifferentiationInterface.GradientPrep` and put the necessary storage in there. +The method `prepare_operator_nokwarg` must output a `prep` object of the correct type. +For instance, `prepare_gradient(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref). +Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`. +Otherwise, define a custom struct like `MyGradientPrep{SIG} <: DifferentiationInterface.GradientPrep{SIG}` and put the necessary storage in there. ## New backend @@ -75,4 +75,4 @@ GROUP = get(ENV, "JULIA_DI_TEST_GROUP", "Back/SuperDiff") but don't forget to switch it back before pushing. -Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`). \ No newline at end of file +Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`). diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index be85d1f24..23f9c9b0c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,)) + prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,)) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 7c68b20e3..07f74edbb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -6,7 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} pb::PB end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoReverseChainRules, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index baab8a75f..bd286dc34 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -10,7 +10,9 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward -function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple) +function DI.prepare_pushforward_nokwarg( + strict::Val, f, backend::AutoDiffractor, x, tx::NTuple +) _sig = DI.signature(f, backend, x, tx; strict) return DI.NoPushforwardPrep(_sig) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 96b2fff76..9ce785d99 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,6 +1,6 @@ ## Pushforward -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, @@ -122,7 +122,7 @@ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG} shadows::O end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, @@ -203,7 +203,7 @@ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG} output_length::Int end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 33e5ce1aa..4d8328c3e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,6 +1,6 @@ ## Pushforward -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 205a86837..b0a52fb92 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -52,7 +52,7 @@ struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG} y_example::Y # useful to create return activity end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, @@ -191,7 +191,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 30562fb7f..ae2b33923 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -5,7 +5,7 @@ struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG} ty_copy::TY end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 0ebfdb7dc..dfbbf3549 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -7,7 +7,7 @@ struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardP jvp_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, @@ -105,7 +105,7 @@ struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} vjp_exe!::E1! end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, @@ -204,7 +204,7 @@ struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePre der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -284,7 +284,7 @@ struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} jac_exe!::E1! end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -360,7 +360,7 @@ struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SI jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, @@ -445,7 +445,7 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: der2_exe!::E2! end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -462,7 +462,7 @@ function DI.prepare_second_derivative( der2_exe = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=false) der2_exe! = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=true) - derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) + derivative_prep = DI.prepare_derivative_nokwarg(strict, f, backend, x, contexts...) return FastDifferentiationAllocatingSecondDerivativePrep( _sig, y_prototype, derivative_prep, der2_exe, der2_exe! ) @@ -534,7 +534,7 @@ struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} gradient_prep::E1 end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoFastDifferentiation, @@ -557,7 +557,7 @@ function DI.prepare_hvp( hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) return FastDifferentiationHVPPrep(_sig, hvp_exe, hvp_exe!, gradient_prep) end @@ -633,7 +633,7 @@ struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} hess_exe!::E2! end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, @@ -656,7 +656,9 @@ function DI.prepare_hessian( hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false) hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true) - gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg( + strict, f, dense_ad(backend), x, contexts... + ) return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index f67ed4324..fb4b560f8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -6,7 +6,7 @@ struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPre jvp_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -107,7 +107,7 @@ struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} vjp_exe!::E1! end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!, y, @@ -213,7 +213,7 @@ struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{ der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -295,7 +295,7 @@ struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 7ebdb6e49..fa3d73aa2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -8,7 +8,7 @@ struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} dir::D end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) @@ -124,7 +124,7 @@ struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} dir::D end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -253,7 +253,7 @@ struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} dir::D end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -341,7 +341,7 @@ struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} dir::D end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -446,7 +446,7 @@ struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} absstep_h::AH end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index 259ebbdd3..11bbcfbb9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -8,7 +8,7 @@ struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} dir::D end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -161,7 +161,7 @@ struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} dir::D end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -198,7 +198,9 @@ function DI.prepare!_derivative( cache.c3 isa Union{Number,Nothing} || resize!(cache.c3, length(y)) return old_prep else - return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_derivative_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end @@ -277,7 +279,7 @@ struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} dir::D end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -318,7 +320,9 @@ function DI.prepare!_jacobian( cache.sparsity = nothing return old_prep else - return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_jacobian_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 31bcd5961..68bd2918f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -11,7 +11,7 @@ DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple() ## Pushforward -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, @@ -54,7 +54,7 @@ end ## Pullback -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, @@ -97,7 +97,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -154,7 +154,7 @@ end ## Jacobian -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 703fe6be6..30f769417 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -67,7 +67,7 @@ struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} contexts_dual::CD end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, @@ -215,11 +215,13 @@ end ### Prepared -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) - pushforward_prep = DI.prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) + pushforward_prep = DI.prepare_pushforward_nokwarg( + strict, f, backend, x, (one(x),), contexts... + ) return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep) end @@ -297,7 +299,7 @@ function DI.value_and_gradient!( grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) return y, grad else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) end end @@ -315,7 +317,7 @@ function DI.value_and_gradient( result = gradient!(result, fc, x) return DR.value(result), DR.gradient(result) else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_gradient(f, prep, backend, x, contexts...) end end @@ -331,7 +333,7 @@ function DI.gradient!( fc = DI.with_contexts(f, contexts...) return gradient!(grad, fc, x) else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.gradient!(f, grad, prep, backend, x, contexts...) end end @@ -347,7 +349,7 @@ function DI.gradient( fc = DI.with_contexts(f, contexts...) return gradient(fc, x) else - prep = DI.prepare_gradient(f, backend, x, contexts...) + prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) return DI.gradient(f, prep, backend, x, contexts...) end end @@ -360,7 +362,7 @@ struct ForwardDiffGradientPrep{SIG,C,CD} <: DI.GradientPrep{SIG} contexts_dual::CD end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, @@ -471,7 +473,7 @@ function DI.value_and_jacobian!( jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result)) return y, jac else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_jacobian!(f, jac, prep, backend, x, contexts...) end end @@ -487,7 +489,7 @@ function DI.value_and_jacobian( fc = DI.with_contexts(f, contexts...) return fc(x), jacobian(fc, x) else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_and_jacobian(f, prep, backend, x, contexts...) end end @@ -503,7 +505,7 @@ function DI.jacobian!( fc = DI.with_contexts(f, contexts...) return jacobian!(jac, fc, x) else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.jacobian!(f, jac, prep, backend, x, contexts...) end end @@ -519,7 +521,7 @@ function DI.jacobian( fc = DI.with_contexts(f, contexts...) return jacobian(fc, x) else - prep = DI.prepare_jacobian(f, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) return DI.jacobian(f, prep, backend, x, contexts...) end end @@ -532,7 +534,7 @@ struct ForwardDiffOneArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} contexts_dual::CD end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -620,7 +622,7 @@ end ## Second derivative -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -719,7 +721,7 @@ function DI.hessian!( fc = DI.with_contexts(f, contexts...) return hessian!(hess, fc, x) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.hessian!(f, hess, prep, backend, x, contexts...) end end @@ -735,7 +737,7 @@ function DI.hessian( fc = DI.with_contexts(f, contexts...) return hessian(fc, x) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.hessian(f, prep, backend, x, contexts...) end end @@ -761,7 +763,7 @@ function DI.value_gradient_and_hessian!( hess === DR.hessian(result) || copyto!(hess, DR.hessian(result)) return (y, grad, hess) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_gradient_and_hessian!(f, grad, hess, prep, backend, x, contexts...) end end @@ -779,7 +781,7 @@ function DI.value_gradient_and_hessian( result = hessian!(result, fc, x) return (DR.value(result), DR.gradient(result), DR.hessian(result)) else - prep = DI.prepare_hessian(f, backend, x, contexts...) + prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) return DI.value_gradient_and_hessian(f, prep, backend, x, contexts...) end end @@ -793,7 +795,7 @@ struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} contexts_dual::CD end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index a919d8965..5073838b4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -8,7 +8,7 @@ struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} contexts_dual::CD end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -137,7 +137,7 @@ function DI.value_and_derivative( result = derivative!(result, fc!, y, x) return DiffResults.value(result), DiffResults.derivative(result) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_derivative(f!, y, prep, backend, x, contexts...) end end @@ -151,7 +151,7 @@ function DI.value_and_derivative!( result = derivative!(result, fc!, y, x) return DiffResults.value(result), DiffResults.derivative(result) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_derivative!(f!, y, der, prep, backend, x, contexts...) end end @@ -163,7 +163,7 @@ function DI.derivative( fc! = DI.with_contexts(f!, contexts...) return derivative(fc!, y, x) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.derivative(f!, y, prep, backend, x, contexts...) end end @@ -175,7 +175,7 @@ function DI.derivative!( fc! = DI.with_contexts(f!, contexts...) return derivative!(der, fc!, y, x) else - prep = DI.prepare_derivative(f!, y, backend, x, contexts...) + prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.derivative!(f!, y, der, prep, backend, x, contexts...) end end @@ -188,7 +188,7 @@ struct ForwardDiffTwoArgDerivativePrep{SIG,C,CD} <: DI.DerivativePrep{SIG} contexts_dual::CD end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -212,7 +212,9 @@ function DI.prepare!_derivative( resize!(config.duals, length(y)) return old_prep else - return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_derivative_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end @@ -312,7 +314,7 @@ function DI.value_and_jacobian( result = jacobian!(result, fc!, y, x) return DiffResults.value(result), DiffResults.jacobian(result) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_jacobian(f!, y, prep, backend, x, contexts...) end end @@ -330,7 +332,7 @@ function DI.value_and_jacobian!( result = jacobian!(result, fc!, y, x) return DiffResults.value(result), DiffResults.jacobian(result) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.value_and_jacobian!(f!, y, jac, prep, backend, x, contexts...) end end @@ -346,7 +348,7 @@ function DI.jacobian( fc! = DI.with_contexts(f!, contexts...) return jacobian(fc!, y, x) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.jacobian(f!, y, prep, backend, x, contexts...) end end @@ -362,7 +364,7 @@ function DI.jacobian!( fc! = DI.with_contexts(f!, contexts...) return jacobian!(jac, fc!, y, x) else - prep = DI.prepare_jacobian(f!, y, backend, x, contexts...) + prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) return DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) end end @@ -375,7 +377,7 @@ struct ForwardDiffTwoArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} contexts_dual::CD end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -402,7 +404,9 @@ function DI.prepare!_jacobian( resize!(xduals, length(x)) return old_prep else - return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) + return DI.prepare_jacobian_nokwarg( + DI.is_strict(old_prep), f!, y, backend, x, contexts... + ) end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 83be97ed1..514550d26 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -7,7 +7,7 @@ struct GTPSAOneArgPushforwardPrep{SIG,X} <: DI.PushforwardPrep{SIG} xt::X end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoGTPSA{D}, @@ -115,7 +115,7 @@ struct GTPSAOneArgGradientPrep{SIG,X} <: DI.GradientPrep{SIG} end # Unlike JVP, this requires us to use all variables -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -198,7 +198,7 @@ struct GTPSAOneArgJacobianPrep{SIG,X} <: DI.JacobianPrep{SIG} end # To materialize the entire Jacobian we use all variables -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -284,7 +284,7 @@ struct GTPSAOneArgSecondDerivativePrep{SIG,X} <: DI.SecondDerivativePrep{SIG} xt::X end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -414,7 +414,7 @@ struct GTPSAOneArgHessianPrep{SIG,X,M} <: DI.HessianPrep{SIG} m::M end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -549,11 +549,11 @@ struct GTPSAOneArgHVPPrep{SIG,E,H} <: DI.HVPPrep{SIG} hess::H end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C} ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - hessprep = DI.prepare_hessian(strict, f, backend, x, contexts...) + hessprep = DI.prepare_hessian_nokwarg(strict, f, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) return GTPSAOneArgHVPPrep(_sig, hessprep, hess) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index e77c4900d..b106923b5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -10,7 +10,7 @@ struct GTPSATwoArgPushforwardPrep{SIG,X,Y} <: DI.PushforwardPrep{SIG} yt::Y end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -125,7 +125,7 @@ struct GTPSATwoArgJacobianPrep{SIG,X,Y} <: DI.JacobianPrep{SIG} yt::Y end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index e1e0c580b..2e46bc7c4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -6,7 +6,7 @@ struct MooncakeOneArgPullbackPrep{SIG,Tcache,DY} <: DI.PullbackPrep{SIG} dy_righttype::DY end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) @@ -106,7 +106,7 @@ struct MooncakeGradientPrep{SIG,Tcache} <: DI.GradientPrep{SIG} cache::Tcache end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C} ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 43bed9857..e89fbc37e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -5,7 +5,7 @@ struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} target_function::F end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 2e9380fd4..a646488c5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -6,7 +6,7 @@ struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SI single_threaded_prep::P end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, @@ -15,7 +15,7 @@ function DI.prepare_pushforward( contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - single_threaded_prep = DI.prepare_pushforward( + single_threaded_prep = DI.prepare_pushforward_nokwarg( strict, f, single_threaded(backend), x, tx, contexts... ) return PolyesterForwardDiffOneArgPushforwardPrep(_sig, single_threaded_prep) @@ -86,11 +86,11 @@ struct PolyesterForwardDiffOneArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} single_threaded_prep::P end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_derivative( + single_threaded_prep = DI.prepare_derivative_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgDerivativePrep(_sig, single_threaded_prep) @@ -158,7 +158,7 @@ struct PolyesterForwardDiffGradientPrep{SIG,chunksize,P} <: DI.GradientPrep{SIG} single_threaded_prep::P end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff{chunksize}, @@ -171,7 +171,7 @@ function DI.prepare_gradient( else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_gradient( + single_threaded_prep = DI.prepare_gradient_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffGradientPrep(_sig, chunk, single_threaded_prep) @@ -249,7 +249,7 @@ struct PolyesterForwardDiffOneArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPre single_threaded_prep::P end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff{chunksize}, @@ -262,7 +262,7 @@ function DI.prepare_jacobian( else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_jacobian( + single_threaded_prep = DI.prepare_jacobian_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgJacobianPrep(_sig, chunk, single_threaded_prep) @@ -339,11 +339,11 @@ struct PolyesterForwardDiffHessianPrep{SIG,P} <: DI.HessianPrep{SIG} single_threaded_prep::P end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_hessian( + single_threaded_prep = DI.prepare_hessian_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffHessianPrep(_sig, single_threaded_prep) @@ -411,11 +411,11 @@ struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG,P} <: DI.SecondDerivat single_threaded_prep::P end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_second_derivative( + single_threaded_prep = DI.prepare_second_derivative_nokwarg( strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgSecondDerivativePrep(_sig, single_threaded_prep) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 74d460cc3..78609ea0e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -5,7 +5,7 @@ struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SI single_threaded_prep::P end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -15,8 +15,8 @@ function DI.prepare_pushforward( contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) - single_threaded_prep = DI.prepare_pushforward( - f!, y, single_threaded(backend), x, tx, contexts... + single_threaded_prep = DI.prepare_pushforward_nokwarg( + strict, f!, y, single_threaded(backend), x, tx, contexts... ) return PolyesterForwardDiffTwoArgPushforwardPrep(_sig, single_threaded_prep) end @@ -90,11 +90,11 @@ struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} single_threaded_prep::P end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) - single_threaded_prep = DI.prepare_derivative( + single_threaded_prep = DI.prepare_derivative_nokwarg( strict, f!, y, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffTwoArgDerivativePrep(_sig, single_threaded_prep) @@ -166,7 +166,7 @@ struct PolyesterForwardDiffTwoArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPre single_threaded_prep::P end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, @@ -180,7 +180,7 @@ function DI.prepare_jacobian( else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_jacobian( + single_threaded_prep = DI.prepare_jacobian_nokwarg( strict, f!, y, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffTwoArgJacobianPrep(_sig, chunk, single_threaded_prep) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 13b673ac5..c9d10490a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -1,6 +1,6 @@ ## Pullback -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) @@ -79,7 +79,7 @@ struct ReverseDiffGradientPrep{SIG,C,T} <: DI.GradientPrep{SIG} tape::T end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) @@ -143,7 +143,7 @@ end ### With contexts -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -216,7 +216,7 @@ struct ReverseDiffOneArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} tape::T end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) @@ -280,7 +280,7 @@ end ### With contexts -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -354,11 +354,11 @@ struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.Hessia hessian_tape::HT end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) - gradient_prep = DI.prepare_gradient(strict, f, backend, x) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x) if compile hessian_tape = ReverseDiff.compile(HessianTape(f, x)) return ReverseDiffHessianPrep(_sig, gradient_prep, nothing, hessian_tape) @@ -412,11 +412,11 @@ end ### With contexts -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) hessian_config = HessianConfig(x) return ReverseDiffHessianPrep(_sig, gradient_prep, hessian_config, nothing) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 531212fe5..40553d456 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -1,6 +1,6 @@ ## Pullback -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f!, y, @@ -139,7 +139,7 @@ struct ReverseDiffTwoArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} tape::T end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f!, y, backend, x; strict) @@ -205,7 +205,7 @@ end ### With contexts -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 1084583f2..8efeb3659 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -25,7 +25,7 @@ SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result) ## Hessian, one argument -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) @@ -62,8 +62,12 @@ function _prepare_sparse_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = DI.prepare_hvp(strict, f, dense_backend, x, batched_seeds[1], contexts...) - gradient_prep = DI.prepare_gradient(strict, f, DI.inner(dense_backend), x, contexts...) + hvp_prep = DI.prepare_hvp_nokwarg( + strict, f, dense_backend, x, batched_seeds[1], contexts... + ) + gradient_prep = DI.prepare_gradient_nokwarg( + strict, f, DI.inner(dense_backend), x, contexts... + ) return SparseHessianPrep( _sig, batch_size_settings, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 85eac3b0f..c5edc8cb2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -36,7 +36,7 @@ struct PullbackSparseJacobianPrep{ pullback_prep::E end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) @@ -45,7 +45,7 @@ function DI.prepare_jacobian( return _prepare_sparse_jacobian_aux(strict, perf, y, (f,), backend, x, contexts...) end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) @@ -108,7 +108,7 @@ function _prepare_sparse_jacobian_aux_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] - pushforward_prep = DI.prepare_pushforward( + pushforward_prep = DI.prepare_pushforward_nokwarg( strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PushforwardSparseJacobianPrep( @@ -142,7 +142,7 @@ function _prepare_sparse_jacobian_aux_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - pullback_prep = DI.prepare_pullback( + pullback_prep = DI.prepare_pullback_nokwarg( strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PullbackSparseJacobianPrep( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index c5d0d2847..42b4970db 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -27,7 +27,7 @@ struct MixedModeSparseJacobianPrep{ pullback_prep::Er end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f::F, backend::AutoSparse{<:DI.MixedMode}, @@ -38,7 +38,7 @@ function DI.prepare_jacobian( return _prepare_mixed_sparse_jacobian_aux(strict, y, (f,), backend, x, contexts...) end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!::F, y, @@ -127,7 +127,7 @@ function _prepare_mixed_sparse_jacobian_aux_aux( ntuple(b -> similar(x), Val(Br)) for _ in batched_seeds_reverse ] - pushforward_prep = DI.prepare_pushforward( + pushforward_prep = DI.prepare_pushforward_nokwarg( strict, f_or_f!y..., DI.forward_backend(dense_backend), @@ -135,7 +135,7 @@ function _prepare_mixed_sparse_jacobian_aux_aux( batched_seeds_forward[1], contexts...; ) - pullback_prep = DI.prepare_pullback( + pullback_prep = DI.prepare_pullback_nokwarg( strict, f_or_f!y..., DI.reverse_backend(dense_backend), diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 105d8a6a1..48314542d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -6,7 +6,7 @@ struct SymbolicsOneArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} pf_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) @@ -94,7 +94,7 @@ struct SymbolicsOneArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -168,7 +168,7 @@ struct SymbolicsOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} grad_exe!::E1! end -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -238,7 +238,7 @@ struct SymbolicsOneArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, @@ -316,7 +316,7 @@ struct SymbolicsOneArgHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} hess_exe!::E2! end -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, @@ -336,7 +336,9 @@ function DI.prepare_hessian( res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false)) (hess_exe, hess_exe!) = res - gradient_prep = DI.prepare_gradient(strict, f, dense_ad(backend), x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg( + strict, f, dense_ad(backend), x, contexts... + ) return SymbolicsOneArgHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end @@ -405,7 +407,7 @@ struct SymbolicsOneArgHVPPrep{SIG,G,E2,E2!} <: DI.HVPPrep{SIG} hvp_exe!::E2! end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) @@ -422,7 +424,7 @@ function DI.prepare_hvp( ) (hvp_exe, hvp_exe!) = res - gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...) return SymbolicsOneArgHVPPrep(_sig, gradient_prep, hvp_exe, hvp_exe!) end @@ -497,7 +499,7 @@ struct SymbolicsOneArgSecondDerivativePrep{SIG,D,E1,E1!} <: DI.SecondDerivativeP der2_exe!::E1! end -function DI.prepare_second_derivative( +function DI.prepare_second_derivative_nokwarg( strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -512,7 +514,7 @@ function DI.prepare_second_derivative( elseif res isa RuntimeGeneratedFunction res, nothing end - derivative_prep = DI.prepare_derivative(strict, f, backend, x, contexts...) + derivative_prep = DI.prepare_derivative_nokwarg(strict, f, backend, x, contexts...) return SymbolicsOneArgSecondDerivativePrep(_sig, derivative_prep, der2_exe, der2_exe!) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 58623720e..5237a64d4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -6,7 +6,7 @@ struct SymbolicsTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} pushforward_exe!::E1! end -function DI.prepare_pushforward( +function DI.prepare_pushforward_nokwarg( strict::Val, f!, y, @@ -104,7 +104,7 @@ struct SymbolicsTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} der_exe!::E1! end -function DI.prepare_derivative( +function DI.prepare_derivative_nokwarg( strict::Val, f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) @@ -182,7 +182,7 @@ struct SymbolicsTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} jac_exe!::E1! end -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f!, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index aaeae6cc7..d3e05294d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -15,7 +15,7 @@ struct TrackerPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} pb::PB end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoTracker, @@ -91,7 +91,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index c7a246211..dc400b03e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -27,7 +27,7 @@ struct ZygotePullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} pb::PB end -function DI.prepare_pullback( +function DI.prepare_pullback_nokwarg( strict::Val, f, backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) @@ -98,7 +98,7 @@ end ## Gradient -function DI.prepare_gradient( +function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -138,7 +138,7 @@ end ## Jacobian -function DI.prepare_jacobian( +function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) @@ -187,11 +187,11 @@ struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG} fd_prep::P end -function DI.prepare_hvp( +function DI.prepare_hvp_nokwarg( strict::Val, f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - fd_prep = DI.prepare_hvp( + fd_prep = DI.prepare_hvp_nokwarg( strict, f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) return ZygoteHVPPrep(_sig, fd_prep) @@ -265,7 +265,7 @@ end ## Hessian -function DI.prepare_hessian( +function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) diff --git a/DifferentiationInterface/src/fallbacks/change_prep.jl b/DifferentiationInterface/src/fallbacks/change_prep.jl index 492813391..504d156c5 100644 --- a/DifferentiationInterface/src/fallbacks/change_prep.jl +++ b/DifferentiationInterface/src/fallbacks/change_prep.jl @@ -20,8 +20,10 @@ for op in [ end val_and_op! = Symbol(val_and_op, "!") prep_op = Symbol("prepare_", op) + prep_op_nokwarg = Symbol("prepare_", op, "_nokwarg") prep_op! = Symbol("prepare!_", op) prep_op_same_point = Symbol("prepare_", op, "_same_point") + prep_op_same_point_nokwarg = Symbol("prepare_", op, "_same_point_nokwarg") P = if op == :derivative DerivativePrep elseif op == :gradient @@ -46,7 +48,7 @@ for op in [ f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} check_prep(f, old_prep, backend, x, contexts...) - return $prep_op(is_strict(old_prep), f, backend, x, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end op == :gradient && continue # 2-arg @@ -54,7 +56,7 @@ for op in [ f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} check_prep(f!, y, old_prep, backend, x, contexts...) - return $prep_op(is_strict(old_prep), f!, y, backend, x, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f!, y, backend, x, contexts...) end elseif op in (:second_derivative, :hessian) @@ -63,7 +65,7 @@ for op in [ f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} check_prep(f, old_prep, backend, x, contexts...) - return $prep_op(is_strict(old_prep), f, backend, x, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f, backend, x, contexts...) end elseif op in (:pushforward, :pullback, :hvp) @@ -77,7 +79,7 @@ for op in [ contexts::Vararg{Context,C}; ) where {F,C} check_prep(f, old_prep, backend, x, seed, contexts...) - return $prep_op(is_strict(old_prep), f, backend, x, seed, contexts...) + return $prep_op_nokwarg(is_strict(old_prep), f, backend, x, seed, contexts...) end @eval function $prep_op_same_point( f::F, @@ -90,7 +92,7 @@ for op in [ check_prep(f, prep, backend, x, seed, contexts...) return prep end - @eval function $prep_op_same_point( + @eval function $prep_op_same_point_nokwarg( strict::Val, f::F, backend::AbstractADType, @@ -98,7 +100,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}; ) where {F,C} - prep = $prep_op(strict, f, backend, x, seed, contexts...) + prep = $prep_op_nokwarg(strict, f, backend, x, seed, contexts...) return $prep_op_same_point(f, prep, backend, x, seed, contexts...) end op == :hvp && continue @@ -113,7 +115,9 @@ for op in [ contexts::Vararg{Context,C}, ) where {F,C} check_prep(f!, y, old_prep, backend, x, seed, contexts...) - return $prep_op(is_strict(old_prep), f!, y, backend, x, seed, contexts...) + return $prep_op_nokwarg( + is_strict(old_prep), f!, y, backend, x, seed, contexts... + ) end @eval function $prep_op_same_point( f!::F, @@ -127,7 +131,7 @@ for op in [ check_prep(f!, y, prep, backend, x, seed, contexts...) return prep end - @eval function $prep_op_same_point( + @eval function $prep_op_same_point_nokwarg( strict::Val, f!::F, y, @@ -136,7 +140,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}; ) where {F,C} - prep = $prep_op(strict, f!, y, backend, x, seed, contexts...) + prep = $prep_op_nokwarg(strict, f!, y, backend, x, seed, contexts...) return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) end end diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index 81c7b2bd0..f7dc633d6 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -20,8 +20,10 @@ for op in [ end val_and_op! = Symbol(val_and_op, "!") prep_op = Symbol("prepare_", op) + prep_op_nokwarg = Symbol("prepare_", op, "_nokwarg") prep_op! = Symbol("prepare!_", op) prep_op_same_point = Symbol("prepare_", op, "_same_point") + prep_op_same_point_nokwarg = Symbol("prepare_", op, "_same_point_nokwarg") P = if op == :derivative DerivativePrep elseif op == :gradient @@ -45,25 +47,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op!(f, result, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result, prep, backend, x, contexts...) end op == :gradient && continue @@ -71,25 +73,25 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $op(f!, y, prep, backend, x, contexts...) end @eval function $op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $op!(f!, y, result, prep, backend, x, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $val_and_op(f!, y, prep, backend, x, contexts...) end @eval function $val_and_op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, contexts...) return $val_and_op!(f!, y, result, prep, backend, x, contexts...) end @@ -98,25 +100,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $op!(f, result2, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) end @@ -124,7 +126,7 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $op(f, prep, backend, x, tang, contexts...) end @eval function $op!( @@ -135,13 +137,13 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $op!(f, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $val_and_op(f, prep, backend, x, tang, contexts...) end @@ -154,7 +156,7 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $val_and_op!(f, result, prep, backend, x, tang, contexts...) end elseif op == :hvp @@ -167,7 +169,7 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f, backend, x, tang, contexts...) return $val_and_op!( f, result1, result2, prep, backend, x, tang, contexts... ) @@ -179,7 +181,7 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $op(f!, y, prep, backend, x, tang, contexts...) end @eval function $op!( @@ -191,13 +193,13 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $op!(f!, y, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $val_and_op(f!, y, prep, backend, x, tang, contexts...) end @eval function $val_and_op!( @@ -209,7 +211,7 @@ for op in [ tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + prep = $prep_op_nokwarg(Val(true), f!, y, backend, x, tang, contexts...) return $val_and_op!(f!, y, result, prep, backend, x, tang, contexts...) end end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index c8753432d..69075fc3a 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -6,8 +6,21 @@ $(docstring_prepare("derivative"; inplace=true)) """ -function prepare_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_derivative(strict, args...) +function prepare_derivative( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) +) where {F,C} + return prepare_derivative_nokwarg(strict, f, backend, x, contexts...) +end + +function prepare_derivative( + f!::F, + y, + backend::AbstractADType, + x, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_derivative_nokwarg(strict, f!, y, backend, x, contexts...) end """ @@ -65,19 +78,21 @@ struct PushforwardDerivativePrep{SIG,E<:PushforwardPrep} <: DerivativePrep{SIG} pushforward_prep::E end -function prepare_derivative( +function prepare_derivative_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) - pushforward_prep = prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) + pushforward_prep = prepare_pushforward_nokwarg( + strict, f, backend, x, (one(x),), contexts... + ) return PushforwardDerivativePrep(_sig, pushforward_prep) end -function prepare_derivative( +function prepare_derivative_nokwarg( strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, contexts...; strict) - pushforward_prep = prepare_pushforward( + pushforward_prep = prepare_pushforward_nokwarg( strict, f!, y, backend, x, (one(x),), contexts... ) return PushforwardDerivativePrep(_sig, pushforward_prep) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 8ddfb7137..d15344b0d 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -5,8 +5,10 @@ $(docstring_prepare("gradient")) """ -function prepare_gradient(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_gradient(strict, args...) +function prepare_gradient( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) +) where {F,C} + return prepare_gradient_nokwarg(strict, f, backend, x, contexts...) end """ @@ -60,12 +62,14 @@ struct PullbackGradientPrep{SIG,Y,E<:PullbackPrep} <: GradientPrep{SIG} pullback_prep::E end -function prepare_gradient( +function prepare_gradient_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? - pullback_prep = prepare_pullback(strict, f, backend, x, (one(typeof(y)),), contexts...) + pullback_prep = prepare_pullback_nokwarg( + strict, f, backend, x, (one(typeof(y)),), contexts... + ) return PullbackGradientPrep(_sig, y, pullback_prep) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 201cbea23..24932b289 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -6,8 +6,21 @@ $(docstring_prepare("jacobian"; inplace=true)) """ -function prepare_jacobian(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_jacobian(strict, args...) +function prepare_jacobian( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) +) where {F,C} + return prepare_jacobian_nokwarg(strict, f, backend, x, contexts...) +end + +function prepare_jacobian( + f!::F, + y, + backend::AbstractADType, + x, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_jacobian_nokwarg(strict, f!, y, backend, x, contexts...) end """ @@ -90,7 +103,7 @@ struct PullbackJacobianPrep{ pullback_prep::E end -function prepare_jacobian( +function prepare_jacobian_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} y = f(x, map(unwrap, contexts)...) @@ -107,7 +120,7 @@ function prepare_jacobian( ) end -function prepare_jacobian( +function prepare_jacobian_nokwarg( strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} perf = pushforward_performance(backend) @@ -140,7 +153,7 @@ function _prepare_jacobian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] - pushforward_prep = prepare_pushforward( + pushforward_prep = prepare_pushforward_nokwarg( strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... ) return PushforwardJacobianPrep( @@ -165,7 +178,7 @@ function _prepare_jacobian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - pullback_prep = prepare_pullback( + pullback_prep = prepare_pullback_nokwarg( strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... ) return PullbackJacobianPrep( diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 63afa0656..a22277a6f 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -6,8 +6,27 @@ $(docstring_prepare("pullback"; inplace=true)) """ -function prepare_pullback(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pullback(strict, args...) +function prepare_pullback( + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pullback_nokwarg(strict, f, backend, x, ty, contexts...) +end + +function prepare_pullback( + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pullback_nokwarg(strict, f!, y, backend, x, ty, contexts...) end """ @@ -24,8 +43,27 @@ function prepare!_pullback end $(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ -function prepare_pullback_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pullback_same_point(strict, args...) +function prepare_pullback_same_point( + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pullback_same_point_nokwarg(strict, f, backend, x, ty, contexts...) +end + +function prepare_pullback_same_point( + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pullback_same_point_nokwarg(strict, f!, y, backend, x, ty, contexts...) end """ @@ -94,7 +132,7 @@ struct PushforwardPullbackPrep{SIG,E} <: PullbackPrep{SIG} pushforward_prep::E end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pullback_aux( @@ -102,7 +140,7 @@ function prepare_pullback( ) end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f!::F, y, @@ -127,7 +165,9 @@ function _prepare_pullback_aux( ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(strict, f, backend, x, (dx,), contexts...) + pushforward_prep = prepare_pushforward_nokwarg( + strict, f, backend, x, (dx,), contexts... + ) return PushforwardPullbackPrep(_sig, pushforward_prep) end @@ -143,7 +183,9 @@ function _prepare_pullback_aux( ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(strict, f!, y, backend, x, (dx,), contexts...) + pushforward_prep = prepare_pushforward_nokwarg( + strict, f!, y, backend, x, (dx,), contexts... + ) return PushforwardPullbackPrep(_sig, pushforward_prep) end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 883681031..601461d6a 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -6,8 +6,27 @@ $(docstring_prepare("pushforward"; inplace=true)) """ -function prepare_pushforward(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pushforward(strict, args...) +function prepare_pushforward( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pushforward_nokwarg(strict, f, backend, x, tx, contexts...) +end + +function prepare_pushforward( + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pushforward_nokwarg(strict, f!, y, backend, x, tx, contexts...) end """ @@ -24,8 +43,29 @@ function prepare!_pushforward end $(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ -function prepare_pushforward_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_pushforward_same_point(strict, args...) +function prepare_pushforward_same_point( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pushforward_same_point_nokwarg(strict, f, backend, x, tx, contexts...) +end + +function prepare_pushforward_same_point( + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_pushforward_same_point_nokwarg( + strict, f!, y, backend, x, tx, contexts... + ) end """ @@ -94,7 +134,7 @@ struct PullbackPushforwardPrep{SIG,E} <: PushforwardPrep{SIG} pullback_prep::E end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pushforward_aux( @@ -102,7 +142,7 @@ function prepare_pushforward( ) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -128,7 +168,7 @@ function _prepare_pushforward_aux( _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(strict, f, backend, x, (dy,), contexts...) + pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end @@ -144,7 +184,7 @@ function _prepare_pushforward_aux( ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(strict, f!, y, backend, x, (dy,), contexts...) + pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 35c21a71e..00dd6b445 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -41,7 +41,7 @@ struct FromPrimitivePushforwardPrep{SIG,E<:PushforwardPrep} <: PushforwardPrep{S pushforward_prep::E end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoForwardFromPrimitive, @@ -50,11 +50,13 @@ function prepare_pushforward( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) - primitive_prep = prepare_pushforward(strict, f, backend.backend, x, tx, contexts...) + primitive_prep = prepare_pushforward_nokwarg( + strict, f, backend.backend, x, tx, contexts... + ) return FromPrimitivePushforwardPrep(_sig, primitive_prep) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -64,7 +66,9 @@ function prepare_pushforward( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) - primitive_prep = prepare_pushforward(strict, f!, y, backend.backend, x, tx, contexts...) + primitive_prep = prepare_pushforward_nokwarg( + strict, f!, y, backend.backend, x, tx, contexts... + ) return FromPrimitivePushforwardPrep(_sig, primitive_prep) end @@ -158,7 +162,7 @@ struct FromPrimitivePullbackPrep{SIG,E<:PullbackPrep} <: PullbackPrep{SIG} pullback_prep::E end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoReverseFromPrimitive, @@ -167,11 +171,13 @@ function prepare_pullback( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) - primitive_prep = prepare_pullback(strict, f, backend.backend, x, ty, contexts...) + primitive_prep = prepare_pullback_nokwarg( + strict, f, backend.backend, x, ty, contexts... + ) return FromPrimitivePullbackPrep(_sig, primitive_prep) end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f!::F, y, @@ -181,7 +187,9 @@ function prepare_pullback( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) - primitive_prep = prepare_pullback(strict, f!, y, backend.backend, x, ty, contexts...) + primitive_prep = prepare_pullback_nokwarg( + strict, f!, y, backend.backend, x, ty, contexts... + ) return FromPrimitivePullbackPrep(_sig, primitive_prep) end diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index bfa6656f4..8481cad97 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -36,7 +36,7 @@ function threshold_batchsize( return AutoSimpleFiniteDiff(backend.ε; chunksize) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoSimpleFiniteDiff, @@ -48,7 +48,7 @@ function prepare_pushforward( return NoPushforwardPrep(_sig) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index 1b3c6c0eb..e74449f3c 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -20,14 +20,14 @@ ADTypes.mode(::AutoZeroForward) = ForwardMode() check_available(::AutoZeroForward) = true inplace_support(::AutoZeroForward) = InPlaceSupported() -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end -function prepare_pushforward( +function prepare_pushforward_nokwarg( strict::Val, f!::F, y, @@ -118,14 +118,14 @@ ADTypes.mode(::AutoZeroReverse) = ReverseMode() check_available(::AutoZeroReverse) = true inplace_support(::AutoZeroReverse) = InPlaceSupported() -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f::F, backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) return NoPullbackPrep(_sig) end -function prepare_pullback( +function prepare_pullback_nokwarg( strict::Val, f!::F, y, diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 290adb525..dbdd16da6 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -5,8 +5,10 @@ $(docstring_prepare("hessian")) """ -function prepare_hessian(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_hessian(strict, args...) +function prepare_hessian( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) +) where {F,C} + return prepare_hessian_nokwarg(strict, f, backend, x, contexts...) end """ @@ -70,7 +72,7 @@ struct HVPGradientHessianPrep{ gradient_prep::E1 end -function prepare_hessian( +function prepare_hessian_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} # type-unstable @@ -94,8 +96,8 @@ function _prepare_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = prepare_hvp(strict, f, backend, x, batched_seeds[1], contexts...) - gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) + hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, batched_seeds[1], contexts...) + gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...) return HVPGradientHessianPrep( _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep ) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 8caeb2b64..11faa2548 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -5,8 +5,15 @@ $(docstring_prepare("hvp")) """ -function prepare_hvp(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_hvp(strict, args...) +function prepare_hvp( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_hvp_nokwarg(strict, f, backend, x, tx, contexts...) end """ @@ -21,8 +28,15 @@ function prepare!_hvp end $(docstring_prepare("hvp"; samepoint=true)) """ -function prepare_hvp_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_hvp_same_point(strict, args...) +function prepare_hvp_same_point( + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Val=Val(false), +) where {F,C} + return prepare_hvp_same_point_nokwarg(strict, f, backend, x, tx, contexts...) end """ @@ -61,7 +75,7 @@ $(docstring_preparation_hint("hvp"; same_point=true)) """ function gradient_and_hvp! end -function prepare_hvp( +function prepare_hvp_nokwarg( strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_hvp_aux( @@ -105,11 +119,11 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pushforward( + prepare_pushforward_nokwarg( strict, shuffled_gradient!, grad_buffer, @@ -140,7 +154,9 @@ function _prepare_hvp_aux( grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient - inner_gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) + inner_gradient_prep = prepare_gradient_nokwarg( + strict, f, inner(backend), x, contexts... + ) inner_gradient_in_prep = inner_gradient_prep # Outer pushforward new_contexts = ( @@ -157,11 +173,11 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pushforward( + prepare_pushforward_nokwarg( strict, shuffled_gradient!, grad_buffer, @@ -203,8 +219,12 @@ function _prepare_hvp_aux( ) contextso = adapt_eltype.(contexts, Ref(eltype(xo))) contextsoi = adapt_eltype.(contexts, Ref(eltype(xoi))) - inner_gradient_prep = prepare_gradient(strict, f, inner(backend), xo, contextso...) - inner_gradient_in_prep = prepare_gradient(strict, f, inner(backend), xoi, contextsoi...) + inner_gradient_prep = prepare_gradient_nokwarg( + strict, f, inner(backend), xo, contextso... + ) + inner_gradient_in_prep = prepare_gradient_nokwarg( + strict, f, inner(backend), xoi, contextsoi... + ) # Outer pushforward new_contexts = ( FunctionContext(f), @@ -220,11 +240,11 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pushforward( + prepare_pushforward_nokwarg( strict, shuffled_gradient!, grad_buffer, @@ -477,10 +497,10 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - outer_gradient_prep = prepare_gradient( + outer_gradient_prep = prepare_gradient_nokwarg( strict, shuffled_single_pushforward, outer(backend), x, new_contexts... ) - gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) + gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...) return ReverseOverForwardHVPPrep(_sig, outer_gradient_prep, gradient_prep) end @@ -596,11 +616,11 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - outer_pullback_prep = prepare_pullback( + outer_pullback_prep = prepare_pullback_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pullback_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported - prepare_pullback( + prepare_pullback_nokwarg( strict, shuffled_gradient!, grad_buffer, diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index f26f335e6..09e81fbac 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -5,8 +5,10 @@ $(docstring_prepare("second_derivative")) """ -function prepare_second_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} - return prepare_second_derivative(strict, args...) +function prepare_second_derivative( + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) +) where {F,C} + return prepare_second_derivative_nokwarg(strict, f, backend, x, contexts...) end """ @@ -59,7 +61,7 @@ struct DerivativeSecondDerivativePrep{SIG,E<:DerivativePrep} <: SecondDerivative outer_derivative_prep::E end -function prepare_second_derivative( +function prepare_second_derivative_nokwarg( strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) @@ -67,7 +69,7 @@ function prepare_second_derivative( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - outer_derivative_prep = prepare_derivative( + outer_derivative_prep = prepare_derivative_nokwarg( strict, shuffled_derivative, outer(backend), x, new_contexts... ) return DerivativeSecondDerivativePrep(_sig, outer_derivative_prep)