From 26517dc2412cd879eb05c918ebb147b6a14a3709 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 19:22:38 +0100 Subject: [PATCH 01/22] feat!: ensure consistency between preparation result and current signature --- .github/workflows/Test.yml | 2 +- .../differentiate_with.jl | 2 +- .../reverse_onearg.jl | 31 +- .../DifferentiationInterfaceDiffractorExt.jl | 11 +- .../forward_onearg.jl | 56 ++-- .../forward_twoarg.jl | 19 +- .../reverse_onearg.jl | 49 +++- .../reverse_twoarg.jl | 17 +- .../utils.jl | 16 +- .../onearg.jl | 155 +++++++--- .../twoarg.jl | 90 ++++-- .../onearg.jl | 111 +++++--- .../twoarg.jl | 77 +++-- ...rentiationInterfaceFiniteDifferencesExt.jl | 50 +++- .../DifferentiationInterfaceForwardDiffExt.jl | 1 - .../onearg.jl | 155 +++++----- .../secondorder.jl | 144 ---------- .../twoarg.jl | 69 ++++- .../utils.jl | 2 +- .../onearg.jl | 141 ++++++--- .../twoarg.jl | 44 ++- .../onearg.jl | 41 ++- .../twoarg.jl | 23 +- .../onearg.jl | 268 ++++++++++-------- .../twoarg.jl | 104 +++++-- .../onearg.jl | 154 ++++++---- .../twoarg.jl | 75 +++-- ...iationInterfaceSparseMatrixColoringsExt.jl | 2 +- .../hessian.jl | 17 +- .../jacobian.jl | 37 ++- .../jacobian_mixed.jl | 36 ++- .../onearg.jl | 38 +-- .../twoarg.jl | 16 +- .../DifferentiationInterfaceTrackerExt.jl | 41 ++- .../DifferentiationInterfaceZygoteExt.jl | 102 +++++-- .../src/DifferentiationInterface.jl | 2 +- DifferentiationInterface/src/docstrings.jl | 1 + .../src/fallbacks/change_prep.jl | 46 ++- .../src/fallbacks/no_prep.jl | 42 +-- .../src/first_order/derivative.jl | 30 +- .../src/first_order/gradient.jl | 31 +- .../src/first_order/jacobian.jl | 72 +++-- .../src/first_order/pullback.jl | 53 +++- .../src/first_order/pushforward.jl | 53 +++- .../src/misc/from_primitive.jl | 112 +++++--- .../src/misc/simple_finite_diff.jl | 27 +- .../src/misc/zero_backends.jl | 85 ++++-- .../src/second_order/hessian.jl | 37 ++- .../src/second_order/hvp.jl | 115 ++++++-- .../src/second_order/second_derivative.jl | 17 +- DifferentiationInterface/src/utils/prep.jl | 110 +++++-- .../src/tests/correctness_eval.jl | 56 +++- 52 files changed, 2035 insertions(+), 1050 deletions(-) delete mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index f91ce0853..13de31d96 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true # TODO: toggle + fail-fast: false # TODO: toggle matrix: version: - "1.10" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index 2eb8d2e93..528ba061a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,)) + prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,); strict=true) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 079e5dd1a..ee691774f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -1,37 +1,48 @@ ## Pullback -struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep +struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} + _sig::Type{SIG} y::Y pb::PB end function DI.prepare_pullback( - f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} + f, + backend::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant,C}; + strict::Bool=false, ) where {C} - return DI.NoPullbackPrep() + SIG = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep{SIG}() end function DI.prepare_pullback_same_point( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, + contexts::Vararg{DI.GeneralizedConstant,C}; + strict::Bool=false, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) + SIG = DI.signature(f, backend, x, ty, contexts...; strict) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) - return ChainRulesPullbackPrepSamePoint(y, pb) + return ChainRulesPullbackPrepSamePoint(SIG, y, pb) end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -43,11 +54,12 @@ end function DI.value_and_pullback( f, prep::ChainRulesPullbackPrepSamePoint, - ::AutoReverseChainRules, + backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy unthunk(pb(dy)[2]) @@ -58,11 +70,12 @@ end function DI.pullback( f, prep::ChainRulesPullbackPrepSamePoint, - ::AutoReverseChainRules, + backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy unthunk(pb(dy)[2]) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 2973f3b37..bede2d064 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -10,9 +10,15 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward -DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep() +function DI.prepare_pushforward(f, backend::AutoDiffractor, x, tx::NTuple) + SIG = DI.signature(f, backend, x, tx) + return DI.NoPushforwardPrep{SIG}() +end -function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple) +function DI.pushforward( + f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple +) + DI.check_prep(f, prep, backend, x, tx) ty = map(tx) do dx # code copied from Diffractor.jl z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx)) @@ -25,6 +31,7 @@ end function DI.value_and_pushforward( f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple ) + DI.check_prep(f, prep, backend, x, tx) return f(x), DI.pushforward(f, prep, backend, x, tx) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 72c6df2df..25e6c4801 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -2,22 +2,25 @@ function DI.prepare_pushforward( f::F, - ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} - return DI.NoPushforwardPrep() + SIG = DI.signature(f, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep{SIG}() end function DI.value_and_pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) dx = only(tx) @@ -29,12 +32,13 @@ end function DI.value_and_pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode, Val(B)) x_and_tx = BatchDuplicated(x, tx) @@ -45,12 +49,13 @@ end function DI.pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) dx = only(tx) @@ -62,12 +67,13 @@ end function DI.pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode, Val(B)) x_and_tx = BatchDuplicated(x, tx) @@ -85,6 +91,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) # dy cannot be passed anyway y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) @@ -100,6 +107,7 @@ function DI.pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) # dy cannot be passed anyway new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) @@ -108,23 +116,25 @@ end ## Gradient -struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep +struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG} shadows::O end -function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O} - return EnzymeForwardGradientPrep{B,O}(shadows) +function EnzymeForwardGradientPrep(::Type{SIG}, ::Val{B}, shadows::O) where {SIG,B,O} + return EnzymeForwardGradientPrep{SIG,B,O}(shadows) end function DI.prepare_gradient( f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.Constant,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f, backend, x, contexts...; strict) valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) - return EnzymeForwardGradientPrep(valB, shadows) + return EnzymeForwardGradientPrep(SIG, valB, shadows) end function DI.gradient( @@ -134,6 +144,7 @@ function DI.gradient( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -150,6 +161,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -167,6 +179,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end @@ -178,33 +191,36 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end ## Jacobian -struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep +struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG} shadows::O output_length::Int end function EnzymeForwardOneArgJacobianPrep( - ::Val{B}, shadows::O, output_length::Integer -) where {B,O} - return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length) + ::Type{SIG}, ::Val{B}, shadows::O, output_length::Integer +) where {SIG,B,O} + return EnzymeForwardOneArgJacobianPrep{SIG,B,O}(shadows, output_length) end function DI.prepare_jacobian( f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.Constant,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f, backend, x, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) - return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y)) + return EnzymeForwardOneArgJacobianPrep(SIG, valB, shadows, length(y)) end function DI.jacobian( @@ -214,6 +230,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -231,6 +248,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -249,6 +267,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -260,6 +279,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 5055a23af..698e73c67 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -3,23 +3,26 @@ function DI.prepare_pushforward( f!::F, y, - ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} - return DI.NoPushforwardPrep() + SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep{SIG}() end function DI.value_and_pushforward( f!::F, y, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) dx = only(tx) @@ -34,12 +37,13 @@ end function DI.value_and_pushforward( f!::F, y, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) ty = ntuple(_ -> make_zero(y), Val(B)) @@ -59,6 +63,7 @@ function DI.pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) _, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...) return ty end @@ -67,12 +72,13 @@ function DI.value_and_pushforward!( f!::F, y, ty::NTuple{B}, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) x_and_tx = BatchDuplicated(x, tx) @@ -92,6 +98,7 @@ function DI.pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index a6ceda51a..2f0a0a8c6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -47,19 +47,21 @@ end ## Pullback -struct EnzymeReverseOneArgPullbackPrep{Y} <: DI.PullbackPrep +struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG} y_example::Y # useful to create return activity end function DI.prepare_pullback( f::F, - ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ty::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f, backend, x, ty, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) - return EnzymeReverseOneArgPullbackPrep(y) + return EnzymeReverseOneArgPullbackPrep{SIG,typeof(y)}(y) end ### Out-of-place @@ -72,6 +74,7 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode)) IA = guess_activity(typeof(x), mode) @@ -97,6 +100,7 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) IA = batchify_activity(guess_activity(typeof(x), mode), Val(B)) @@ -122,6 +126,7 @@ function DI.pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return last(DI.value_and_pullback(f, prep, backend, x, ty, contexts...)) end @@ -136,6 +141,7 @@ function DI.value_and_pullback!( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode)) RA = guess_activity(typeof(prep.y_example), mode) @@ -157,6 +163,7 @@ function DI.value_and_pullback!( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B)) @@ -177,26 +184,33 @@ function DI.pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return last(DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)) end ## Gradient function DI.prepare_gradient( - f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C} + f::F, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} - return DI.NoGradientPrep() + SIG = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep{SIG}() end ### Enzyme gradient API (only constants) function DI.gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) @@ -206,11 +220,12 @@ end function DI.value_and_gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) @@ -221,10 +236,11 @@ end function DI.gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) gradient!(mode, grad, f_and_df, x) @@ -234,10 +250,11 @@ end function DI.value_and_gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) _, result = gradient!(mode, grad, f_and_df, x) @@ -248,11 +265,12 @@ end function DI.gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) @@ -271,11 +289,12 @@ end function DI.value_and_gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) @@ -295,11 +314,12 @@ end function DI.gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) @@ -311,11 +331,12 @@ end function DI.value_and_gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 091b695e4..aca9494ca 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,18 +1,21 @@ ## Pullback -struct EnzymeReverseTwoArgPullbackPrep{TY} <: DI.PullbackPrep +struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG} ty_copy::TY end function DI.prepare_pullback( f!::F, y, - ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ty::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} - return EnzymeReverseTwoArgPullbackPrep(map(copy, ty)) + SIG = DI.signature(f!, y, backend, x, ty, contexts...; strict) + ty_copy = map(copy, ty) + return EnzymeReverseTwoArgPullbackPrep{SIG,typeof(ty_copy)}(ty_copy) end function DI.value_and_pullback( @@ -24,6 +27,7 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) copyto!(only(prep.ty_copy), only(ty)) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) @@ -46,6 +50,7 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, prep.ty_copy, ty) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) @@ -68,6 +73,7 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) copyto!(only(prep.ty_copy), only(ty)) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) @@ -89,6 +95,7 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, prep.ty_copy, ty) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) @@ -111,6 +118,7 @@ function DI.value_and_pullback!( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) copyto!(only(prep.ty_copy), only(ty)) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) @@ -134,6 +142,7 @@ function DI.value_and_pullback!( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, prep.ty_copy, ty) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index c189d09d9..8b1550532 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -9,20 +9,20 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B) ## Annotations @inline function get_f_and_df( - f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) + f::F, backend::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) ) where {F,M,B} return f end @inline function get_f_and_df( - f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) + f::F, backend::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) ) where {F,M,B} return Const(f) end @inline function get_f_and_df( f::F, - ::AutoEnzyme{ + backend::AutoEnzyme{ M, <:Union{ Duplicated, @@ -48,13 +48,13 @@ force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) @inline function _translate( - ::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant + backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant ) where {B} return Const(DI.unwrap(c)) end @inline function _translate( - ::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache + backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache ) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) @@ -100,8 +100,10 @@ function reverse_split_withprimal(backend::AutoEnzyme{Nothing}) return set_err(ReverseSplitWithPrimal, backend) end -set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) -set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode +function set_err(mode::Mode, backend::AutoEnzyme{<:Any,Nothing}) + return EnzymeCore.set_err_if_func_written(mode) +end +set_err(mode::Mode, backend::AutoEnzyme{<:Any,<:Annotation}) = mode function maybe_reshape(A::AbstractMatrix, m, n) @assert size(A) == (m, n) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 6d19c4ce3..f873d6c35 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -1,14 +1,21 @@ ## Pushforward -struct FastDifferentiationOneArgPushforwardPrep{Y,E1,E1!} <: DI.PushforwardPrep +struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Type{SIG} y_prototype::Y jvp_exe::E1 jvp_exe!::E1! end function DI.prepare_pushforward( - f, ::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, tx, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -23,17 +30,18 @@ function DI.prepare_pushforward( jvp_exe! = make_function( jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationOneArgPushforwardPrep(y_prototype, jvp_exe, jvp_exe!) + return FastDifferentiationOneArgPushforwardPrep(SIG, y_prototype, jvp_exe, jvp_exe!) end function DI.pushforward( f, prep::FastDifferentiationOneArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ty = map(tx) do dx result = prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number @@ -49,11 +57,12 @@ function DI.pushforward!( f, ty::NTuple, prep::FastDifferentiationOneArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.jvp_exe!(myvec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -69,6 +78,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward(f, prep, backend, x, tx, contexts...) end @@ -82,20 +92,28 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Pullback -struct FastDifferentiationOneArgPullbackPrep{E1,E1!} <: DI.PullbackPrep +struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} + _sig::Type{SIG} vjp_exe::E1 vjp_exe!::E1! end function DI.prepare_pullback( - f, ::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -110,17 +128,18 @@ function DI.prepare_pullback( vjp_exe! = make_function( vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationOneArgPullbackPrep(vjp_exe, vjp_exe!) + return FastDifferentiationOneArgPullbackPrep(SIG, vjp_exe, vjp_exe!) end function DI.pullback( f, prep::FastDifferentiationOneArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) tx = map(ty) do dy result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) if x isa Number @@ -136,11 +155,12 @@ function DI.pullback!( f, tx::NTuple, prep::FastDifferentiationOneArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.vjp_exe!(myvec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) @@ -156,6 +176,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pullback(f, prep, backend, x, ty, contexts...) end @@ -169,21 +190,28 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pullback!(f, tx, prep, backend, x, ty, contexts...) end ## Derivative -struct FastDifferentiationOneArgDerivativePrep{Y,E1,E1!} <: DI.DerivativePrep +struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Type{SIG} y_prototype::Y der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f, ::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + f, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -195,16 +223,17 @@ function DI.prepare_derivative( der_vec_var = derivative(y_vec_var, x_var) der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgDerivativePrep(y_prototype, der_exe, der_exe!) + return FastDifferentiationOneArgDerivativePrep(SIG, y_prototype, der_exe, der_exe!) end function DI.derivative( f, prep::FastDifferentiationOneArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) result = prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number return only(result) @@ -217,10 +246,11 @@ function DI.derivative!( f, der, prep::FastDifferentiationOneArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return der end @@ -232,6 +262,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative(f, prep, backend, x, contexts...) end @@ -244,20 +275,27 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient -struct FastDifferentiationOneArgGradientPrep{E1,E1!} <: DI.GradientPrep +struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} + _sig::Type{SIG} jac_exe::E1 jac_exe!::E1! end function DI.prepare_gradient( - f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + f, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -268,16 +306,17 @@ function DI.prepare_gradient( jac_var = jacobian(y_vec_var, x_vec_var) jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgGradientPrep(jac_exe, jac_exe!) + return FastDifferentiationOneArgGradientPrep(SIG, jac_exe, jac_exe!) end function DI.gradient( f, prep::FastDifferentiationOneArgGradientPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) grad_vec = @view jac[1, :] return reshape(grad_vec, size(x)) @@ -287,10 +326,11 @@ function DI.gradient!( f, grad, prep::FastDifferentiationOneArgGradientPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(reshape(grad, 1, length(grad)), myvec(x), map(myvec_unwrap, contexts)...) return grad end @@ -302,6 +342,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end @@ -313,13 +354,15 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian -struct FastDifferentiationOneArgJacobianPrep{Y,E1,E1!} <: DI.JacobianPrep +struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} y_prototype::Y jac_exe::E1 jac_exe!::E1! @@ -329,8 +372,10 @@ function DI.prepare_jacobian( f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -346,16 +391,17 @@ function DI.prepare_jacobian( end jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgJacobianPrep(y_prototype, jac_exe, jac_exe!) + return FastDifferentiationOneArgJacobianPrep(SIG, y_prototype, jac_exe, jac_exe!) end function DI.jacobian( f, prep::FastDifferentiationOneArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) end @@ -363,10 +409,11 @@ function DI.jacobian!( f, jac, prep::FastDifferentiationOneArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac end @@ -378,6 +425,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -389,14 +437,16 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Second derivative -struct FastDifferentiationAllocatingSecondDerivativePrep{Y,D,E2,E2!} <: - DI.SecondDerivativePrep +struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: + DI.SecondDerivativePrep{SIG} + _sig::Type{SIG} y_prototype::Y derivative_prep::D der2_exe::E2 @@ -404,8 +454,13 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{Y,D,E2,E2!} <: end function DI.prepare_second_derivative( - f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + f, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -421,17 +476,18 @@ function DI.prepare_second_derivative( derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) return FastDifferentiationAllocatingSecondDerivativePrep( - y_prototype, derivative_prep, der2_exe, der2_exe! + SIG, y_prototype, derivative_prep, der2_exe, der2_exe! ) end function DI.second_derivative( f, prep::FastDifferentiationAllocatingSecondDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) result = prep.der2_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number return only(result) @@ -444,10 +500,11 @@ function DI.second_derivative!( f, der2, prep::FastDifferentiationAllocatingSecondDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der2_exe!(myvec(der2), myvec(x), map(myvec_unwrap, contexts)...) return der2 end @@ -459,6 +516,7 @@ function DI.value_derivative_and_second_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) der2 = DI.second_derivative(f, prep, backend, x, contexts...) return y, der, der2 @@ -473,6 +531,7 @@ function DI.value_derivative_and_second_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) DI.second_derivative!(f, der2, prep, backend, x, contexts...) return y, der, der2 @@ -480,15 +539,22 @@ end ## HVP -struct FastDifferentiationHVPPrep{E2,E2!,E1} <: DI.HVPPrep +struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} + sig::Type{SIG} hvp_exe::E2 hvp_exe!::E2! gradient_prep::E1 end function DI.prepare_hvp( - f, backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -504,17 +570,18 @@ function DI.prepare_hvp( ) gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) - return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep) + return FastDifferentiationHVPPrep(SIG, hvp_exe, hvp_exe!, gradient_prep) end function DI.hvp( f, prep::FastDifferentiationHVPPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) tg = map(tx) do dx dg_vec = prep.hvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) return reshape(dg_vec, size(x)) @@ -526,11 +593,12 @@ function DI.hvp!( f, tg::NTuple, prep::FastDifferentiationHVPPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] prep.hvp_exe!(myvec(dg), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -546,6 +614,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) tg = DI.hvp(f, prep, backend, x, tx, contexts...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -561,6 +630,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hvp!(f, tg, prep, backend, x, tx, contexts...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -568,7 +638,8 @@ end ## Hessian -struct FastDifferentiationHessianPrep{G,E2,E2!} <: DI.HessianPrep +struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} + _sig::Type{SIG} gradient_prep::G hess_exe::E2 hess_exe!::E2! @@ -578,8 +649,10 @@ function DI.prepare_hessian( f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -596,7 +669,7 @@ function DI.prepare_hessian( hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true) gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) - return FastDifferentiationHessianPrep(gradient_prep, hess_exe, hess_exe!) + return FastDifferentiationHessianPrep(SIG, gradient_prep, hess_exe, hess_exe!) end function DI.hessian( @@ -606,6 +679,7 @@ function DI.hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.hess_exe(myvec(x), map(myvec_unwrap, contexts)...) end @@ -617,6 +691,7 @@ function DI.hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.hess_exe!(hess, myvec(x), map(myvec_unwrap, contexts)...) return hess end @@ -628,6 +703,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, dense_ad(backend), x, contexts... ) @@ -644,6 +720,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index 6133213e3..4e5e59c73 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -1,13 +1,21 @@ ## Pushforward -struct FastDifferentiationTwoArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep +struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Type{SIG} jvp_exe::E1 jvp_exe!::E1! end function DI.prepare_pushforward( - f!, y, ::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -23,18 +31,19 @@ function DI.prepare_pushforward( jvp_exe! = make_function( jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationTwoArgPushforwardPrep(jvp_exe, jvp_exe!) + return FastDifferentiationTwoArgPushforwardPrep(SIG, jvp_exe, jvp_exe!) end function DI.pushforward( f!, y, prep::FastDifferentiationTwoArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx reshape(prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...), size(y)) end @@ -46,11 +55,12 @@ function DI.pushforward!( y, ty::NTuple, prep::FastDifferentiationTwoArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.jvp_exe!(myvec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -67,6 +77,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -82,6 +93,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -89,14 +101,22 @@ end ## Pullback -struct FastDifferentiationTwoArgPullbackPrep{E1,E1!} <: DI.PullbackPrep +struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} + _sig::Type{SIG} vjp_exe::E1 vjp_exe!::E1! end function DI.prepare_pullback( - f!, y, ::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f!, y, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -112,18 +132,19 @@ function DI.prepare_pullback( vjp_exe! = make_function( vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationTwoArgPullbackPrep(vjp_exe, vjp_exe!) + return FastDifferentiationTwoArgPullbackPrep(SIG, vjp_exe, vjp_exe!) end function DI.pullback( f!, y, prep::FastDifferentiationTwoArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) if x isa Number @@ -140,11 +161,12 @@ function DI.pullback!( y, tx::NTuple, prep::FastDifferentiationTwoArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.vjp_exe!(myvec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) @@ -161,6 +183,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = DI.pullback(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, tx @@ -176,6 +199,7 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) DI.pullback!(f!, y, tx, prep, backend, x, ty, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, tx @@ -183,14 +207,21 @@ end ## Derivative -struct FastDifferentiationTwoArgDerivativePrep{E1,E1!} <: DI.DerivativePrep +struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Type{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f!, y, ::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -202,17 +233,18 @@ function DI.prepare_derivative( der_vec_var = derivative(y_vec_var, x_var) der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationTwoArgDerivativePrep(der_exe, der_exe!) + return FastDifferentiationTwoArgDerivativePrep(SIG, der_exe, der_exe!) end function DI.value_and_derivative( f!, y, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) return y, der @@ -223,10 +255,11 @@ function DI.value_and_derivative!( y, der, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return y, der @@ -236,10 +269,11 @@ function DI.derivative( f!, y, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) return der end @@ -249,17 +283,19 @@ function DI.derivative!( y, der, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return der end ## Jacobian -struct FastDifferentiationTwoArgJacobianPrep{E1,E1!} <: DI.JacobianPrep +struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} jac_exe::E1 jac_exe!::E1! end @@ -269,8 +305,10 @@ function DI.prepare_jacobian( y, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -286,17 +324,18 @@ function DI.prepare_jacobian( end jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationTwoArgJacobianPrep(jac_exe, jac_exe!) + return FastDifferentiationTwoArgJacobianPrep(SIG, jac_exe, jac_exe!) end function DI.value_and_jacobian( f!, y, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) return y, jac @@ -307,10 +346,11 @@ function DI.value_and_jacobian!( y, jac, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return y, jac @@ -320,10 +360,11 @@ function DI.jacobian( f!, y, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) return jac end @@ -333,10 +374,11 @@ function DI.jacobian!( y, jac, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 05cf3a999..1024f2119 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -1,6 +1,7 @@ ## Pushforward -struct FiniteDiffOneArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep +struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} + _sig::Type{SIG} cache::C relstep::R absstep::A @@ -8,8 +9,14 @@ struct FiniteDiffOneArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep end function DI.prepare_pushforward( - f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f, backend, x, tx, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) cache = if x isa Number || y isa Number @@ -28,7 +35,7 @@ function DI.prepare_pushforward( backend.relstep end dir = backend.dir - return FiniteDiffOneArgPushforwardPrep(cache, relstep, absstep, dir) + return FiniteDiffOneArgPushforwardPrep(SIG, cache, relstep, absstep, dir) end function DI.pushforward( @@ -39,6 +46,7 @@ function DI.pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) ty = map(tx) do dx @@ -57,6 +65,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) y = f(x, map(DI.unwrap, contexts)...) @@ -78,11 +87,12 @@ end function DI.pushforward( f, prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) ty = map(tx) do dx @@ -94,11 +104,12 @@ end function DI.value_and_pushforward( f, prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -110,7 +121,8 @@ end ## Derivative -struct FiniteDiffOneArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep +struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} + _sig::Type{SIG} cache::C relstep::R absstep::A @@ -118,8 +130,9 @@ struct FiniteDiffOneArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep end function DI.prepare_derivative( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) cache = if y isa Number @@ -139,7 +152,7 @@ function DI.prepare_derivative( backend.relstep end dir = backend.dir - return FiniteDiffOneArgDerivativePrep(cache, relstep, absstep, dir) + return FiniteDiffOneArgDerivativePrep(SIG, cache, relstep, absstep, dir) end ### Scalar to scalar @@ -151,6 +164,7 @@ function DI.derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_derivative(fc, x, fdtype(backend); relstep, absstep, dir) @@ -163,6 +177,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -179,10 +194,11 @@ end function DI.derivative( f, prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) @@ -192,10 +208,11 @@ function DI.derivative!( f, der, prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir) @@ -204,10 +221,11 @@ end function DI.value_and_derivative( f, prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) (; relstep, absstep, dir) = prep y = fc(x) @@ -218,10 +236,11 @@ function DI.value_and_derivative!( f, der, prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return ( @@ -231,7 +250,8 @@ end ## Gradient -struct FiniteDiffGradientPrep{C,R,A,D} <: DI.GradientPrep +struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} + _sig::Type{SIG} cache::C relstep::R absstep::A @@ -239,8 +259,9 @@ struct FiniteDiffGradientPrep{C,R,A,D} <: DI.GradientPrep end function DI.prepare_gradient( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) df = zero(y) .* x @@ -256,16 +277,17 @@ function DI.prepare_gradient( backend.relstep end dir = backend.dir - return FiniteDiffGradientPrep(cache, relstep, absstep, dir) + return FiniteDiffGradientPrep(SIG, cache, relstep, absstep, dir) end function DI.gradient( f, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) @@ -274,10 +296,11 @@ end function DI.value_and_gradient( f, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return fc(x), finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) @@ -287,10 +310,11 @@ function DI.gradient!( f, grad, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir) @@ -300,10 +324,11 @@ function DI.value_and_gradient!( f, grad, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return ( @@ -313,7 +338,8 @@ end ## Jacobian -struct FiniteDiffOneArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep +struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} cache::C relstep::R absstep::A @@ -321,8 +347,9 @@ struct FiniteDiffOneArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep end function DI.prepare_jacobian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) x1 = similar(x) @@ -340,16 +367,17 @@ function DI.prepare_jacobian( backend.relstep end dir = backend.dir - return FiniteDiffOneArgJacobianPrep(cache, relstep, absstep, dir) + return FiniteDiffOneArgJacobianPrep(SIG, cache, relstep, absstep, dir) end function DI.jacobian( f, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_jacobian(fc, x, prep.cache; relstep, absstep, dir) @@ -358,10 +386,11 @@ end function DI.value_and_jacobian( f, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) (; relstep, absstep, dir) = prep y = fc(x) @@ -372,10 +401,11 @@ function DI.jacobian!( f, jac, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return copyto!( @@ -390,10 +420,11 @@ function DI.value_and_jacobian!( f, jac, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -410,7 +441,8 @@ end ## Hessian -struct FiniteDiffHessianPrep{C1,C2,RG,AG,RH,AH} <: DI.HessianPrep +struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} + _sig::Type{SIG} gradient_cache::C1 hessian_cache::C2 relstep_g::RG @@ -420,8 +452,9 @@ struct FiniteDiffHessianPrep{C1,C2,RG,AG,RH,AH} <: DI.HessianPrep end function DI.prepare_hessian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) df = zero(y) .* x @@ -440,13 +473,18 @@ function DI.prepare_hessian( absstep_g = isnothing(backend.absstep) ? relstep_g : backend.absstep absstep_h = isnothing(backend.absstep) ? relstep_h : backend.absstep return FiniteDiffHessianPrep( - gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h + SIG, gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h ) end function DI.hessian( - f, prep::FiniteDiffHessianPrep, ::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_hessian( @@ -458,10 +496,11 @@ function DI.hessian!( f, hess, prep::FiniteDiffHessianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_hessian!( @@ -470,8 +509,13 @@ function DI.hessian!( end function DI.value_gradient_and_hessian( - f, prep::FiniteDiffHessianPrep, ::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) grad = finite_difference_gradient( @@ -488,10 +532,11 @@ function DI.value_gradient_and_hessian!( grad, hess, prep::FiniteDiffHessianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) finite_difference_gradient!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index abd76ed68..b36adf5d7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -1,6 +1,7 @@ ## Pushforward -struct FiniteDiffTwoArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep +struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} + _sig::Type{SIG} cache::C relstep::R absstep::A @@ -8,8 +9,15 @@ struct FiniteDiffTwoArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep end function DI.prepare_pushforward( - f!, y, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} + SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) cache = if x isa Number nothing else @@ -26,7 +34,7 @@ function DI.prepare_pushforward( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgPushforwardPrep(cache, relstep, absstep, dir) + return FiniteDiffTwoArgPushforwardPrep(SIG, cache, relstep, absstep, dir) end function DI.value_and_pushforward( @@ -38,6 +46,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep function step(t::Number, dx) new_y = similar(y) @@ -64,11 +73,12 @@ function DI.pushforward( f!, y, prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) ty = map(tx) do dx @@ -83,11 +93,12 @@ function DI.value_and_pushforward( f!, y, prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) ty = map(tx) do dx @@ -104,11 +115,12 @@ function DI.pushforward!( y, ty::NTuple, prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) for b in eachindex(tx, ty) @@ -123,11 +135,12 @@ function DI.value_and_pushforward!( y, ty::NTuple, prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) for b in eachindex(tx, ty) @@ -140,7 +153,8 @@ end ## Derivative -struct FiniteDiffTwoArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep +struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} + _sig::Type{SIG} cache::C relstep::R absstep::A @@ -148,8 +162,9 @@ struct FiniteDiffTwoArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep end function DI.prepare_derivative( - f!, y, backend::AutoFiniteDiff, x, ::Vararg{DI.Context,C} + f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) df = similar(y) cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_INPLACE) relstep = if isnothing(backend.relstep) @@ -163,7 +178,7 @@ function DI.prepare_derivative( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgDerivativePrep(cache, relstep, absstep, dir) + return FiniteDiffTwoArgDerivativePrep(SIG, cache, relstep, absstep, dir) end function DI.prepare!_derivative( @@ -174,6 +189,7 @@ function DI.prepare!_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if y isa Vector (; cache) = old_prep cache.fx isa Union{Number,Nothing} || resize!(cache.fx, length(y)) @@ -182,7 +198,9 @@ function DI.prepare!_derivative( cache.c3 isa Union{Number,Nothing} || resize!(cache.c3, length(y)) return old_prep else - return DI.prepare_derivative(f!, y, backend, x, contexts...) + return DI.prepare_derivative( + f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) + ) end end @@ -190,10 +208,11 @@ function DI.value_and_derivative( f!, y, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) fc!(y, x) @@ -206,10 +225,11 @@ function DI.value_and_derivative!( y, der, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) fc!(y, x) @@ -221,10 +241,11 @@ function DI.derivative( f!, y, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) fc!(y, x) @@ -237,10 +258,11 @@ function DI.derivative!( y, der, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir) @@ -249,7 +271,8 @@ end ## Jacobian -struct FiniteDiffTwoArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep +struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} cache::C relstep::R absstep::A @@ -257,8 +280,9 @@ struct FiniteDiffTwoArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) x1 = similar(x) fx = similar(y) fx1 = similar(y) @@ -274,7 +298,7 @@ function DI.prepare_jacobian( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgJacobianPrep(cache, relstep, absstep, dir) + return FiniteDiffTwoArgJacobianPrep(SIG, cache, relstep, absstep, dir) end function DI.prepare!_jacobian( @@ -285,6 +309,7 @@ function DI.prepare!_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if x isa Vector && y isa Vector (; cache) = old_prep cache.x1 isa Union{Number,Nothing} || resize!(cache.x1, length(x)) @@ -295,7 +320,9 @@ function DI.prepare!_jacobian( cache.sparsity = nothing return old_prep else - return DI.prepare_jacobian(f!, y, backend, x, contexts...) + return DI.prepare_jacobian( + f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) + ) end end @@ -303,10 +330,11 @@ function DI.value_and_jacobian( f!, y, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -320,10 +348,11 @@ function DI.value_and_jacobian!( y, jac, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) @@ -335,10 +364,11 @@ function DI.jacobian( f!, y, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -351,10 +381,11 @@ function DI.jacobian!( y, jac, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index f2f692002..fca32bdc8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -12,19 +12,26 @@ DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple() ## Pushforward function DI.prepare_pushforward( - f, ::AutoFiniteDifferences, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoFiniteDifferences, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.NoPushforwardPrep() + SIG = DI.signature(f, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep{SIG}() end function DI.pushforward( f, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoFiniteDifferences, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) ty = map(tx) do dx jvp(backend.fdm, fc, (x, dx)) @@ -40,6 +47,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward(f, prep, backend, x, tx, contexts...) end @@ -47,19 +55,26 @@ end ## Pullback function DI.prepare_pullback( - f, ::AutoFiniteDifferences, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoFiniteDifferences, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.NoPullbackPrep() + SIG = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep{SIG}() end function DI.pullback( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoFiniteDifferences, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.with_contexts(f, contexts...) tx = map(ty) do dy only(j′vp(backend.fdm, fc, dy, x)) @@ -75,6 +90,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pullback(f, prep, backend, x, ty, contexts...) end @@ -82,18 +98,20 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C} + f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} - return DI.NoGradientPrep() + SIG = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep{SIG}() end function DI.gradient( f, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return only(grad(backend.fdm, fc, x)) end @@ -105,6 +123,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end @@ -116,6 +135,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end @@ -127,6 +147,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end @@ -134,18 +155,20 @@ end ## Jacobian function DI.prepare_jacobian( - f, ::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C} + f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} - return DI.NoJacobianPrep() + SIG = DI.signature(f, backend, x, contexts...; strict) + return DI.NoJacobianPrep{SIG}() end function DI.jacobian( f, - ::DI.NoJacobianPrep, + prep::DI.NoJacobianPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return only(jacobian(backend.fdm, fc, x)) end @@ -157,6 +180,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -168,6 +192,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -179,6 +204,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index 2e46031a0..996c659c6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -33,7 +33,6 @@ DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload() include("utils.jl") include("onearg.jl") include("twoarg.jl") -# include("secondorder.jl") include("differentiate_with.jl") include("misc.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 43c0f44be..35133ae8e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -60,14 +60,20 @@ end ### Prepared -struct ForwardDiffOneArgPushforwardPrep{T,X,CD} <: DI.PushforwardPrep +struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} xdual_tmp::X contexts_dual::CD end function DI.prepare_pushforward( - f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C} + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,B,C} + SIG = DI.signature(f, backend, x, tx, contexts...; strict) T = tag_type(f, backend, x) if DI.ismutable_array(x) xdual_tmp = make_dual_similar(T, x, tx) @@ -75,7 +81,7 @@ function DI.prepare_pushforward( xdual_tmp = nothing end contexts_dual = translate_toprep(Dual{T,eltype(x),B}, contexts) - return ForwardDiffOneArgPushforwardPrep{T,typeof(xdual_tmp),typeof(contexts_dual)}( + return ForwardDiffOneArgPushforwardPrep{SIG,T,typeof(xdual_tmp),typeof(contexts_dual)}( xdual_tmp, contexts_dual ) end @@ -114,11 +120,12 @@ end function DI.value_and_pushforward( f::F, prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,T,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) ty = mypartials(T, Val(B), ydual) @@ -129,11 +136,12 @@ function DI.value_and_pushforward!( f::F, ty::NTuple, prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,T,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) mypartials!(T, ty, ydual) @@ -143,11 +151,12 @@ end function DI.pushforward( f::F, prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,T,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual) return ty @@ -157,11 +166,12 @@ function DI.pushforward!( f::F, ty::NTuple, prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,T,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) mypartials!(T, ty, ydual) return ty @@ -169,7 +179,7 @@ end ## Derivative -struct ForwardDiffOneArgDerivativePrep{E} <: DI.DerivativePrep +struct ForwardDiffOneArgDerivativePrep{SIG,E} <: DI.DerivativePrep{SIG} pushforward_prep::E end @@ -205,10 +215,11 @@ end ### Prepared function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} - pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) - return ForwardDiffOneArgDerivativePrep(pushforward_prep) + SIG = DI.signature(f, backend, x, contexts...; strict) + pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...; strict) + return ForwardDiffOneArgDerivativePrep{SIG,typeof(pushforward_prep)}(pushforward_prep) end function DI.value_and_derivative( @@ -218,6 +229,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -232,6 +244,7 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -245,6 +258,7 @@ function DI.derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) ) @@ -258,6 +272,7 @@ function DI.derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der end @@ -338,19 +353,25 @@ end ### Prepared -struct ForwardDiffGradientPrep{C,CD} <: DI.GradientPrep +struct ForwardDiffGradientPrep{SIG,C,CD} <: DI.GradientPrep{SIG} + _sig::Type{SIG} config::C contexts_dual::CD end function DI.prepare_gradient( - f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{DI.Context,C} + f::F, + backend::AutoForwardDiff, + x::AbstractArray, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) config = GradientConfig(nothing, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffGradientPrep(config, contexts_dual) + return ForwardDiffGradientPrep(SIG, config, contexts_dual) end function DI.value_and_gradient!( @@ -361,6 +382,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -381,6 +403,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = GradientResult(x) @@ -400,6 +423,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -416,6 +440,7 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -500,19 +525,21 @@ end ### Prepared -struct ForwardDiffOneArgJacobianPrep{C,CD} <: DI.JacobianPrep +struct ForwardDiffOneArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} config::C contexts_dual::CD end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} + SIG = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) config = JacobianConfig(nothing, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffOneArgJacobianPrep(config, contexts_dual) + return ForwardDiffOneArgJacobianPrep(SIG, config, contexts_dual) end function DI.value_and_jacobian!( @@ -523,6 +550,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) y = fc(x) @@ -544,6 +572,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -561,6 +590,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -577,6 +607,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -589,18 +620,20 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} - return DI.NoSecondDerivativePrep() + SIG = DI.signature(f, backend, x, contexts...; strict) + return DI.NoSecondDerivativePrep{SIG}() end function DI.second_derivative( f::F, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -613,11 +646,12 @@ end function DI.second_derivative!( f::F, der2, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -629,11 +663,12 @@ end function DI.value_derivative_and_second_derivative( f::F, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -650,11 +685,12 @@ function DI.value_derivative_and_second_derivative!( f::F, der, der2, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -667,67 +703,6 @@ function DI.value_derivative_and_second_derivative!( return y, der, der2 end -## HVP - -#= -function DI.prepare_hvp( - f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} -) where {F,C} - return DI.prepare_hvp(f, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.hvp(f, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.hvp!(f, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.gradient_and_hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.gradient_and_hvp!( - f, grad, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end -=# - ## Hessian ### Unprepared, only when chunk size and tag are not specified @@ -810,22 +785,24 @@ end ### Prepared -struct ForwardDiffHessianPrep{C1,C2,CD} <: DI.HessianPrep +struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} + _sig::Type{SIG} array_config::C1 result_config::C2 contexts_dual::CD end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} + SIG = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) result = HessianResult(x) array_config = HessianConfig(nothing, x, chunk, tag) result_config = HessianConfig(nothing, result, x, chunk, tag) contexts_dual = translate_toprep(dual_type(array_config), contexts) - return ForwardDiffHessianPrep(array_config, result_config, contexts_dual) + return ForwardDiffHessianPrep(SIG, array_config, result_config, contexts_dual) end function DI.hessian!( @@ -836,6 +813,7 @@ function DI.hessian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -852,6 +830,7 @@ function DI.hessian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -870,6 +849,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -891,6 +871,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl deleted file mode 100644 index d01efd074..000000000 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ /dev/null @@ -1,144 +0,0 @@ -struct ForwardDiffOverSomethingHVPPrep{E1<:DI.GradientPrep,E2<:DI.PushforwardPrep} <: - DI.HVPPrep - inner_gradient_prep::E1 - outer_pushforward_prep::E2 -end - -function DI.prepare_hvp( - f::F, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - T = tag_type(DI.shuffled_gradient, DI.outer(backend), x) - xdual = make_dual(T, x, tx) - inner_gradient_prep = DI.prepare_gradient(f, DI.inner(backend), xdual, contexts...) - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - outer_pushforward_prep = DI.prepare_pushforward( - DI.shuffled_gradient, DI.outer(backend), x, tx, new_contexts... - ) - return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep) -end - -function DI.hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward!( - DI.shuffled_gradient, - tg, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) - return tg -end - -function DI.gradient_and_hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.value_and_pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - new_grad, _ = DI.value_and_pushforward!( - DI.shuffled_gradient, - tg, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg -end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index f52ae69d9..47fb69b6d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -1,20 +1,27 @@ ## Pushforward -struct ForwardDiffTwoArgPushforwardPrep{T,X,Y,CD} <: DI.PushforwardPrep +struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} xdual_tmp::X ydual_tmp::Y contexts_dual::CD end function DI.prepare_pushforward( - f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C} + f!::F, + y, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,B,C} + SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) T = tag_type(f!, backend, x) xdual_tmp = make_dual_similar(T, x, tx) ydual_tmp = make_dual_similar(T, y, tx) # tx only for batch size contexts_dual = translate_toprep(eltype(xdual_tmp), contexts) return ForwardDiffTwoArgPushforwardPrep{ - T,typeof(xdual_tmp),typeof(ydual_tmp),typeof(contexts_dual) + SIG,T,typeof(xdual_tmp),typeof(ydual_tmp),typeof(contexts_dual) }( xdual_tmp, ydual_tmp, contexts_dual ) @@ -54,11 +61,12 @@ function DI.value_and_pushforward( f!::F, y, prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,T,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) ty = mypartials(T, Val(B), ydual_tmp) @@ -70,11 +78,12 @@ function DI.value_and_pushforward!( y, ty::NTuple, prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,T,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) mypartials!(T, ty, ydual_tmp) @@ -85,11 +94,12 @@ function DI.pushforward( f!::F, y, prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,T,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual_tmp) return ty @@ -100,11 +110,12 @@ function DI.pushforward!( y, ty::NTuple, prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,T,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) mypartials!(T, ty, ydual_tmp) return ty @@ -168,18 +179,25 @@ end ### Prepared -struct ForwardDiffTwoArgDerivativePrep{C,CD} <: DI.DerivativePrep +struct ForwardDiffTwoArgDerivativePrep{SIG,C,CD} <: DI.DerivativePrep{SIG} + _sig::Type{SIG} config::C contexts_dual::CD end function DI.prepare_derivative( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) tag = get_tag(f!, backend, x) config = DerivativeConfig(nothing, y, x, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgDerivativePrep(config, contexts_dual) + return ForwardDiffTwoArgDerivativePrep(SIG, config, contexts_dual) end function DI.prepare!_derivative( @@ -190,12 +208,15 @@ function DI.prepare!_derivative( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if y isa Vector (; config) = old_prep resize!(config.duals, length(y)) return old_prep else - return DI.prepare_derivative(f!, y, backend, x, contexts...) + return DI.prepare_derivative( + f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) + ) end end @@ -207,6 +228,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) result = MutableDiffResult(y, (similar(y),)) @@ -227,6 +249,7 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) result = MutableDiffResult(y, (der,)) @@ -246,6 +269,7 @@ function DI.derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -264,6 +288,7 @@ function DI.derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -348,19 +373,26 @@ end ### Prepared -struct ForwardDiffTwoArgJacobianPrep{C,CD} <: DI.JacobianPrep +struct ForwardDiffTwoArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} config::C contexts_dual::CD end function DI.prepare_jacobian( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f!, backend, x) config = JacobianConfig(nothing, y, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgJacobianPrep(config, contexts_dual) + return ForwardDiffTwoArgJacobianPrep(SIG, config, contexts_dual) end function DI.prepare!_jacobian( @@ -371,6 +403,7 @@ function DI.prepare!_jacobian( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if x isa Vector && y isa Vector (; config) = old_prep (yduals, xduals) = config.duals @@ -378,7 +411,9 @@ function DI.prepare!_jacobian( resize!(xduals, length(x)) return old_prep else - return DI.prepare_jacobian(f!, y, backend, x, contexts...) + return DI.prepare_jacobian( + f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) + ) end end @@ -390,6 +425,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) jac = similar(y, length(y), length(x)) @@ -411,6 +447,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) result = MutableDiffResult(y, (jac,)) @@ -430,6 +467,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -448,6 +486,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index c094bd14a..2dd4409cc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -19,7 +19,7 @@ choose_chunk(::AutoForwardDiff{chunksize}, x) where {chunksize} = Chunk{chunksiz get_tag(f, backend::AutoForwardDiff, x) = backend.tag -function get_tag(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize} +function get_tag(f::F, backend::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize} return Tag(f, eltype(x)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index ca7937d63..59f081340 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -2,14 +2,20 @@ # Contains either a single pre-allocated initial TPS # or a vector of pre-allocated TPSs. -struct GTPSAOneArgPushforwardPrep{X} <: DI.PushforwardPrep +struct GTPSAOneArgPushforwardPrep{SIG,X} <: DI.PushforwardPrep{SIG} + _sig::Type{SIG} xt::X end function DI.prepare_pushforward( - ::F, backend::AutoGTPSA{D}, x, tx::NTuple, ::Vararg{DI.Constant,C} + f::F, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant,C}; + strict::Bool=false, ) where {F,D,C} - + SIG = DI.signature(f, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. if D != Nothing @@ -25,18 +31,19 @@ function DI.prepare_pushforward( for i in eachindex(xt) xt[i] = TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}(; use=d) end - return GTPSAOneArgPushforwardPrep(xt) + return GTPSAOneArgPushforwardPrep(SIG, xt) end end function DI.pushforward( f, prep::GTPSAOneArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) ty = map(tx) do dx foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) @@ -55,11 +62,12 @@ function DI.pushforward!( f, ty::NTuple, prep::GTPSAOneArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -78,6 +86,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) ty = DI.pushforward(fc, prep, backend, x, tx) y = fc(x) # TO-DO: optimize @@ -93,6 +102,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) DI.pushforward!(fc, ty, prep, backend, x, tx) y = fc(x) # TO-DO: optimize @@ -101,14 +111,16 @@ end ## Gradient # Contains a vector of pre-allocated TPSs. -struct GTPSAOneArgGradientPrep{X} <: DI.GradientPrep +struct GTPSAOneArgGradientPrep{SIG,X} <: DI.GradientPrep{SIG} + _sig::Type{SIG} xt::X end # Unlike JVP, this requires us to use all variables function DI.prepare_gradient( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false ) where {D,C} + SIG = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -121,12 +133,13 @@ function DI.prepare_gradient( xt[i][j] = 1 j += 1 end - return GTPSAOneArgGradientPrep(xt) + return GTPSAOneArgGradientPrep(SIG, xt) end function DI.gradient( - f, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -136,8 +149,14 @@ function DI.gradient( end function DI.gradient!( - f, grad, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + grad, + prep::GTPSAOneArgGradientPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -146,8 +165,9 @@ function DI.gradient!( end function DI.value_and_gradient( - f, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -157,8 +177,14 @@ function DI.value_and_gradient( end function DI.value_and_gradient!( - f, grad, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + grad, + prep::GTPSAOneArgGradientPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -168,14 +194,16 @@ end ## Jacobian # Contains a vector of pre-allocated TPSs -struct GTPSAOneArgJacobianPrep{X} <: DI.JacobianPrep +struct GTPSAOneArgJacobianPrep{SIG,X} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} xt::X end # To materialize the entire Jacobian we use all variables function DI.prepare_jacobian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false ) where {D,C} + SIG = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -190,12 +218,13 @@ function DI.prepare_jacobian( xt[i][j] = 1 j += 1 end - return GTPSAOneArgJacobianPrep(xt) + return GTPSAOneArgJacobianPrep(SIG, xt) end function DI.jacobian( - f, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -205,8 +234,14 @@ function DI.jacobian( end function DI.jacobian!( - f, jac, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::GTPSAOneArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -215,8 +250,9 @@ function DI.jacobian!( end function DI.value_and_jacobian( - f, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -227,8 +263,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f, jac, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::GTPSAOneArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -239,13 +281,15 @@ end ## Second derivative # Contains single pre-allocated TPS -struct GTPSAOneArgSecondDerivativePrep{X} <: DI.SecondDerivativePrep +struct GTPSAOneArgSecondDerivativePrep{SIG,X} <: DI.SecondDerivativePrep{SIG} + _sig::Type{SIG} xt::X end function DI.prepare_second_derivative( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false ) where {D,C} + SIG = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -253,7 +297,7 @@ function DI.prepare_second_derivative( end xt = TPS{promote_type(typeof(x), Float64)}(; use=d) xt[1] = 1 # Set slope - return GTPSAOneArgSecondDerivativePrep(xt) + return GTPSAOneArgSecondDerivativePrep(SIG, xt) end function DI.second_derivative( @@ -263,6 +307,7 @@ function DI.second_derivative( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -291,6 +336,7 @@ function DI.second_derivative!( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -312,6 +358,7 @@ function DI.value_derivative_and_second_derivative( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -343,6 +390,7 @@ function DI.value_derivative_and_second_derivative!( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -362,14 +410,16 @@ end ## Hessian # Stores allocated array of TPS and an array for the monomial coefficient # indexing in GTPSA.cycle! (which is used if a Descriptor is specified) -struct GTPSAOneArgHessianPrep{X,M} <: DI.HessianPrep +struct GTPSAOneArgHessianPrep{SIG,X,M} <: DI.HessianPrep{SIG} + _sig::Type{SIG} xt::X m::M end function DI.prepare_hessian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false ) where {D,C} + SIG = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor m = Vector{UInt8}(undef, length(x)) @@ -392,12 +442,17 @@ function DI.prepare_hessian( j += 1 end - return GTPSAOneArgHessianPrep(xt, m) + return GTPSAOneArgHessianPrep(SIG, xt, m) end function DI.hessian( - f, prep::GTPSAOneArgHessianPrep, ::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -418,10 +473,11 @@ function DI.hessian!( f, hess, prep::GTPSAOneArgHessianPrep, - ::AutoGTPSA{D}, + backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -438,8 +494,13 @@ function DI.hessian!( end function DI.value_gradient_and_hessian( - f, prep::GTPSAOneArgHessianPrep, ::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -463,10 +524,11 @@ function DI.value_gradient_and_hessian!( grad, hess, prep::GTPSAOneArgHessianPrep, - ::AutoGTPSA{D}, + backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -483,18 +545,25 @@ function DI.value_gradient_and_hessian!( return yt[0], grad, hess end -struct GTPSAOneArgHVPPrep{E,H} <: DI.HVPPrep +struct GTPSAOneArgHVPPrep{SIG,E,H} <: DI.HVPPrep{SIG} + _sig::Type{SIG} hessprep::E hess::H end function DI.prepare_hvp( - f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f, + backend::AutoGTPSA, + x, + tx::NTuple, + contexts::Vararg{DI.Constant,C}; + strict::Bool=false, ) where {C} - hessprep = DI.prepare_hessian(f, backend, x) + SIG = DI.signature(f, backend, x, tx, contexts...; strict) + hessprep = DI.prepare_hessian(f, backend, x; strict) fc = DI.with_contexts(f, contexts...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) - return GTPSAOneArgHVPPrep(hessprep, hess) + return GTPSAOneArgHVPPrep(SIG, hessprep, hess) end function DI.hvp( @@ -505,6 +574,7 @@ function DI.hvp( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hessian!(f, prep.hess, prep.hessprep, backend, x, contexts...) tg = map(tx) do dx dg = similar(x, eltype(prep.hess)) @@ -530,6 +600,7 @@ function DI.hvp!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hessian!(f, prep.hess, prep.hessprep, backend, x, contexts...) for b in eachindex(tg) dx, dg = tx[b], tg[b] @@ -553,6 +624,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) grad = similar(x, eltype(prep.hess)) DI.value_gradient_and_hessian!( f, grad, prep.hess, prep.hessprep, backend, x, contexts... @@ -582,6 +654,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.value_gradient_and_hessian!( f, grad, prep.hess, prep.hessprep, backend, x, contexts... ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index 3e2e88de5..57833fdd2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -4,15 +4,22 @@ # or a vector of pre-allocated TPSs. # # Output: Contains a vector of pre-allocated TPSs -struct GTPSATwoArgPushforwardPrep{X,Y} <: DI.PushforwardPrep +struct GTPSATwoArgPushforwardPrep{SIG,X,Y} <: DI.PushforwardPrep{SIG} + _sig::Type{SIG} xt::X yt::Y end function DI.prepare_pushforward( - ::F, y, backend::AutoGTPSA{D}, x, tx::NTuple, ::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant,C}; + strict::Bool=false, ) where {F,D,C} - + SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. if D != Nothing @@ -34,18 +41,19 @@ function DI.prepare_pushforward( for i in eachindex(yt) yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) end - return GTPSATwoArgPushforwardPrep(xt, yt) + return GTPSATwoArgPushforwardPrep(SIG, xt, yt) end function DI.pushforward( f!, y, prep::GTPSATwoArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) fc! = DI.with_contexts(f!, contexts...) ty = map(tx) do dx foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) @@ -62,11 +70,12 @@ function DI.pushforward!( y, ty::NTuple, prep::GTPSATwoArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) fc! = DI.with_contexts(f!, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -87,6 +96,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) return y, ty end @@ -101,6 +111,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return y, ty end @@ -108,14 +119,16 @@ end ## Jacobian # Input: Contains a vector of pre-allocated TPSs # Output: Contains a vector of pre-allocated TPSs -struct GTPSATwoArgJacobianPrep{X,Y} <: DI.JacobianPrep +struct GTPSATwoArgJacobianPrep{SIG,X,Y} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} xt::X yt::Y end function DI.prepare_jacobian( - f, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false ) where {D,C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -137,12 +150,18 @@ function DI.prepare_jacobian( yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) end - return GTPSATwoArgJacobianPrep(xt, yt) + return GTPSATwoArgJacobianPrep(SIG, xt, yt) end function DI.jacobian( - f!, y, prep::GTPSATwoArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f!, + y, + prep::GTPSATwoArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc! = DI.with_contexts(f!, contexts...) fc!(prep.yt, prep.xt) @@ -157,10 +176,11 @@ function DI.jacobian!( y, jac, prep::GTPSATwoArgJacobianPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc! = DI.with_contexts(f!, contexts...) fc!(prep.yt, prep.xt) @@ -177,6 +197,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) # y set on line 151 return y, jac end @@ -190,6 +211,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) return y, jac end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 81acd8673..a8760d57a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,20 +1,27 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{Tcache,DY} <: DI.PullbackPrep +struct MooncakeOneArgPullbackPrep{SIG,Tcache,DY} <: DI.PullbackPrep{SIG} + _sig::Type{SIG} cache::Tcache dy_righttype::DY end function DI.prepare_pullback( - f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f::F, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) - prep = MooncakeOneArgPullbackPrep(cache, dy_righttype) + prep = MooncakeOneArgPullbackPrep(SIG, cache, dy_righttype) DI.value_and_pullback(f, prep, backend, x, ty, contexts...) return prep end @@ -22,11 +29,12 @@ end function DI.value_and_pullback( f::F, prep::MooncakeOneArgPullbackPrep{Y}, - ::AutoMooncake, + backend::AutoMooncake, x, ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,Y,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( @@ -44,6 +52,7 @@ function DI.value_and_pullback!( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,Y,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) y, (new_dx,) = DI.value_and_pullback(f, prep, backend, x, ty, contexts...) copyto!(only(tx), new_dx) return y, tx @@ -57,6 +66,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy y, tx = DI.value_and_pullback(f, prep, backend, x, (dy,), contexts...) y, only(tx) @@ -75,6 +85,7 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) ys = map(tx, ty) do dx, dy y, _ = DI.value_and_pullback!(f, (dx,), prep, backend, x, (dy,), contexts...) y @@ -91,6 +102,7 @@ function DI.pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return DI.value_and_pullback(f, prep, backend, x, ty, contexts...)[2] end @@ -103,30 +115,38 @@ function DI.pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2] end ## Gradient -struct MooncakeGradientPrep{Tcache} <: DI.GradientPrep +struct MooncakeGradientPrep{SIG,Tcache} <: DI.GradientPrep{SIG} + _sig::Type{SIG} cache::Tcache end function DI.prepare_gradient( - f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C} + f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} + SIG = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) - prep = MooncakeGradientPrep(cache) + prep = MooncakeGradientPrep(SIG, cache) DI.value_and_gradient(f, prep, backend, x, contexts...) return prep end function DI.value_and_gradient( - f::F, prep::MooncakeGradientPrep, ::AutoMooncake, x, contexts::Vararg{DI.Context,C} + f::F, + prep::MooncakeGradientPrep, + backend::AutoMooncake, + x, + contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) return y, mycopy(new_grad) end @@ -135,10 +155,11 @@ function DI.value_and_gradient!( f::F, grad, prep::MooncakeGradientPrep, - ::AutoMooncake, + backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) copyto!(grad, new_grad) return y, grad @@ -151,6 +172,7 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) _, grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return grad end @@ -163,6 +185,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) return grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index bd86b34f8..11d8ecd4e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,12 +1,20 @@ -struct MooncakeTwoArgPullbackPrep{Tcache,DY,F} <: DI.PullbackPrep +struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} + _sig::Type{SIG} cache::Tcache dy_righttype::DY target_function::F end function DI.prepare_pullback( - f!::F, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f!::F, + y, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} + SIG = DI.signature(f!, y, backend, x, ty, contexts...; strict) target_function = function (f!, y, x, contexts...) f!(y, x, contexts...) return y @@ -22,18 +30,21 @@ function DI.prepare_pullback( silence_debug_messages=config.silence_debug_messages, ) dy_righttype_after = zero_tangent(y) - return MooncakeTwoArgPullbackPrep(cache, dy_righttype_after, target_function) + prep = MooncakeTwoArgPullbackPrep(SIG, cache, dy_righttype_after, target_function) + DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) + return prep end function DI.value_and_pullback( f!::F, y, prep::MooncakeTwoArgPullbackPrep, - ::AutoMooncake, + backend::AutoMooncake, x, ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) # Prepare cotangent to add after the forward pass. dy = only(ty) dy_righttype_after = copyto!(prep.dy_righttype, dy) @@ -56,6 +67,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy _, tx = DI.value_and_pullback(f!, y, prep, backend, x, (dy,), contexts...) only(tx) @@ -73,6 +85,7 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) _, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) return y, tx @@ -87,6 +100,7 @@ function DI.pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) return DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2] end @@ -100,5 +114,6 @@ function DI.pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) return DI.value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2] end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 17cf1ab6c..eee0beb5c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -1,127 +1,185 @@ ## Pushforward +struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: + PolyesterForwardDiffOneArgPushforwardPrep{SIG} + _sig::Type{SIG} + single_threaded_prep::P +end + function DI.prepare_pushforward( - f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.prepare_pushforward(f, single_threaded(backend), x, tx, contexts...) + SIG = DI.signature(f, backend, x, tx, contexts...; strict) + single_threaded_prep = DI.prepare_pushforward( + f, single_threaded(backend), x, tx, contexts...; strict + ) + return PolyesterForwardDiffOneArgPushforwardPrep(SIG, single_threaded_prep) end function DI.value_and_pushforward( f, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_pushforward(f, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward( + f, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end function DI.value_and_pushforward!( f, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.value_and_pushforward!( - f, ty, prep, single_threaded(backend), x, tx, contexts... + f, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... ) end function DI.pushforward( f, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward(f, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.pushforward( + f, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end function DI.pushforward!( f, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward!(f, ty, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.pushforward!( + f, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end ## Derivative +struct PolyesterForwardDiffOneArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} + _sig::Type{SIG} + single_threaded_prep::P +end + function DI.prepare_derivative( - f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} + f, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.prepare_derivative(f, single_threaded(backend), x, contexts...) + SIG = DI.signature(f, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_derivative( + f, single_threaded(backend), x, contexts...; strict + ) + return PolyesterForwardDiffOneArgDerivativePrep(SIG, single_threaded_prep) end function DI.value_and_derivative( f, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_derivative(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.value_and_derivative( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_and_derivative!( f, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_derivative!(f, der, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.value_and_derivative!( + f, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.derivative( f, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.derivative( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.derivative!( f, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative!(f, der, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.derivative!( + f, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end ## Gradient -struct PolyesterForwardDiffGradientPrep{chunksize,P} <: DI.GradientPrep +struct PolyesterForwardDiffGradientPrep{SIG,chunksize,P} <: DI.GradientPrep{SIG} + _sig::Type{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_gradient( - f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + f, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {chunksize,C} + SIG = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_gradient(f, single_threaded(backend), x, contexts...) - return PolyesterForwardDiffGradientPrep(chunk, single_threaded_prep) + single_threaded_prep = DI.prepare_gradient( + f, single_threaded(backend), x, contexts...; strict + ) + return PolyesterForwardDiffGradientPrep(SIG, chunk, single_threaded_prep) end function DI.value_and_gradient!( @@ -132,6 +190,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) threaded_gradient!(fc, grad, x, prep.chunk) @@ -152,6 +211,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) threaded_gradient!(fc, grad, x, prep.chunk) @@ -171,6 +231,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_and_gradient!(f, similar(x), prep, backend, x, contexts...) end @@ -181,26 +242,35 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.gradient!(f, similar(x), prep, backend, x, contexts...) end ## Jacobian -struct PolyesterForwardDiffOneArgJacobianPrep{chunksize,P} <: DI.JacobianPrep +struct PolyesterForwardDiffOneArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_jacobian( - f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + f, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {chunksize,C} + SIG = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_jacobian(f, single_threaded(backend), x, contexts...) - return PolyesterForwardDiffOneArgJacobianPrep(chunk, single_threaded_prep) + single_threaded_prep = DI.prepare_jacobian( + f, single_threaded(backend), x, contexts...; strict + ) + return PolyesterForwardDiffOneArgJacobianPrep(SIG, chunk, single_threaded_prep) end function DI.value_and_jacobian!( @@ -211,6 +281,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk) @@ -229,6 +300,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) return threaded_jacobian!(fc, jac, x, prep.chunk) @@ -246,6 +318,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) return DI.value_and_jacobian!(f, jac, prep, backend, x, contexts...) @@ -258,6 +331,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) return DI.jacobian!(f, jac, prep, backend, x, contexts...) @@ -265,146 +339,109 @@ end ## Hessian +struct PolyesterForwardDiffHessianPrep{SIG,P} <: DI.HessianPrep{SIG} + _sig::Type{SIG} + single_threaded_prep::P +end + function DI.prepare_hessian( - f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} + f, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.prepare_hessian(f, single_threaded(backend), x, contexts...) + SIG = DI.signature(f, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_hessian( + f, single_threaded(backend), x, contexts...; strict + ) + return PolyesterForwardDiffHessianPrep(SIG, single_threaded_prep) end function DI.hessian( f, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.hessian(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.hessian( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.hessian!( f, hess, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.hessian!(f, hess, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.hessian!( + f, hess, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_gradient_and_hessian( f, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_gradient_and_hessian(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.value_gradient_and_hessian( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_gradient_and_hessian!( f, grad, hess, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_gradient_and_hessian!( - f, grad, hess, prep, single_threaded(backend), x, contexts... + f, grad, hess, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end -## HVP - -#= -function DI.prepare_hvp( - f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} -) where {C} - return DI.prepare_hvp( - f, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp( - f, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp!( - f, - tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp!( - f, tg, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end +## Second derivative -function DI.gradient_and_hvp( - f, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) +struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG,P} <: DI.SecondDerivativePrep{SIG} + _sig::Type{SIG} + single_threaded_prep::P end -function DI.gradient_and_hvp!( +function DI.prepare_second_derivative( f, - grad, - tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.gradient_and_hvp!( - f, - grad, - tg, - prep, - DI.SecondOrder(single_threaded(backend), backend), - x, - tx, - contexts..., + SIG = DI.signature(f, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_second_derivative( + f, single_threaded(backend), x, contexts...; strict ) -end -=# - -## Second derivative - -function DI.prepare_second_derivative( - f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {C} - return DI.prepare_second_derivative(f, single_threaded(backend), x, contexts...) + return PolyesterForwardDiffOneArgSecondDerivativePrep(SIG, single_threaded_prep) end function DI.value_derivative_and_second_derivative( f, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative( f, prep, single_threaded(backend), x, contexts... ) @@ -414,11 +451,12 @@ function DI.value_derivative_and_second_derivative!( f, der, der2, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative!( f, der, der2, prep, single_threaded(backend), x, contexts... ) @@ -426,21 +464,23 @@ end function DI.second_derivative( f, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.second_derivative(f, prep, single_threaded(backend), x, contexts...) end function DI.second_derivative!( f, der2, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.second_derivative!(f, der2, prep, single_threaded(backend), x, contexts...) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 4f63a65c1..1bd32b5b8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -1,22 +1,38 @@ ## Pushforward +struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG} + _sig::Type{SIG} + single_threaded_prep::P +end + function DI.prepare_pushforward( - f!, y, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.prepare_pushforward(f!, y, single_threaded(backend), x, tx, contexts...) + SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) + single_threaded_prep = DI.prepare_pushforward( + f!, y, single_threaded(backend), x, tx, contexts... + ) + return PolyesterForwardDiffTwoArgPushforwardPrep(SIG, single_threaded_prep) end function DI.value_and_pushforward( f!, y, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.value_and_pushforward( - f!, y, prep, single_threaded(backend), x, tx, contexts... + f!, y, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... ) end @@ -24,108 +40,146 @@ function DI.value_and_pushforward!( f!, y, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.value_and_pushforward!( - f!, y, ty, prep, single_threaded(backend), x, tx, contexts... + f!, y, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... ) end function DI.pushforward( f!, y, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward(f!, y, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + return DI.pushforward( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end function DI.pushforward!( f!, y, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward!(f!, y, ty, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + return DI.pushforward!( + f!, y, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end ## Derivative +struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} + _sig::Type{SIG} + single_threaded_prep::P +end + function DI.prepare_derivative( - f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoPolyesterForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.prepare_derivative(f!, y, single_threaded(backend), x, contexts...) + SIG = DI.signature(f!, y, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_derivative( + f!, y, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffTwoArgDerivativePrep(SIG, single_threaded_prep) end function DI.value_and_derivative( f!, y, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_derivative(f!, y, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) + return DI.value_and_derivative( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_and_derivative!( f!, y, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return DI.value_and_derivative!( - f!, y, der, prep, single_threaded(backend), x, contexts... + f!, y, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end function DI.derivative( f!, y, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative(f!, y, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) + return DI.derivative( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.derivative!( f!, y, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative!(f!, y, der, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) + return DI.derivative!( + f!, y, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end ## Jacobian -struct PolyesterForwardDiffTwoArgJacobianPrep{chunksize,P} <: DI.JacobianPrep +struct PolyesterForwardDiffTwoArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_jacobian( - f!, y, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {chunksize,C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else @@ -134,7 +188,7 @@ function DI.prepare_jacobian( single_threaded_prep = DI.prepare_jacobian( f!, y, single_threaded(backend), x, contexts... ) - return PolyesterForwardDiffTwoArgJacobianPrep(chunk, single_threaded_prep) + return PolyesterForwardDiffTwoArgJacobianPrep(SIG, chunk, single_threaded_prep) end function DI.value_and_jacobian( @@ -145,6 +199,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {K,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -167,6 +222,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {K,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) @@ -187,6 +243,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -208,6 +265,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index fadca806a..e74ee00ca 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -1,19 +1,26 @@ ## Pullback function DI.prepare_pullback( - f, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoReverseDiff, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.NoPullbackPrep() + SIG = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep{SIG}() end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.with_contexts(f, contexts...) y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) @@ -30,12 +37,13 @@ end function DI.value_and_pullback!( f, tx::NTuple, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.with_contexts(f, contexts...) y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) @@ -53,12 +61,13 @@ end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) x_array = [x] f_array(x_array, args...) = f(only(x_array), args...) y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty, contexts...) @@ -69,24 +78,29 @@ end ### Without contexts -@kwdef struct ReverseDiffGradientPrep{C,T} <: DI.GradientPrep +struct ReverseDiffGradientPrep{SIG,C,T} <: DI.GradientPrep{SIG} + _sig::Type{SIG} config::C tape::T end -function DI.prepare_gradient(f, ::AutoReverseDiff{compile}, x) where {compile} +function DI.prepare_gradient( + f, backend::AutoReverseDiff{compile}, x; strict::Bool=false +) where {compile} + SIG = DI.signature(f, backend, x) if compile tape = ReverseDiff.compile(GradientTape(f, x)) - return ReverseDiffGradientPrep(; config=nothing, tape=tape) + return ReverseDiffGradientPrep(SIG, nothing, tape) else config = GradientConfig(x) - return ReverseDiffGradientPrep(; config=config, tape=nothing) + return ReverseDiffGradientPrep(SIG, config, nothing) end end function DI.value_and_gradient!( - f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x + f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251 if compile result = gradient!(result, prep.tape, x) @@ -99,6 +113,7 @@ end function DI.value_and_gradient( f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) # GradientResult tries to mutate an SArray result = MutableDiffResult(zero(eltype(x)), (similar(x),)) if compile @@ -110,8 +125,9 @@ function DI.value_and_gradient( end function DI.gradient!( - f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x + f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return gradient!(grad, prep.tape, x) else @@ -120,8 +136,9 @@ function DI.gradient!( end function DI.gradient( - f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return gradient!(prep.tape, x) else @@ -132,20 +149,22 @@ end ### With contexts function DI.prepare_gradient( - f, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) config = GradientConfig(x) - return ReverseDiffGradientPrep(; config=config, tape=nothing) + return ReverseDiffGradientPrep(SIG, config, nothing) end function DI.value_and_gradient!( f, grad, prep::ReverseDiffGradientPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251 result = gradient!(result, fc, x, prep.config) @@ -153,8 +172,13 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) # GradientResult tries to mutate an SArray result = MutableDiffResult(zero(eltype(x)), (similar(x),)) @@ -166,17 +190,23 @@ function DI.gradient!( f, grad, prep::ReverseDiffGradientPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return gradient!(grad, fc, x, prep.config) end function DI.gradient( - f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return gradient(fc, x, prep.config) end @@ -185,24 +215,29 @@ end ### Without contexts -@kwdef struct ReverseDiffOneArgJacobianPrep{C,T} <: DI.JacobianPrep +struct ReverseDiffOneArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} config::C tape::T end -function DI.prepare_jacobian(f, ::AutoReverseDiff{compile}, x) where {compile} +function DI.prepare_jacobian( + f, backend::AutoReverseDiff{compile}, x; strict::Bool=false +) where {compile} + SIG = DI.signature(f, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f, x)) - return ReverseDiffOneArgJacobianPrep(; config=nothing, tape=tape) + return ReverseDiffOneArgJacobianPrep(SIG, nothing, tape) else config = JacobianConfig(x) - return ReverseDiffOneArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffOneArgJacobianPrep(SIG, config, nothing) end end function DI.value_and_jacobian!( - f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) y = f(x) result = DiffResult(y, (jac,)) if compile @@ -216,8 +251,9 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return f(x), jacobian!(prep.tape, x) else @@ -226,8 +262,9 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return jacobian!(jac, prep.tape, x) else @@ -236,8 +273,9 @@ function DI.jacobian!( end function DI.jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return jacobian!(prep.tape, x) else @@ -248,20 +286,22 @@ end ### With contexts function DI.prepare_jacobian( - f, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) config = JacobianConfig(x) - return ReverseDiffOneArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffOneArgJacobianPrep(SIG, config, nothing) end function DI.value_and_jacobian!( f, jac, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) y = fc(x) result = DiffResult(y, (jac,)) @@ -274,10 +314,11 @@ end function DI.value_and_jacobian( f, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return fc(x), jacobian(fc, x, prep.config) end @@ -286,10 +327,11 @@ function DI.jacobian!( f, jac, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return jacobian!(jac, fc, x, prep.config) end @@ -297,10 +339,11 @@ end function DI.jacobian( f, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return jacobian(fc, x, prep.config) end @@ -309,30 +352,31 @@ end ### Without contexts -@kwdef struct ReverseDiffHessianPrep{G<:ReverseDiffGradientPrep,HC,HT} <: DI.HessianPrep +struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.HessianPrep{SIG} + _sig::Type{SIG} gradient_prep::G hessian_config::HC hessian_tape::HT end -function DI.prepare_hessian(f, backend::AutoReverseDiff{compile}, x) where {compile} +function DI.prepare_hessian( + f, backend::AutoReverseDiff{compile}, x; strict::Bool=false +) where {compile} + SIG = DI.signature(f, backend, x; strict) gradient_prep = DI.prepare_gradient(f, backend, x) if compile hessian_tape = ReverseDiff.compile(HessianTape(f, x)) - return ReverseDiffHessianPrep(; - gradient_prep, hessian_config=nothing, hessian_tape=hessian_tape - ) + return ReverseDiffHessianPrep(SIG, gradient_prep, nothing, hessian_tape) else hessian_config = HessianConfig(x) - return ReverseDiffHessianPrep(; - gradient_prep, hessian_config=hessian_config, hessian_tape=nothing - ) + return ReverseDiffHessianPrep(SIG, gradient_prep, hessian_config, nothing) end end function DI.hessian!( - f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x + f, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return hessian!(hess, prep.hessian_tape, x) else @@ -341,8 +385,9 @@ function DI.hessian!( end function DI.hessian( - f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return hessian!(prep.hessian_tape, x) else @@ -353,6 +398,7 @@ end function DI.value_gradient_and_hessian!( f, grad, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) y = f(x) DI.gradient!(f, grad, prep.gradient_prep, backend, x) DI.hessian!(f, hess, prep, backend, x) @@ -362,6 +408,7 @@ end function DI.value_gradient_and_hessian( f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) y = f(x) grad = DI.gradient(f, prep.gradient_prep, backend, x) hess = DI.hessian(f, prep, backend, x) @@ -371,30 +418,35 @@ end ### With contexts function DI.prepare_hessian( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) hessian_config = HessianConfig(x) - return ReverseDiffHessianPrep(; - gradient_prep, hessian_config=hessian_config, hessian_tape=nothing - ) + return ReverseDiffHessianPrep(SIG, gradient_prep, hessian_config, nothing) end function DI.hessian!( f, hess, prep::ReverseDiffHessianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return hessian!(hess, fc, x, prep.hessian_config) end function DI.hessian( - f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::ReverseDiffHessianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return hessian(fc, x, prep.hessian_config) end @@ -408,6 +460,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) @@ -421,6 +474,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index f76e7c824..2bf3fa7e9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -1,9 +1,16 @@ ## Pullback function DI.prepare_pullback( - f!, y, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoReverseDiff, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.NoPullbackPrep() + SIG = DI.signature(f!, y, x, ty, contexts...; strict) + return DI.NoPullbackPrep{SIG}() end ### Array in @@ -11,12 +18,13 @@ end function DI.value_and_pullback( f!, y, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -34,12 +42,13 @@ function DI.value_and_pullback!( f!, y, tx::NTuple, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -57,12 +66,13 @@ end function DI.pullback( f!, y, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -79,12 +89,13 @@ function DI.pullback!( f!, y, tx::NTuple, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -103,12 +114,13 @@ end function DI.value_and_pullback( f!, y, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) x_array = [x] function f!_array(_y::AbstractArray, _x_array, args...) return f!(_y, only(_x_array), args...) @@ -121,24 +133,29 @@ end ### Without contexts -@kwdef struct ReverseDiffTwoArgJacobianPrep{C,T} <: DI.JacobianPrep +struct ReverseDiffTwoArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} + _sig::Type{SIG} config::C tape::T end -function DI.prepare_jacobian(f!, y, ::AutoReverseDiff{compile}, x) where {compile} +function DI.prepare_jacobian( + f!, y, backend::AutoReverseDiff{compile}, x; strict::Bool=false +) where {compile} + SIG = DI.signature(f!, y, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f!, y, x)) - return ReverseDiffTwoArgJacobianPrep(; config=nothing, tape=tape) + return ReverseDiffTwoArgJacobianPrep(SIG, nothing, tape) else config = JacobianConfig(y, x) - return ReverseDiffTwoArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffTwoArgJacobianPrep(SIG, config, nothing) end end function DI.value_and_jacobian( - f!, y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) if compile @@ -150,8 +167,9 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) result = MutableDiffResult(y, (jac,)) if compile result = jacobian!(result, prep.tape, x) @@ -162,8 +180,9 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) if compile jac = jacobian!(prep.tape, x) else @@ -173,8 +192,9 @@ function DI.jacobian( end function DI.jacobian!( - f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) if compile jac = jacobian!(jac, prep.tape, x) else @@ -186,20 +206,22 @@ end ### With contexts function DI.prepare_jacobian( - f!, y, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {C} + SIG = DI.signature(f!, y, backend, x, contexts...; strict) config = JacobianConfig(y, x) - return ReverseDiffTwoArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffTwoArgJacobianPrep(SIG, config, nothing) end function DI.value_and_jacobian( f!, y, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) @@ -212,10 +234,11 @@ function DI.value_and_jacobian!( y, jac, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) result = jacobian!(result, fc!, y, x, prep.config) @@ -226,10 +249,11 @@ function DI.jacobian( f!, y, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) jac = jacobian(fc!, y, x, prep.config) return jac @@ -240,10 +264,11 @@ function DI.jacobian!( y, jac, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) jac = jacobian!(jac, fc!, y, x, prep.config) return jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl index ea8ef94b9..2de265807 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl @@ -15,7 +15,7 @@ using SparseMatrixColorings: decompress! import SparseMatrixColorings as SMC -abstract type SparseJacobianPrep <: DI.JacobianPrep end +abstract type SparseJacobianPrep{SIG} <: DI.JacobianPrep{SIG} end SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result) SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 61131dc71..39e81e30d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -1,4 +1,5 @@ struct SparseHessianPrep{ + SIG, BS<:DI.BatchSizeSettings, C<:AbstractColoringResult{:symmetric,:column}, M<:AbstractMatrix{<:Number}, @@ -6,7 +7,8 @@ struct SparseHessianPrep{ R<:AbstractVector{<:NTuple}, E2<:DI.HVPPrep, E1<:DI.GradientPrep, -} <: DI.HessianPrep +} <: DI.HessianPrep{SIG} + _sig::Type{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -24,7 +26,7 @@ SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result) ## Hessian, one argument function DI.prepare_hessian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} + f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} dense_backend = dense_ad(backend) sparsity = DI.hessian_sparsity_with_contexts( @@ -37,7 +39,7 @@ function DI.prepare_hessian( N = length(column_groups(coloring_result)) batch_size_settings = DI.pick_batchsize(DI.outer(dense_backend), N) return _prepare_sparse_hessian_aux( - batch_size_settings, coloring_result, f, backend, x, contexts... + batch_size_settings, coloring_result, f, backend, x, contexts...; strict ) end @@ -47,8 +49,10 @@ function _prepare_sparse_hessian_aux( f::F, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool, ) where {B,F,C} + SIG = DI.signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) @@ -61,6 +65,7 @@ function _prepare_sparse_hessian_aux( hvp_prep = DI.prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...) gradient_prep = DI.prepare_gradient(f, DI.inner(dense_backend), x, contexts...) return SparseHessianPrep( + SIG, batch_size_settings, coloring_result, compressed_matrix, @@ -79,6 +84,7 @@ function DI.hessian!( x, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, coloring_result, @@ -120,6 +126,7 @@ end function DI.hessian( f::F, prep::SparseHessianPrep{B}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,B,C} + DI.check_prep(f, prep, backend, x, contexts...) hess = similar(sparsity_pattern(prep), eltype(x)) return DI.hessian!(f, hess, prep, backend, x, contexts...) end @@ -133,6 +140,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, DI.inner(dense_ad(backend)), x, contexts... ) @@ -143,6 +151,7 @@ end function DI.value_gradient_and_hessian( f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, DI.inner(dense_ad(backend)), x, contexts... ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 4193fd7a0..77b7a12b5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -1,13 +1,15 @@ ## Preparation struct PushforwardSparseJacobianPrep{ + SIG, BS<:DI.BatchSizeSettings, C<:AbstractColoringResult{:nonsymmetric,:column}, M<:AbstractMatrix{<:Number}, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:DI.PushforwardPrep, -} <: SparseJacobianPrep +} <: SparseJacobianPrep{SIG} + _sig::Type{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -17,13 +19,15 @@ struct PushforwardSparseJacobianPrep{ end struct PullbackSparseJacobianPrep{ + SIG, BS<:DI.BatchSizeSettings, C<:AbstractColoringResult{:nonsymmetric,:row}, M<:AbstractMatrix{<:Number}, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:DI.PullbackPrep, -} <: SparseJacobianPrep +} <: SparseJacobianPrep{SIG} + _sig::Type{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -33,12 +37,12 @@ struct PullbackSparseJacobianPrep{ end function DI.prepare_jacobian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} + f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} dense_backend = dense_ad(backend) y = f(x, map(DI.unwrap, contexts)...) perf = DI.pushforward_performance(dense_backend) - return _prepare_sparse_jacobian_aux(perf, y, (f,), backend, x, contexts...) + return _prepare_sparse_jacobian_aux(perf, y, (f,), backend, x, contexts...; strict) end function DI.prepare_jacobian( @@ -46,7 +50,7 @@ function DI.prepare_jacobian( ) where {F,C} dense_backend = dense_ad(backend) perf = DI.pushforward_performance(dense_backend) - return _prepare_sparse_jacobian_aux(perf, y, (f!, y), backend, x, contexts...) + return _prepare_sparse_jacobian_aux(perf, y, (f!, y), backend, x, contexts...; strict) end function _prepare_sparse_jacobian_aux( @@ -55,7 +59,8 @@ function _prepare_sparse_jacobian_aux( f_or_f!y::FY, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool, ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -79,7 +84,7 @@ function _prepare_sparse_jacobian_aux( end batch_size_settings = DI.pick_batchsize(dense_backend, N) return _prepare_sparse_jacobian_aux_aux( - batch_size_settings, coloring_result, y, f_or_f!y, backend, x, contexts... + batch_size_settings, coloring_result, y, f_or_f!y, backend, x, contexts...; strict ) end @@ -90,8 +95,10 @@ function _prepare_sparse_jacobian_aux_aux( f_or_f!y::FY, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool, ) where {B,FY,C} + SIG = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) @@ -105,6 +112,7 @@ function _prepare_sparse_jacobian_aux_aux( f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PushforwardSparseJacobianPrep( + SIG, batch_size_settings, coloring_result, compressed_matrix, @@ -121,8 +129,10 @@ function _prepare_sparse_jacobian_aux_aux( f_or_f!y::FY, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool, ) where {B,FY,C} + SIG = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = row_groups(coloring_result) @@ -136,6 +146,7 @@ function _prepare_sparse_jacobian_aux_aux( f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PullbackSparseJacobianPrep( + SIG, batch_size_settings, coloring_result, compressed_matrix, @@ -155,12 +166,14 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return _sparse_jacobian_aux!((f,), jac, prep, backend, x, contexts...) end function DI.jacobian( f::F, prep::SparseJacobianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) jac = similar(sparsity_pattern(prep), eltype(x)) return DI.jacobian!(f, jac, prep, backend, x, contexts...) end @@ -168,6 +181,7 @@ end function DI.value_and_jacobian( f::F, prep::SparseJacobianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -179,6 +193,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian!(f, jac, prep, backend, x, contexts...) end @@ -194,6 +209,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return _sparse_jacobian_aux!((f!, y), jac, prep, backend, x, contexts...) end @@ -205,6 +221,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = similar(sparsity_pattern(prep), promote_type(eltype(x), eltype(y))) return DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) end @@ -217,6 +234,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac @@ -231,6 +249,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index 4ab180aad..cc9267092 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -1,6 +1,7 @@ ## Preparation struct MixedModeSparseJacobianPrep{ + SIG, BSf<:DI.BatchSizeSettings, BSr<:DI.BatchSizeSettings, C<:AbstractColoringResult{:nonsymmetric,:bidirectional}, @@ -11,7 +12,8 @@ struct MixedModeSparseJacobianPrep{ Rr<:Vector{<:NTuple}, Ef<:DI.PushforwardPrep, Er<:DI.PullbackPrep, -} <: SparseJacobianPrep +} <: SparseJacobianPrep{SIG} + _sig::Type{SIG} batch_size_settings_forward::BSf batch_size_settings_reverse::BSr coloring_result::C @@ -26,20 +28,34 @@ struct MixedModeSparseJacobianPrep{ end function DI.prepare_jacobian( - f::F, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C} + f::F, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} y = f(x, map(DI.unwrap, contexts)...) - return _prepare_mixed_sparse_jacobian_aux(y, (f,), backend, x, contexts...) + return _prepare_mixed_sparse_jacobian_aux(y, (f,), backend, x, contexts...; strict) end function DI.prepare_jacobian( - f!::F, y, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C} + f!::F, + y, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {F,C} - return _prepare_mixed_sparse_jacobian_aux(y, (f!, y), backend, x, contexts...) + return _prepare_mixed_sparse_jacobian_aux(y, (f!, y), backend, x, contexts...; strict) end function _prepare_mixed_sparse_jacobian_aux( - y, f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C} + y, + f_or_f!y::FY, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context,C}; + strict::Bool, ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -66,7 +82,8 @@ function _prepare_mixed_sparse_jacobian_aux( f_or_f!y, backend, x, - contexts..., + contexts...; + strict, ) end @@ -78,8 +95,10 @@ function _prepare_mixed_sparse_jacobian_aux_aux( f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Bool, ) where {Bf,Br,FY,C} + SIG = DI.signature(f_or_f!y..., backend, x, contexts...; strict) Nf, Af = batch_size_settings_forward.N, batch_size_settings_forward.A Nr, Ar = batch_size_settings_reverse.N, batch_size_settings_reverse.A @@ -124,6 +143,7 @@ function _prepare_mixed_sparse_jacobian_aux_aux( ) return MixedModeSparseJacobianPrep( + SIG, batch_size_settings_forward, batch_size_settings_reverse, coloring_result, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 9f0102a8d..dd0089969 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -6,7 +6,7 @@ struct SymbolicsOneArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep end function DI.prepare_pushforward( - f, ::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} ) where {C} dx = first(tx) x_var = variablize(x, :x) @@ -28,7 +28,7 @@ end function DI.pushforward( f, prep::SymbolicsOneArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, @@ -43,7 +43,7 @@ function DI.pushforward!( f, ty::NTuple, prep::SymbolicsOneArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, @@ -88,7 +88,7 @@ struct SymbolicsOneArgDerivativePrep{E1,E1!} <: DI.DerivativePrep end function DI.prepare_derivative( - f, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -106,7 +106,7 @@ end function DI.derivative( f, prep::SymbolicsOneArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -117,7 +117,7 @@ function DI.derivative!( f, der, prep::SymbolicsOneArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -156,7 +156,7 @@ struct SymbolicsOneArgGradientPrep{E1,E1!} <: DI.GradientPrep end function DI.prepare_gradient( - f, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -169,7 +169,11 @@ function DI.prepare_gradient( end function DI.gradient( - f, prep::SymbolicsOneArgGradientPrep, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, ) where {C} return reshape(prep.grad_exe(vec(x), map(DI.unwrap, contexts)...), size(x)) end @@ -178,7 +182,7 @@ function DI.gradient!( f, grad, prep::SymbolicsOneArgGradientPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -237,7 +241,7 @@ end function DI.jacobian( f, prep::SymbolicsOneArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -248,7 +252,7 @@ function DI.jacobian!( f, jac, prep::SymbolicsOneArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -311,7 +315,7 @@ end function DI.hessian( f, prep::SymbolicsOneArgHessianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -322,7 +326,7 @@ function DI.hessian!( f, hess, prep::SymbolicsOneArgHessianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -391,7 +395,7 @@ end function DI.hvp( f, prep::SymbolicsOneArgHVPPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, @@ -406,7 +410,7 @@ function DI.hvp!( f, tg::NTuple, prep::SymbolicsOneArgHVPPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, @@ -475,7 +479,7 @@ end function DI.second_derivative( f, prep::SymbolicsOneArgSecondDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -486,7 +490,7 @@ function DI.second_derivative!( f, der2, prep::SymbolicsOneArgSecondDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index ffe6ee0f4..ab3e90928 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -6,7 +6,7 @@ struct SymbolicsTwoArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep end function DI.prepare_pushforward( - f!, y, ::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f!, y, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} ) where {C} dx = first(tx) x_var = variablize(x, :x) @@ -27,7 +27,7 @@ function DI.pushforward( f!, y, prep::SymbolicsTwoArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, @@ -43,7 +43,7 @@ function DI.pushforward!( y, ty::NTuple, prep::SymbolicsTwoArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, @@ -92,7 +92,7 @@ struct SymbolicsTwoArgDerivativePrep{E1,E1!} <: DI.DerivativePrep end function DI.prepare_derivative( - f!, y, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} x_var = variablize(x, :x) y_var = variablize(y, :y) @@ -109,7 +109,7 @@ function DI.derivative( f!, y, prep::SymbolicsTwoArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -121,7 +121,7 @@ function DI.derivative!( y, der, prep::SymbolicsTwoArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -189,7 +189,7 @@ function DI.jacobian( f!, y, prep::SymbolicsTwoArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} @@ -201,7 +201,7 @@ function DI.jacobian!( y, jac, prep::SymbolicsTwoArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 17d885ca7..dc7e63263 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -9,37 +9,46 @@ DI.inplace_support(::AutoTracker) = DI.InPlaceNotSupported() ## Pullback -struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep +struct TrackerPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} + _sig::Type{SIG} y::Y pb::PB end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} + f, + backend::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant,C}; + strict::Bool=false, ) where {C} - return DI.NoPullbackPrep() + SIG = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep{SIG}() end function DI.prepare_pullback_same_point( f, - ::DI.NoPullbackPrep, - ::AutoTracker, + prep::DI.NoPullbackPrep, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(y, pb) end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, - ::AutoTracker, + prep::DI.NoPullbackPrep, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy data(first(pb(dy))) @@ -50,7 +59,7 @@ end function DI.value_and_pullback( f, prep::TrackerPullbackPrepSamePoint, - ::AutoTracker, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, @@ -65,7 +74,7 @@ end function DI.pullback( f, prep::TrackerPullbackPrepSamePoint, - ::AutoTracker, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, @@ -80,20 +89,28 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index adf1c397e..0ef65e91b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -21,27 +21,47 @@ translate(c::DI.Cache) = Buffer(DI.unwrap(c)) ## Pullback -struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep +struct ZygotePullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} + _sig::Type{SIG} y::Y pb::PB end function DI.prepare_pullback( - f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.NoPullbackPrep() + SIG = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep{SIG}() end function DI.prepare_pullback_same_point( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.NoPullbackPrep, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) + SIG = DI.signature(f, backend, x, ty, contexts...; strict) y, pb = pullback(f, x, map(translate, contexts)...) - return ZygotePullbackPrepSamePoint(y, pb) + return ZygotePullbackPrepSamePoint(SIG, y, pb) end function DI.value_and_pullback( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.NoPullbackPrep, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = pullback(f, x, map(translate, contexts)...) tx = map(ty) do dy first(pb(dy)) @@ -52,11 +72,12 @@ end function DI.value_and_pullback( f, prep::ZygotePullbackPrepSamePoint, - ::AutoZygote, + backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy first(pb(dy)) @@ -67,11 +88,12 @@ end function DI.pullback( f, prep::ZygotePullbackPrepSamePoint, - ::AutoZygote, + backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy first(pb(dy)) @@ -81,20 +103,25 @@ end ## Gradient -function DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{DI.Context,C}) where {C} - return DI.NoGradientPrep() +function DI.prepare_gradient( + f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Bool=false +) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep{SIG}() end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; val, grad) = withgradient(f, x, map(translate, contexts)...) return val, first(grad) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) grad = gradient(f, x, map(translate, contexts)...) return first(grad) end @@ -102,6 +129,7 @@ end function DI.value_and_gradient!( f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end @@ -109,18 +137,23 @@ end function DI.gradient!( f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end ## Jacobian -function DI.prepare_jacobian(f, ::AutoZygote, x, contexts::Vararg{DI.Context,C}) where {C} - return DI.NoJacobianPrep() +function DI.prepare_jacobian( + f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Bool=false +) where {C} + SIG = DI.signature(f, backend, x, contexts...; strict) + return DI.NoJacobianPrep{SIG}() end function DI.value_and_jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(translate, contexts)...) # https://github.com/FluxML/Zygote.jl/issues/1506 jac = jacobian(f, x, map(translate, contexts)...) @@ -128,8 +161,9 @@ function DI.value_and_jacobian( end function DI.jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) jac = jacobian(f, x, map(translate, contexts)...) return first(jac) end @@ -137,6 +171,7 @@ end function DI.value_and_jacobian!( f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end @@ -144,6 +179,7 @@ end function DI.jacobian!( f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -152,9 +188,16 @@ end # Beware, this uses ForwardDiff for the inner differentiation function DI.prepare_hvp( - f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} - return DI.prepare_hvp(f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) + return DI.prepare_hvp( + f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...; strict + ) end function DI.hvp( @@ -165,6 +208,7 @@ function DI.hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -177,6 +221,7 @@ function DI.hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.hvp!( f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) @@ -190,6 +235,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp( f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) @@ -205,6 +251,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) @@ -213,14 +260,24 @@ end ## Hessian function DI.prepare_hessian( - f, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + backend::AutoZygote, + x, + contexts::Vararg{DI.GeneralizedConstant,C}; + strict::Bool=false, ) where {C} - return DI.NoHessianPrep() + SIG = DI.signature(f, backend, x, contexts...; strict) + return DI.NoHessianPrep{SIG}() end function DI.hessian( - f, ::DI.NoHessianPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) hess = hessian(fc, x) return hess @@ -234,6 +291,7 @@ function DI.hessian!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end @@ -244,6 +302,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) return y, grad, hess @@ -258,6 +317,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) return y, grad, hess diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 2cbd312ff..984a194f7 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -41,13 +41,13 @@ include("docstrings.jl") include("first_order/mixed_mode.jl") include("second_order/second_order.jl") +include("utils/context.jl") include("utils/prep.jl") include("utils/traits.jl") include("utils/basis.jl") include("utils/batchsize.jl") include("utils/check.jl") include("utils/errors.jl") -include("utils/context.jl") include("utils/linalg.jl") include("utils/sparse.jl") diff --git a/DifferentiationInterface/src/docstrings.jl b/DifferentiationInterface/src/docstrings.jl index 66834d1e2..02d6068f1 100644 --- a/DifferentiationInterface/src/docstrings.jl +++ b/DifferentiationInterface/src/docstrings.jl @@ -23,6 +23,7 @@ function docstring_prepare(operator; samepoint=false, inplace=false) !!! warning The preparation result is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified. Otherwise, preparation will be invalidated and you will need to run it again. + The keyword argument `strict` activates automatic type checking, but ensuring size consistency is up to the user. $(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "") """ end diff --git a/DifferentiationInterface/src/fallbacks/change_prep.jl b/DifferentiationInterface/src/fallbacks/change_prep.jl index 48ea62583..8736deb40 100644 --- a/DifferentiationInterface/src/fallbacks/change_prep.jl +++ b/DifferentiationInterface/src/fallbacks/change_prep.jl @@ -43,24 +43,27 @@ for op in [ if op in (:derivative, :gradient, :jacobian) # 1-arg @eval function $prep_op!( - f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} - return $prep_op(f, backend, x, contexts...) + check_prep(f, old_prep, backend, x, contexts...) + return $prep_op(f, backend, x, contexts...; strict=is_strict(old_prep)) end op == :gradient && continue # 2-arg @eval function $prep_op!( - f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} - return $prep_op(f!, y, backend, x, contexts...) + check_prep(f!, y, old_prep, backend, x, contexts...) + return $prep_op(f!, y, backend, x, contexts...; strict=is_strict(old_prep)) end elseif op in (:second_derivative, :hessian) # 1-arg @eval function $prep_op!( - f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} - return $prep_op(f, backend, x, contexts...) + check_prep(f, old_prep, backend, x, contexts...) + return $prep_op(f, backend, x, contexts...; strict=is_strict(old_prep)) end elseif op in (:pushforward, :pullback, :hvp) @@ -71,9 +74,10 @@ for op in [ backend::AbstractADType, x, seed::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} - return $prep_op(f, backend, x, seed, contexts...) + check_prep(f, old_prep, backend, x, seed, contexts...) + return $prep_op(f, backend, x, seed, contexts...; strict=is_strict(old_prep)) end @eval function $prep_op_same_point( f::F, @@ -83,12 +87,18 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, seed, contexts...) return prep end @eval function $prep_op_same_point( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(f, backend, x, seed, contexts...; strict) return $prep_op_same_point(f, prep, backend, x, seed, contexts...) end op == :hvp && continue @@ -102,7 +112,10 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - return $prep_op(f!, y, backend, x, seed, contexts...) + check_prep(f!, y, old_prep, backend, x, seed, contexts...) + return $prep_op( + f!, y, backend, x, seed, contexts...; strict=is_strict(old_prep) + ) end @eval function $prep_op_same_point( f!::F, @@ -113,12 +126,19 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, seed, contexts...) return prep end @eval function $prep_op_same_point( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict) return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) end end diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index 0f2ecd4e6..5e6debc37 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -45,25 +45,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $op!(f, result, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $val_and_op!(f, result, prep, backend, x, contexts...) end op == :gradient && continue @@ -71,25 +71,25 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(f!, y, backend, x, contexts...; strict=true) return $op(f!, y, prep, backend, x, contexts...) end @eval function $op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(f!, y, backend, x, contexts...; strict=true) return $op!(f!, y, result, prep, backend, x, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(f!, y, backend, x, contexts...; strict=true) return $val_and_op(f!, y, prep, backend, x, contexts...) end @eval function $val_and_op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(f!, y, backend, x, contexts...; strict=true) return $val_and_op!(f!, y, result, prep, backend, x, contexts...) end @@ -98,25 +98,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $op!(f, result2, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(f, backend, x, contexts...; strict=true) return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) end @@ -124,7 +124,7 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(f, backend, x, seed, contexts...; strict=true) return $op(f, prep, backend, x, seed, contexts...) end @eval function $op!( @@ -135,13 +135,13 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(f, backend, x, seed, contexts...; strict=true) return $op!(f, result, prep, backend, x, seed, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(f, backend, x, seed, contexts...; strict=true) return $val_and_op(f, prep, backend, x, seed, contexts...) end @@ -154,7 +154,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(f, backend, x, seed, contexts...; strict=true) return $val_and_op!(f, result, prep, backend, x, seed, contexts...) end elseif op == :hvp @@ -167,7 +167,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(f, backend, x, seed, contexts...; strict=true) return $val_and_op!( f, result1, result2, prep, backend, x, seed, contexts... ) @@ -179,7 +179,7 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) return $op(f!, y, prep, backend, x, seed, contexts...) end @eval function $op!( @@ -191,13 +191,13 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) return $op!(f!, y, result, prep, backend, x, seed, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) return $val_and_op(f!, y, prep, backend, x, seed, contexts...) end @eval function $val_and_op!( @@ -209,7 +209,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) return $val_and_op!(f!, y, result, prep, backend, x, seed, contexts...) end end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 84b080ae5..25680a7f8 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_derivative(f, backend, x, [contexts...]) -> prep - prepare_derivative(f!, y, backend, x, [contexts...]) -> prep + prepare_derivative(f, backend, x, [contexts...]; strict=false) -> prep + prepare_derivative(f!, y, backend, x, [contexts...]; strict=false) -> prep $(docstring_prepare("derivative"; inplace=true)) """ @@ -58,22 +58,26 @@ function derivative! end ## Preparation -struct PushforwardDerivativePrep{E<:PushforwardPrep} <: DerivativePrep +struct PushforwardDerivativePrep{SIG,E<:PushforwardPrep} <: DerivativePrep{SIG} pushforward_prep::E end function prepare_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false ) where {F,C} - pushforward_prep = prepare_pushforward(f, backend, x, (one(x),), contexts...) - return PushforwardDerivativePrep(pushforward_prep) + SIG = signature(f, backend, x, contexts...; strict) + pushforward_prep = prepare_pushforward(f, backend, x, (one(x),), contexts...; strict) + return PushforwardDerivativePrep{SIG,typeof(pushforward_prep)}(pushforward_prep) end function prepare_derivative( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false ) where {F,C} - pushforward_prep = prepare_pushforward(f!, y, backend, x, (one(x),), contexts...) - return PushforwardDerivativePrep(pushforward_prep) + SIG = signature(f!, y, backend, x, contexts...; strict) + pushforward_prep = prepare_pushforward( + f!, y, backend, x, (one(x),), contexts...; strict + ) + return PushforwardDerivativePrep{SIG,typeof(pushforward_prep)}(pushforward_prep) end ## One argument @@ -85,6 +89,7 @@ function value_and_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, ty = value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -99,6 +104,7 @@ function value_and_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -112,6 +118,7 @@ function derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) ty = pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) return only(ty) end @@ -124,6 +131,7 @@ function derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der end @@ -138,6 +146,7 @@ function value_and_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) y, ty = value_and_pushforward( f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -153,6 +162,7 @@ function value_and_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) y, _ = value_and_pushforward!( f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -167,6 +177,7 @@ function derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...) return only(ty) end @@ -180,6 +191,7 @@ function derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 17c39ac92..6f232fbda 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_gradient(f, backend, x, [contexts...]) -> prep + prepare_gradient(f, backend, x, [contexts...]; strict=false) -> prep $(docstring_prepare("gradient")) """ @@ -52,27 +52,29 @@ function gradient! end ## Preparation -struct PullbackGradientPrep{Y,E<:PullbackPrep} <: GradientPrep +struct PullbackGradientPrep{SIG,Y,E<:PullbackPrep} <: GradientPrep{SIG} pullback_prep::E end function prepare_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false ) where {F,C} + SIG = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? - pullback_prep = prepare_pullback(f, backend, x, (true,), contexts...) - return PullbackGradientPrep{typeof(y),typeof(pullback_prep)}(pullback_prep) + pullback_prep = prepare_pullback(f, backend, x, (true,), contexts...; strict) + return PullbackGradientPrep{SIG,typeof(y),typeof(pullback_prep)}(pullback_prep) end ## One argument function value_and_gradient( f::F, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) y, tx = value_and_pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...) return y, only(tx) end @@ -80,11 +82,12 @@ end function value_and_gradient!( f::F, grad, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pullback!( f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts... ) @@ -93,11 +96,12 @@ end function gradient( f::F, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) tx = pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...) return only(tx) end @@ -105,11 +109,12 @@ end function gradient!( f::F, grad, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) pullback!(f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...) return grad end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 546c2b80f..f1079a81f 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_jacobian(f, backend, x, [contexts...]) -> prep - prepare_jacobian(f!, y, backend, x, [contexts...]) -> prep + prepare_jacobian(f, backend, x, [contexts...]; strict=false) -> prep + prepare_jacobian(f!, y, backend, x, [contexts...]; strict=false) -> prep $(docstring_prepare("jacobian"; inplace=true)) """ @@ -58,14 +58,16 @@ function jacobian! end ## Preparation -abstract type StandardJacobianPrep <: JacobianPrep end +abstract type StandardJacobianPrep{SIG} <: JacobianPrep{SIG} end struct PushforwardJacobianPrep{ + SIG, BS<:BatchSizeSettings, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:PushforwardPrep, -} <: StandardJacobianPrep +} <: StandardJacobianPrep{SIG} + _sig::Type{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -73,11 +75,13 @@ struct PushforwardJacobianPrep{ end struct PullbackJacobianPrep{ + SIG, BS<:BatchSizeSettings, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:PullbackPrep, -} <: StandardJacobianPrep +} <: StandardJacobianPrep{SIG} + _sig::Type{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -85,7 +89,7 @@ struct PullbackJacobianPrep{ end function prepare_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false ) where {F,C} y = f(x, map(unwrap, contexts)...) perf = pushforward_performance(backend) @@ -97,12 +101,12 @@ function prepare_jacobian( end # function barrier return _prepare_jacobian_aux( - perf, batch_size_settings, y, (f,), backend, x, contexts... + perf, batch_size_settings, y, (f,), backend, x, contexts...; strict ) end function prepare_jacobian( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} + f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false ) where {F,C} perf = pushforward_performance(backend) # type-unstable @@ -113,7 +117,7 @@ function prepare_jacobian( end # function barrier return _prepare_jacobian_aux( - perf, batch_size_settings, y, (f!, y), backend, x, contexts... + perf, batch_size_settings, y, (f!, y), backend, x, contexts...; strict ) end @@ -124,8 +128,10 @@ function _prepare_jacobian_aux( f_or_f!y::FY, backend::AbstractADType, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {B,FY,C} + SIG = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] batched_seeds = [ @@ -133,10 +139,10 @@ function _prepare_jacobian_aux( ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = prepare_pushforward( - f_or_f!y..., backend, x, batched_seeds[1], contexts... + f_or_f!y..., backend, x, batched_seeds[1], contexts...; strict ) return PushforwardJacobianPrep( - batch_size_settings, batched_seeds, batched_results, pushforward_prep + SIG, batch_size_settings, batched_seeds, batched_results, pushforward_prep ) end @@ -147,17 +153,21 @@ function _prepare_jacobian_aux( f_or_f!y::FY, backend::AbstractADType, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {B,FY,C} + SIG = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(y, ind) for ind in eachindex(y)] batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - pullback_prep = prepare_pullback(f_or_f!y..., backend, x, batched_seeds[1], contexts...) + pullback_prep = prepare_pullback( + f_or_f!y..., backend, x, batched_seeds[1], contexts...; strict + ) return PullbackJacobianPrep( - batch_size_settings, batched_seeds, batched_results, pullback_prep + SIG, batch_size_settings, batched_seeds, batched_results, pullback_prep ) end @@ -170,6 +180,7 @@ function jacobian( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return _jacobian_aux((f,), prep, backend, x, contexts...) end @@ -181,18 +192,21 @@ function jacobian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return _jacobian_aux!((f,), jac, prep, backend, x, contexts...) end function value_and_jacobian( f::F, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return f(x, map(unwrap, contexts)...), jacobian(f, prep, backend, x, contexts...) end function value_and_jacobian!( f::F, jac, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return f(x, map(unwrap, contexts)...), jacobian!(f, jac, prep, backend, x, contexts...) end @@ -206,6 +220,7 @@ function jacobian( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) return _jacobian_aux((f!, y), prep, backend, x, contexts...) end @@ -218,12 +233,14 @@ function jacobian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) return _jacobian_aux!((f!, y), jac, prep, backend, x, contexts...) end function value_and_jacobian( f!::F, y, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) jac = jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(unwrap, contexts)...) return y, jac @@ -238,6 +255,7 @@ function value_and_jacobian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(unwrap, contexts)...) return y, jac @@ -247,11 +265,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PushforwardJacobianPrep{<:BatchSizeSettings{B,true,aligned}}, + prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B,true,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pushforward_prep) = prep (; B_last) = batch_size_settings dy_batch = pushforward( @@ -267,11 +285,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PushforwardJacobianPrep{<:BatchSizeSettings{B,false,aligned}}, + prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pushforward_prep) = prep (; A, B_last) = batch_size_settings @@ -300,11 +318,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PullbackJacobianPrep{<:BatchSizeSettings{B,true,aligned}}, + prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B,true,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pullback_prep) = prep (; B_last) = batch_size_settings dx_batch = pullback( @@ -323,11 +341,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PullbackJacobianPrep{<:BatchSizeSettings{B,false,aligned}}, + prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pullback_prep) = prep (; A, B_last) = batch_size_settings @@ -355,11 +373,11 @@ end function _jacobian_aux!( f_or_f!y::FY, jac, - prep::PushforwardJacobianPrep{<:BatchSizeSettings{B}}, + prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep (; N) = batch_size_settings @@ -391,11 +409,11 @@ end function _jacobian_aux!( f_or_f!y::FY, jac, - prep::PullbackJacobianPrep{<:BatchSizeSettings{B}}, + prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep (; N) = batch_size_settings diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index d80a4ff7e..c7974686c 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_pullback(f, backend, x, ty, [contexts...]) -> prep - prepare_pullback(f!, y, backend, x, ty, [contexts...]) -> prep + prepare_pullback(f, backend, x, ty, [contexts...]; strict=false) -> prep + prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=false) -> prep $(docstring_prepare("pullback"; inplace=true)) """ @@ -17,8 +17,8 @@ $(docstring_prepare!("pullback")) function prepare!_pullback end """ - prepare_pullback_same_point(f, backend, x, ty, [contexts...]) -> prep_same - prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]) -> prep_same + prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=false) -> prep_same + prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=false) -> prep_same $(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ @@ -85,23 +85,34 @@ function pullback! end ## Preparation -struct PushforwardPullbackPrep{E} <: PullbackPrep +struct PushforwardPullbackPrep{SIG,E} <: PullbackPrep{SIG} pushforward_prep::E end function prepare_pullback( - f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} return _prepare_pullback_aux( - pullback_performance(backend), f, backend, x, ty, contexts... + pullback_performance(backend), f, backend, x, ty, contexts...; strict ) end function prepare_pullback( - f!::F, y, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} return _prepare_pullback_aux( - pullback_performance(backend), f!, y, backend, x, ty, contexts... + pullback_performance(backend), f!, y, backend, x, ty, contexts...; strict ) end @@ -111,11 +122,13 @@ function _prepare_pullback_aux( backend::AbstractADType, x, ty::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(f, backend, x, (dx,), contexts...) - return PushforwardPullbackPrep(pushforward_prep) + pushforward_prep = prepare_pushforward(f, backend, x, (dx,), contexts...; strict) + return PushforwardPullbackPrep{SIG,typeof(pushforward_prep)}(pushforward_prep) end function _prepare_pullback_aux( @@ -125,11 +138,13 @@ function _prepare_pullback_aux( backend::AbstractADType, x, ty::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f!, y, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(f!, y, backend, x, (dx,), contexts...) - return PushforwardPullbackPrep(pushforward_prep) + pushforward_prep = prepare_pushforward(f!, y, backend, x, (dx,), contexts...; strict) + return PushforwardPullbackPrep{SIG,typeof(pushforward_prep)}(pushforward_prep) end ## One argument @@ -202,6 +217,7 @@ function value_and_pullback( ty::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep y = f(x, map(unwrap, contexts)...) tx = ntuple( @@ -220,6 +236,7 @@ function value_and_pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) y, new_tx = value_and_pullback(f, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) return y, tx @@ -233,6 +250,7 @@ function pullback( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback(f, prep, backend, x, ty, contexts...)[2] end @@ -245,6 +263,7 @@ function pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2] end @@ -325,6 +344,7 @@ function value_and_pullback( ty::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep tx = ntuple( b -> _pullback_via_pushforward( @@ -346,6 +366,7 @@ function value_and_pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) y, new_tx = value_and_pullback(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) return y, tx @@ -360,6 +381,7 @@ function pullback( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2] end @@ -373,5 +395,6 @@ function pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2] end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 91a967ad6..dd6f5cea8 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_pushforward(f, backend, x, tx, [contexts...]) -> prep - prepare_pushforward(f!, y, backend, x, tx, [contexts...]) -> prep + prepare_pushforward(f, backend, x, tx, [contexts...]; strict=false) -> prep + prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=false) -> prep $(docstring_prepare("pushforward"; inplace=true)) """ @@ -17,8 +17,8 @@ $(docstring_prepare!("pushforward")) function prepare!_pushforward end """ - prepare_pushforward_same_point(f, backend, x, tx, [contexts...]) -> prep_same - prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]) -> prep_same + prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=false) -> prep_same + prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=false) -> prep_same $(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ @@ -85,23 +85,34 @@ function pushforward! end ## Preparation -struct PullbackPushforwardPrep{E} <: PushforwardPrep +struct PullbackPushforwardPrep{SIG,E} <: PushforwardPrep{SIG} pullback_prep::E end function prepare_pushforward( - f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} return _prepare_pushforward_aux( - pushforward_performance(backend), f, backend, x, tx, contexts... + pushforward_performance(backend), f, backend, x, tx, contexts...; strict ) end function prepare_pushforward( - f!::F, y, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} return _prepare_pushforward_aux( - pushforward_performance(backend), f!, y, backend, x, tx, contexts... + pushforward_performance(backend), f!, y, backend, x, tx, contexts...; strict ) end @@ -111,12 +122,14 @@ function _prepare_pushforward_aux( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(f, backend, x, (dy,), contexts...) - return PullbackPushforwardPrep(pullback_prep) + pullback_prep = prepare_pullback(f, backend, x, (dy,), contexts...; strict) + return PullbackPushforwardPrep{SIG,typeof(pullback_prep)}(pullback_prep) end function _prepare_pushforward_aux( @@ -126,11 +139,13 @@ function _prepare_pushforward_aux( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f!, y, backend, x, tx, contexts...; strict) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(f!, y, backend, x, (dy,), contexts...) - return PullbackPushforwardPrep(pullback_prep) + pullback_prep = prepare_pullback(f!, y, backend, x, (dy,), contexts...; strict) + return PullbackPushforwardPrep{SIG,typeof(pullback_prep)}(pullback_prep) end ## One argument @@ -205,6 +220,7 @@ function value_and_pushforward( tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, tx, contexts...) (; pullback_prep) = prep y = f(x, map(unwrap, contexts)...) ty = ntuple( @@ -223,6 +239,7 @@ function value_and_pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) y, new_ty = value_and_pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) return y, ty @@ -236,6 +253,7 @@ function pushforward( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] end @@ -248,6 +266,7 @@ function pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)[2] end @@ -297,6 +316,7 @@ function value_and_pushforward( tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) (; pullback_prep) = prep ty = ntuple( b -> @@ -317,6 +337,7 @@ function value_and_pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) y, new_ty = value_and_pushforward(f!, y, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) return y, ty @@ -331,6 +352,7 @@ function pushforward( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2] end @@ -344,6 +366,7 @@ function pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)[2] end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 496675503..9fdbeeb52 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -1,14 +1,14 @@ abstract type FromPrimitive{inplace} <: AbstractADType end -check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) +check_available(backend::FromPrimitive) = check_available(backend.backend) inplace_support(::FromPrimitive{true}) = InPlaceSupported() inplace_support(::FromPrimitive{false}) = InPlaceNotSupported() -function inner_preparation_behavior(fromprim::FromPrimitive) - return inner_preparation_behavior(fromprim.backend) +function inner_preparation_behavior(backend::FromPrimitive) + return inner_preparation_behavior(backend.backend) end -function pick_batchsize(fromprim::FromPrimitive, N::Integer) - return pick_batchsize(fromprim.backend, N) +function pick_batchsize(backend::FromPrimitive, N::Integer) + return pick_batchsize(backend.backend, N) end """ @@ -29,41 +29,55 @@ end ADTypes.mode(::AutoForwardFromPrimitive) = ADTypes.ForwardMode() function threshold_batchsize( - fromprim::AutoForwardFromPrimitive{inplace}, dimension::Integer + backend::AutoForwardFromPrimitive{inplace}, dimension::Integer ) where {inplace} return AutoForwardFromPrimitive( - threshold_batchsize(fromprim.backend, dimension); inplace + threshold_batchsize(backend.backend, dimension); inplace ) end -struct FromPrimitivePushforwardPrep{E<:PushforwardPrep} <: PushforwardPrep +struct FromPrimitivePushforwardPrep{SIG,E<:PushforwardPrep} <: PushforwardPrep{SIG} pushforward_prep::E end function prepare_pushforward( - f::F, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - primitive_prep = prepare_pushforward(f, fromprim.backend, x, tx, contexts...) - return FromPrimitivePushforwardPrep(primitive_prep) + SIG = signature(f, backend, x, tx, contexts...; strict) + primitive_prep = prepare_pushforward(f, backend.backend, x, tx, contexts...; strict) + return FromPrimitivePushforwardPrep{SIG,typeof(primitive_prep)}(primitive_prep) end function prepare_pushforward( - f!::F, y, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - primitive_prep = prepare_pushforward(f!, y, fromprim.backend, x, tx, contexts...) - return FromPrimitivePushforwardPrep(primitive_prep) + SIG = signature(f!, y, backend, x, tx, contexts...; strict) + primitive_prep = prepare_pushforward(f!, y, backend.backend, x, tx, contexts...; strict) + return FromPrimitivePushforwardPrep{SIG,typeof(primitive_prep)}(primitive_prep) end function value_and_pushforward( f::F, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward( - f, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -71,13 +85,14 @@ function value_and_pushforward( f!::F, y, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward( - f!, y, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f!, y, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -85,13 +100,14 @@ function value_and_pushforward!( f::F, ty::NTuple, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward!( - f, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f, ty, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -100,13 +116,14 @@ function value_and_pushforward!( y, ty::NTuple, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward!( - f!, y, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f!, y, ty, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -128,53 +145,68 @@ end ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode() function threshold_batchsize( - fromprim::AutoReverseFromPrimitive{inplace}, dimension::Integer + backend::AutoReverseFromPrimitive{inplace}, dimension::Integer ) where {inplace} return AutoReverseFromPrimitive( - threshold_batchsize(fromprim.backend, dimension); inplace + threshold_batchsize(backend.backend, dimension); inplace ) end -struct FromPrimitivePullbackPrep{E<:PullbackPrep} <: PullbackPrep +struct FromPrimitivePullbackPrep{SIG,E<:PullbackPrep} <: PullbackPrep{SIG} pullback_prep::E end function prepare_pullback( - f::F, fromprim::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - primitive_prep = prepare_pullback(f, fromprim.backend, x, ty, contexts...) - return FromPrimitivePullbackPrep(primitive_prep) + SIG = signature(f, backend, x, ty, contexts...; strict) + primitive_prep = prepare_pullback(f, backend.backend, x, ty, contexts...; strict) + return FromPrimitivePullbackPrep{SIG,typeof(primitive_prep)}(primitive_prep) end function prepare_pullback( - f!::F, y, fromprim::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - primitive_prep = prepare_pullback(f!, y, fromprim.backend, x, ty, contexts...) - return FromPrimitivePullbackPrep(primitive_prep) + SIG = signature(f!, y, backend, x, ty, contexts...; strict) + primitive_prep = prepare_pullback(f!, y, backend.backend, x, ty, contexts...; strict) + return FromPrimitivePullbackPrep{SIG,typeof(primitive_prep)}(primitive_prep) end function value_and_pullback( f::F, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - return value_and_pullback(f, prep.pullback_prep, fromprim.backend, x, ty, contexts...) + check_prep(f, prep, backend, x, ty, contexts...) + return value_and_pullback(f, prep.pullback_prep, backend.backend, x, ty, contexts...) end function value_and_pullback( f!::F, y, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback( - f!, y, prep.pullback_prep, fromprim.backend, x, ty, contexts... + f!, y, prep.pullback_prep, backend.backend, x, ty, contexts... ) end @@ -182,13 +214,14 @@ function value_and_pullback!( f::F, tx::NTuple, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback!( - f, tx, prep.pullback_prep, fromprim.backend, x, ty, contexts... + f, tx, prep.pullback_prep, backend.backend, x, ty, contexts... ) end @@ -197,12 +230,13 @@ function value_and_pullback!( y, tx::NTuple, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback!( - f!, y, tx, prep.pullback_prep, fromprim.backend, x, ty, contexts... + f!, y, tx, prep.pullback_prep, backend.backend, x, ty, contexts... ) end diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index b36f36503..b34053474 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -37,25 +37,39 @@ function threshold_batchsize( end function prepare_pushforward( - f::F, ::AutoSimpleFiniteDiff, x, tx::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - return NoPushforwardPrep() + SIG = signature(f, backend, x, tx, contexts...; strict) + return NoPushforwardPrep{SIG}() end function prepare_pushforward( - f!::F, y, ::AutoSimpleFiniteDiff, x, tx::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - return NoPushforwardPrep() + SIG = signature(f!, y, backend, x, tx, contexts...; strict) + return NoPushforwardPrep{SIG}() end function value_and_pushforward( f::F, - ::NoPushforwardPrep, + prep::NoPushforwardPrep, backend::AutoSimpleFiniteDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, tx, contexts...) ε = eltype(x)(backend.ε) y = f(x, map(unwrap, contexts)...) ty = map(tx) do dx @@ -69,12 +83,13 @@ end function value_and_pushforward( f!::F, y, - ::NoPushforwardPrep, + prep::NoPushforwardPrep, backend::AutoSimpleFiniteDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) ε = eltype(x)(backend.ε) ty = map(tx) do dx f!(y, x + ε * dx, map(unwrap, contexts)...) diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index a340edd8f..5dcbf47da 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -21,25 +21,39 @@ check_available(::AutoZeroForward) = true inplace_support(::AutoZeroForward) = InPlaceSupported() function prepare_pushforward( - f::F, ::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AutoZeroForward, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - return NoPushforwardPrep() + SIG = signature(f, backend, x, tx, contexts...; strict) + return NoPushforwardPrep{SIG}() end function prepare_pushforward( - f!::F, y, ::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AutoZeroForward, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - return NoPushforwardPrep() + SIG = signature(f!, y, backend, x, tx, contexts...; strict) + return NoPushforwardPrep{SIG}() end function value_and_pushforward( f::F, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, tx, contexts...) y = f(x, map(unwrap, contexts)...) ty = map(ReturnZero(y), tx) return y, ty @@ -48,12 +62,13 @@ end function value_and_pushforward( f!::F, y, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(unwrap, contexts)...) ty = map(ReturnZero(y), tx) return y, ty @@ -62,12 +77,13 @@ end function value_and_pushforward!( f::F, ty::NTuple, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) y = f(x, map(unwrap, contexts)...) for b in eachindex(ty) _zero!(ty[b]) @@ -79,12 +95,13 @@ function value_and_pushforward!( f!::F, y, ty::NTuple, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(unwrap, contexts)...) for b in eachindex(ty) _zero!(ty[b]) @@ -107,20 +124,39 @@ check_available(::AutoZeroReverse) = true inplace_support(::AutoZeroReverse) = InPlaceSupported() function prepare_pullback( - f::F, ::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AutoZeroReverse, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - return NoPullbackPrep() + SIG = signature(f, backend, x, ty, contexts...; strict) + return NoPullbackPrep{SIG}() end function prepare_pullback( - f!::F, y, ::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C} + f!::F, + y, + backend::AutoZeroReverse, + x, + ty::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} - return NoPullbackPrep() + SIG = signature(f!, y, backend, x, ty, contexts...; strict) + return NoPullbackPrep{SIG}() end function value_and_pullback( - f::F, ::NoPullbackPrep, ::AutoZeroReverse, x, ty::NTuple{B}, contexts::Vararg{Context,C} + f::F, + prep::NoPullbackPrep, + backend::AutoZeroReverse, + x, + ty::NTuple{B}, + contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, ty, contexts...) y = f(x, map(unwrap, contexts)...) tx = ntuple(ReturnZero(x), Val(B)) return y, tx @@ -129,12 +165,13 @@ end function value_and_pullback( f!::F, y, - ::NoPullbackPrep, - ::AutoZeroReverse, + prep::NoPullbackPrep, + backend::AutoZeroReverse, x, ty::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(unwrap, contexts)...) tx = ntuple(ReturnZero(x), Val(B)) return y, tx @@ -143,12 +180,13 @@ end function value_and_pullback!( f::F, tx::NTuple, - ::NoPullbackPrep, - ::AutoZeroReverse, + prep::NoPullbackPrep, + backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) y = f(x, map(unwrap, contexts)...) for b in eachindex(tx) _zero!(tx[b]) @@ -160,12 +198,13 @@ function value_and_pullback!( f!::F, y, tx::NTuple, - ::NoPullbackPrep, - ::AutoZeroReverse, + prep::NoPullbackPrep, + backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(unwrap, contexts)...) for b in eachindex(tx) _zero!(tx[b]) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 3dc8b7916..a4cf03b0b 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_hessian(f, backend, x, [contexts...]) -> prep + prepare_hessian(f, backend, x, [contexts...]; strict=false) -> prep $(docstring_prepare("hessian")) """ @@ -53,12 +53,14 @@ function value_gradient_and_hessian! end ## Preparation struct HVPGradientHessianPrep{ + SIG, BS<:BatchSizeSettings, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E2<:HVPPrep, E1<:GradientPrep, -} <: HessianPrep +} <: HessianPrep{SIG} + _sig::Type{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -67,12 +69,12 @@ struct HVPGradientHessianPrep{ end function prepare_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false ) where {F,C} # type-unstable batch_size_settings = pick_batchsize(outer(backend), x) # function barrier - return _prepare_hessian_aux(batch_size_settings, f, backend, x, contexts...) + return _prepare_hessian_aux(batch_size_settings, f, backend, x, contexts...; strict) end function _prepare_hessian_aux( @@ -80,18 +82,20 @@ function _prepare_hessian_aux( f::F, backend::AbstractADType, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {B,F,C} + SIG = signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = prepare_hvp(f, backend, x, batched_seeds[1], contexts...) - gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) + hvp_prep = prepare_hvp(f, backend, x, batched_seeds[1], contexts...; strict) + gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) return HVPGradientHessianPrep( - batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep + SIG, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep ) end @@ -99,11 +103,12 @@ end function hessian( f::F, - prep::HVPGradientHessianPrep{<:BatchSizeSettings{B,true}}, + prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B,true}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,B,C} +) where {F,SIG,B,C} + check_prep(f, prep, backend, x, contexts...) (; batched_seeds, hvp_prep) = prep dg_batch = hvp(f, hvp_prep, backend, x, only(batched_seeds), contexts...) block = stack_vec_col(dg_batch) @@ -112,11 +117,12 @@ end function hessian( f::F, - prep::HVPGradientHessianPrep{<:BatchSizeSettings{B,false,aligned}}, + prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,B,aligned,C} +) where {F,SIG,B,aligned,C} + check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, batched_seeds, hvp_prep) = prep (; A, B_last) = batch_size_settings @@ -139,11 +145,12 @@ end function hessian!( f::F, hess, - prep::HVPGradientHessianPrep{<:BatchSizeSettings{B}}, + prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,B,C} +) where {F,SIG,B,C} + check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep (; N) = batch_size_settings @@ -173,6 +180,7 @@ function value_gradient_and_hessian( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, grad = value_and_gradient(f, prep.gradient_prep, inner(backend), x, contexts...) hess = hessian(f, prep, backend, x, contexts...) return y, grad, hess @@ -187,6 +195,7 @@ function value_gradient_and_hessian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, _ = value_and_gradient!(f, grad, prep.gradient_prep, inner(backend), x, contexts...) hessian!(f, hess, prep, backend, x, contexts...) return y, grad, hess diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index edc1699f8..b3766dfe4 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_hvp(f, backend, x, tx, [contexts...]) -> prep + prepare_hvp(f, backend, x, tx, [contexts...]; strict=false) -> prep $(docstring_prepare("hvp")) """ @@ -15,7 +15,7 @@ $(docstring_prepare("hvp")) function prepare!_hvp end """ - prepare_hvp_same_point(f, backend, x, tx, [contexts...]) -> prep_same + prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=false) -> prep_same $(docstring_prepare("hvp"; samepoint=true)) """ @@ -58,7 +58,12 @@ $(docstring_preparation_hint("hvp"; same_point=true)) function gradient_and_hvp! end function prepare_hvp( - f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; + strict::Bool=false, ) where {F,C} return _prepare_hvp_aux( hvp_mode(backend), @@ -67,14 +72,16 @@ function prepare_hvp( backend, x, tx, - contexts..., + contexts...; + strict, ) end ## Forward over anything -struct ForwardOverAnythingHVPPrep{G,GO,GI,PO,PI} <: HVPPrep +struct ForwardOverAnythingHVPPrep{SIG,G,GO,GI,PO,PI} <: HVPPrep{SIG} # pushforward of many pushforwards in theory, but pushforward of gradient in practice + _sig::Type{SIG} grad_buffer::G maybe_inner_gradient_prep::GO maybe_inner_gradient_in_prep::GI @@ -89,8 +96,10 @@ function _prepare_hvp_aux( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Outer pushforward @@ -98,17 +107,23 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... + shuffled_gradient, outer(backend), x, tx, new_contexts...; strict ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts...; + strict, ) else nothing end return ForwardOverAnythingHVPPrep( - grad_buffer, (), (), outer_pushforward_prep, outer_pushforward_in_prep + SIG, grad_buffer, (), (), outer_pushforward_prep, outer_pushforward_in_prep ) end @@ -119,12 +134,14 @@ function _prepare_hvp_aux( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient - inner_gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) + inner_gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) inner_gradient_in_prep = inner_gradient_prep # Outer pushforward new_contexts = ( @@ -142,16 +159,23 @@ function _prepare_hvp_aux( contexts..., ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... + shuffled_gradient, outer(backend), x, tx, new_contexts...; strict ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in... + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts_in...; + strict, ) else nothing end return ForwardOverAnythingHVPPrep( + SIG, grad_buffer, (inner_gradient_prep,), (inner_gradient_in_prep,), @@ -167,8 +191,10 @@ function _prepare_hvp_aux( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient @@ -176,8 +202,8 @@ function _prepare_hvp_aux( xoi = overloaded_input( pushforward, shuffled_gradient!, grad_buffer, outer(backend), x, tx ) - inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...) - inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contexts...) + inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...; strict) + inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contexts...; strict) # Outer pushforward new_contexts = ( FunctionContext(f), @@ -194,16 +220,23 @@ function _prepare_hvp_aux( contexts..., ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... + shuffled_gradient, outer(backend), x, tx, new_contexts...; strict ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in... + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts_in...; + strict, ) else nothing end return ForwardOverAnythingHVPPrep( + SIG, grad_buffer, (inner_gradient_prep,), (inner_gradient_in_prep,), @@ -220,6 +253,7 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -243,6 +277,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _hvp_aux!( inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... ) @@ -317,6 +352,7 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -341,6 +377,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _gradient_and_hvp_aux!( inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... ) @@ -413,8 +450,9 @@ end ## Reverse over forward -struct ReverseOverForwardHVPPrep{G2<:GradientPrep,G1<:GradientPrep} <: HVPPrep +struct ReverseOverForwardHVPPrep{SIG,G2<:GradientPrep,G1<:GradientPrep} <: HVPPrep{SIG} # gradient of pushforward + _sig::Type{SIG} outer_gradient_prep::G2 gradient_prep::G1 end @@ -426,8 +464,10 @@ function _prepare_hvp_aux( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), @@ -437,10 +477,10 @@ function _prepare_hvp_aux( contexts..., ) outer_gradient_prep = prepare_gradient( - shuffled_single_pushforward, outer(backend), x, new_contexts... + shuffled_single_pushforward, outer(backend), x, new_contexts...; strict ) - gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) - return ReverseOverForwardHVPPrep(outer_gradient_prep, gradient_prep) + gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) + return ReverseOverForwardHVPPrep(SIG, outer_gradient_prep, gradient_prep) end function hvp( @@ -449,8 +489,9 @@ function hvp( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_gradient_prep) = prep rewrap = Rewrap(contexts...) tg = map(tx) do dx @@ -478,6 +519,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_gradient_prep) = prep rewrap = Rewrap(contexts...) for b in eachindex(tx, tg) @@ -505,6 +547,7 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) tg = hvp(f, prep, backend, x, tx, contexts...) grad = gradient(f, prep.gradient_prep, inner(backend), x, contexts...) return grad, tg @@ -520,6 +563,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) hvp!(f, tg, prep, backend, x, tx, contexts...) gradient!(f, grad, prep.gradient_prep, inner(backend), x, contexts...) return grad, tg @@ -527,8 +571,9 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{G,PO,PI} <: HVPPrep +struct ReverseOverReverseHVPPrep{SIG,G,PO,PI} <: HVPPrep{SIG} # pullback of gradient + _sig::Type{SIG} grad_buffer::G outer_pullback_prep::PO outer_pullback_in_prep::PI @@ -541,25 +586,33 @@ function _prepare_hvp_aux( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; + strict::Bool, ) where {F,C} + SIG = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) outer_pullback_prep = prepare_pullback( - shuffled_gradient, outer(backend), x, tx, new_contexts... + shuffled_gradient, outer(backend), x, tx, new_contexts...; strict ) outer_pullback_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pullback( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts...; + strict, ) else nothing end return ReverseOverReverseHVPPrep( - grad_buffer, outer_pullback_prep, outer_pullback_in_prep + SIG, grad_buffer, outer_pullback_prep, outer_pullback_in_prep ) end @@ -571,6 +624,7 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -590,6 +644,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _hvp_aux!( inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... ) @@ -650,6 +705,7 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -670,6 +726,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _gradient_and_hvp_aux!( inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... ) diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 52f07d74b..6cda162a8 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_second_derivative(f, backend, x, [contexts...]) -> prep + prepare_second_derivative(f, backend, x, [contexts...]; strict=false) -> prep $(docstring_prepare("second_derivative")) """ @@ -52,21 +52,24 @@ function value_derivative_and_second_derivative! end ## Preparation -struct DerivativeSecondDerivativePrep{E<:DerivativePrep} <: SecondDerivativePrep +struct DerivativeSecondDerivativePrep{SIG,E<:DerivativePrep} <: SecondDerivativePrep{SIG} outer_derivative_prep::E end function prepare_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false ) where {F,C} + SIG = signature(f, backend, x, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) outer_derivative_prep = prepare_derivative( - shuffled_derivative, outer(backend), x, new_contexts... + shuffled_derivative, outer(backend), x, new_contexts...; strict + ) + return DerivativeSecondDerivativePrep{SIG,typeof(outer_derivative_prep)}( + outer_derivative_prep ) - return DerivativeSecondDerivativePrep(outer_derivative_prep) end ## One argument @@ -78,6 +81,7 @@ function second_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -95,6 +99,7 @@ function value_derivative_and_second_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -115,6 +120,7 @@ function second_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -134,6 +140,7 @@ function value_derivative_and_second_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 0fb7af0df..b547e28f2 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -1,49 +1,125 @@ -abstract type Prep end +abstract type Prep{SIG} end """ $(docstring_preptype("PushforwardPrep", "pushforward")) """ -abstract type PushforwardPrep <: Prep end -struct NoPushforwardPrep <: PushforwardPrep end +abstract type PushforwardPrep{SIG} <: Prep{SIG} end +struct NoPushforwardPrep{SIG} <: PushforwardPrep{SIG} end """ $(docstring_preptype("PullbackPrep", "pullback")) """ -abstract type PullbackPrep <: Prep end -struct NoPullbackPrep <: PullbackPrep end +abstract type PullbackPrep{SIG} <: Prep{SIG} end +struct NoPullbackPrep{SIG} <: PullbackPrep{SIG} end """ $(docstring_preptype("DerivativePrep", "derivative")) """ -abstract type DerivativePrep <: Prep end -struct NoDerivativePrep <: DerivativePrep end +abstract type DerivativePrep{SIG} <: Prep{SIG} end +struct NoDerivativePrep{SIG} <: DerivativePrep{SIG} end """ $(docstring_preptype("GradientPrep", "gradient")) """ -abstract type GradientPrep <: Prep end -struct NoGradientPrep <: GradientPrep end +abstract type GradientPrep{SIG} <: Prep{SIG} end +struct NoGradientPrep{SIG} <: GradientPrep{SIG} end """ $(docstring_preptype("JacobianPrep", "jacobian")) """ -abstract type JacobianPrep <: Prep end -struct NoJacobianPrep <: JacobianPrep end +abstract type JacobianPrep{SIG} <: Prep{SIG} end +struct NoJacobianPrep{SIG} <: JacobianPrep{SIG} end """ $(docstring_preptype("HVPPrep", "hvp")) """ -abstract type HVPPrep <: Prep end -struct NoHVPPrep <: HVPPrep end +abstract type HVPPrep{SIG} <: Prep{SIG} end +struct NoHVPPrep{SIG} <: HVPPrep{SIG} end """ $(docstring_preptype("HessianPrep", "hessian")) """ -abstract type HessianPrep <: Prep end -struct NoHessianPrep <: HessianPrep end +abstract type HessianPrep{SIG} <: Prep{SIG} end +struct NoHessianPrep{SIG} <: HessianPrep{SIG} end """ $(docstring_preptype("SecondDerivativePrep", "second_derivative")) """ -abstract type SecondDerivativePrep <: Prep end -struct NoSecondDerivativePrep <: SecondDerivativePrep end +abstract type SecondDerivativePrep{SIG} <: Prep{SIG} end +struct NoSecondDerivativePrep{SIG} <: SecondDerivativePrep{SIG} end + +## Checks + +is_strict(::Prep{SIG}) where {SIG} = SIG !== Nothing + +function signature( + f, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool +) where {C} + if strict + return typeof((f, backend, x, contexts)) + else + return Nothing + end +end + +function signature( + f!, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool +) where {C} + if strict + return typeof((f!, y, backend, x, contexts)) + else + return Nothing + end +end + +function signature( + f, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}; strict::Bool +) where {C} + if strict + return typeof((f, backend, x, t, contexts)) + else + return Nothing + end +end + +function signature( + f!, y, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}; strict::Bool +) where {C} + if strict + return typeof((f!, y, backend, x, t, contexts)) + else + return Nothing + end +end + +function check_prep( + f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + @assert SIG == typeof((f, backend, x, contexts)) + end +end + +function check_prep( + f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + @assert SIG == typeof((f!, y, backend, x, contexts)) + end +end + +function check_prep( + f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + @assert SIG == typeof((f, backend, x, t, contexts)) + end +end + +function check_prep( + f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + @assert SIG == typeof((f!, y, backend, x, t, contexts)) + end +end diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 9500cf7aa..e0ad40df2 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -62,7 +62,13 @@ for op in ALL_OPS prep = $prep_op(f, ba, xrand, contextsrand...) prepprep = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=true, + ), ba, xrand, contextsrand..., @@ -121,7 +127,13 @@ for op in ALL_OPS prep = $prep_op(f, ba, xrand, contextsrand...) prepprep = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=true, + ), ba, xrand, contextsrand..., @@ -200,7 +212,8 @@ for op in ALL_OPS copy(new_smaller.y), ba, new_smaller.x, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, @@ -272,7 +285,8 @@ for op in ALL_OPS copy(new_smaller.y), ba, new_smaller.x, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, @@ -347,7 +361,13 @@ for op in ALL_OPS prep = $prep_op(f, ba, xrand, contextsrand...) prepprep = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=true, + ), ba, xrand, contextsrand..., @@ -408,7 +428,13 @@ for op in ALL_OPS prep = $prep_op(f, ba, xrand, contextsrand...) prepprep = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=true, + ), ba, xrand, contextsrand..., @@ -489,7 +515,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, @@ -552,7 +579,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, @@ -629,7 +657,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, @@ -704,7 +733,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, @@ -796,7 +826,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, @@ -859,7 +890,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=true, ), ba, xrand, From de6d61492c4a814e1aca4bca23f727fbb8585358 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 19:43:45 +0100 Subject: [PATCH 02/22] Fixes --- .../onearg.jl | 40 +++++++++---------- .../twoarg.jl | 20 +++++----- .../src/first_order/gradient.jl | 2 +- DifferentiationInterface/src/utils/prep.jl | 29 ++++++++++++-- 4 files changed, 56 insertions(+), 35 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 1024f2119..ba49b9f39 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -40,12 +40,12 @@ end function DI.pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{Nothing}, + prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) @@ -59,12 +59,12 @@ end function DI.value_and_pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{Nothing}, + prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) @@ -86,12 +86,12 @@ end function DI.pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache}, + prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) @@ -103,12 +103,12 @@ end function DI.value_and_pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache}, + prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) @@ -159,11 +159,11 @@ end function DI.derivative( f, - prep::FiniteDiffOneArgDerivativePrep{Nothing}, + prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) @@ -172,11 +172,11 @@ end function DI.value_and_derivative( f, - prep::FiniteDiffOneArgDerivativePrep{Nothing}, + prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) @@ -193,11 +193,11 @@ end function DI.derivative( f, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) @@ -207,11 +207,11 @@ end function DI.derivative!( f, der, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) @@ -220,11 +220,11 @@ end function DI.value_and_derivative( f, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) (; relstep, absstep, dir) = prep @@ -235,11 +235,11 @@ end function DI.value_and_derivative!( f, der, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index b36adf5d7..12c6dffbd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -40,12 +40,12 @@ end function DI.value_and_pushforward( f!, y, - prep::FiniteDiffTwoArgPushforwardPrep{Nothing}, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,Nothing}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep function step(t::Number, dx) @@ -72,12 +72,12 @@ end function DI.pushforward( f!, y, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) @@ -92,12 +92,12 @@ end function DI.value_and_pushforward( f!, y, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) @@ -114,12 +114,12 @@ function DI.pushforward!( f!, y, ty::NTuple, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) @@ -134,12 +134,12 @@ function DI.value_and_pushforward!( f!, y, ty::NTuple, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 6f232fbda..223c3d2e5 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -61,7 +61,7 @@ function prepare_gradient( ) where {F,C} SIG = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? - pullback_prep = prepare_pullback(f, backend, x, (true,), contexts...; strict) + pullback_prep = prepare_pullback(f, backend, x, (one(typeof(y)),), contexts...; strict) return PullbackGradientPrep{SIG,typeof(y),typeof(pullback_prep)}(pullback_prep) end diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index b547e28f2..bfe6fb66a 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -52,6 +52,15 @@ struct NoSecondDerivativePrep{SIG} <: SecondDerivativePrep{SIG} end is_strict(::Prep{SIG}) where {SIG} = SIG !== Nothing +function inconsistent_signatures_error(SIG, RUNTIME_SIG) + msg = """ + Inconsistent signatures: + - at preparation time: $SIG + - at execution time: $RUNTIME_SIG + """ + return ArgumentError(msg) +end + function signature( f, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool ) where {C} @@ -96,7 +105,10 @@ function check_prep( f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - @assert SIG == typeof((f, backend, x, contexts)) + RUNTIME_SIG = typeof((f, backend, x, contexts)) + if SIG != RUNTIME_SIG + throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + end end end @@ -104,7 +116,10 @@ function check_prep( f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - @assert SIG == typeof((f!, y, backend, x, contexts)) + RUNTIME_SIG = typeof((f!, y, backend, x, contexts)) + if SIG != RUNTIME_SIG + throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + end end end @@ -112,7 +127,10 @@ function check_prep( f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - @assert SIG == typeof((f, backend, x, t, contexts)) + RUNTIME_SIG = typeof((f, backend, x, t, contexts)) + if SIG != RUNTIME_SIG + throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + end end end @@ -120,6 +138,9 @@ function check_prep( f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - @assert SIG == typeof((f!, y, backend, x, t, contexts)) + RUNTIME_SIG = typeof((f!, y, backend, x, t, contexts)) + if SIG != RUNTIME_SIG + throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + end end end From 960ea1f7a8ee8c63b909edec673269bc83847ee4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 19:58:06 +0100 Subject: [PATCH 03/22] Fixes --- .../onearg.jl | 24 ++++++------ .../twoarg.jl | 32 +++++++-------- .../DifferentiationInterfaceZygoteExt.jl | 39 +++++++++++++------ 3 files changed, 56 insertions(+), 39 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 35133ae8e..d46914708 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -88,11 +88,11 @@ end function compute_ydual_onearg( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, x::Number, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} xdual = make_dual(T, x, tx) contexts_dual = translate_prepared(contexts, prep.contexts_dual) ydual = f(xdual, contexts_dual...) @@ -101,11 +101,11 @@ end function compute_ydual_onearg( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} if DI.ismutable_array(x) make_dual!(T, prep.xdual_tmp, x, tx) xdual_tmp = prep.xdual_tmp @@ -119,12 +119,12 @@ end function DI.value_and_pushforward( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) @@ -135,12 +135,12 @@ end function DI.value_and_pushforward!( f::F, ty::NTuple, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) @@ -150,12 +150,12 @@ end function DI.pushforward( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual) @@ -165,12 +165,12 @@ end function DI.pushforward!( f::F, ty::NTuple, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) mypartials!(T, ty, ydual) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 47fb69b6d..ab122f175 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -30,11 +30,11 @@ end function compute_ydual_twoarg( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, x::Number, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} (; ydual_tmp) = prep xdual_tmp = make_dual(T, x, tx) contexts_dual = translate_prepared(contexts, prep.contexts_dual) @@ -45,11 +45,11 @@ end function compute_ydual_twoarg( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} (; xdual_tmp, ydual_tmp) = prep make_dual!(T, xdual_tmp, x, tx) contexts_dual = translate_prepared(contexts, prep.contexts_dual) @@ -60,12 +60,12 @@ end function DI.value_and_pushforward( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) @@ -77,12 +77,12 @@ function DI.value_and_pushforward!( f!::F, y, ty::NTuple, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) @@ -93,12 +93,12 @@ end function DI.pushforward( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual_tmp) @@ -109,12 +109,12 @@ function DI.pushforward!( f!::F, y, ty::NTuple, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) mypartials!(T, ty, ydual_tmp) @@ -425,7 +425,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - DI.check_prep(f!, y, old_prep, backend, x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) jac = similar(y, length(y), length(x)) @@ -447,7 +447,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - DI.check_prep(f!, y, old_prep, backend, x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) result = MutableDiffResult(y, (jac,)) @@ -467,7 +467,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - DI.check_prep(f!, y, old_prep, backend, x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -486,7 +486,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - DI.check_prep(f!, y, old_prep, backend, x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 0ef65e91b..4e87d29eb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -46,6 +46,7 @@ function DI.prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{DI.Context,C}; + strict::Bool=false, ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) SIG = DI.signature(f, backend, x, ty, contexts...; strict) @@ -187,6 +188,11 @@ end # Beware, this uses ForwardDiff for the inner differentiation +struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG} + _sig::Type{SIG} + fd_prep::P +end + function DI.prepare_hvp( f, backend::AutoZygote, @@ -195,27 +201,31 @@ function DI.prepare_hvp( contexts::Vararg{DI.Context,C}; strict::Bool=false, ) where {C} - return DI.prepare_hvp( + SIG = DI.signature(f, backend, x, tx, contexts...; strict) + fd_prep = DI.prepare_hvp( f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...; strict ) + return ZygoteHVPPrep(SIG, fd_prep) end function DI.hvp( f, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) - return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) + return DI.hvp( + f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + ) end function DI.hvp!( f, tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -223,13 +233,13 @@ function DI.hvp!( ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.hvp!( - f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + f, tg, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) end function DI.gradient_and_hvp( f, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -237,7 +247,7 @@ function DI.gradient_and_hvp( ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) end @@ -245,7 +255,7 @@ function DI.gradient_and_hvp!( f, grad, tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -253,7 +263,14 @@ function DI.gradient_and_hvp!( ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp!( - f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + f, + grad, + tg, + prep.fd_prep, + DI.SecondOrder(AutoForwardDiff(), backend), + x, + tx, + contexts..., ) end @@ -303,7 +320,7 @@ function DI.value_gradient_and_hessian( contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) + y, grad = DI.value_and_gradient(f, backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) return y, grad, hess end @@ -318,7 +335,7 @@ function DI.value_gradient_and_hessian!( contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) + y, _ = DI.value_and_gradient!(f, grad, backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) return y, grad, hess end From 14cb10205d48c0980addbce875b1a20a73265bc0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 20:01:10 +0100 Subject: [PATCH 04/22] Fix --- .../jacobian.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 77b7a12b5..4edb8f47e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -46,7 +46,7 @@ function DI.prepare_jacobian( end function DI.prepare_jacobian( - f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} + f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Bool=false ) where {F,C} dense_backend = dense_ad(backend) perf = DI.pushforward_performance(dense_backend) From 0a4b355f3688fdc5a70dd1a4bfa8441436a96805 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 20:10:43 +0100 Subject: [PATCH 05/22] Fixes --- .../hessian.jl | 8 ++++---- .../jacobian.jl | 8 ++++---- .../jacobian_mixed.jl | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 39e81e30d..b26fd2076 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -79,11 +79,11 @@ end function DI.hessian!( f::F, hess, - prep::SparseHessianPrep{<:DI.BatchSizeSettings{B}}, + prep::SparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {F,B,C} +) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, @@ -124,8 +124,8 @@ function DI.hessian!( end function DI.hessian( - f::F, prep::SparseHessianPrep{B}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} -) where {F,B,C} + f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} +) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) hess = similar(sparsity_pattern(prep), eltype(x)) return DI.hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 4edb8f47e..5411693c5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -260,11 +260,11 @@ end function _sparse_jacobian_aux!( f_or_f!y::FY, jac, - prep::PushforwardSparseJacobianPrep{<:DI.BatchSizeSettings{B}}, + prep::PushforwardSparseJacobianPrep{SIG,<:DI.BatchSizeSettings{B}}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, coloring_result, @@ -306,11 +306,11 @@ end function _sparse_jacobian_aux!( f_or_f!y::FY, jac, - prep::PullbackSparseJacobianPrep{<:DI.BatchSizeSettings{B}}, + prep::PullbackSparseJacobianPrep{SIG,<:DI.BatchSizeSettings{B}}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, coloring_result, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index cc9267092..e3e375494 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -164,12 +164,12 @@ function _sparse_jacobian_aux!( f_or_f!y::FY, jac, prep::MixedModeSparseJacobianPrep{ - <:DI.BatchSizeSettings{Bf},<:DI.BatchSizeSettings{Br} + SIG,<:DI.BatchSizeSettings{Bf},<:DI.BatchSizeSettings{Br} }, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {FY,Bf,Br,C} +) where {FY,SIG,Bf,Br,C} (; batch_size_settings_forward, batch_size_settings_reverse, From 1251518b7f0929b9dd9da96bfcbdc4a4602f7223 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 22:48:09 +0100 Subject: [PATCH 06/22] Propagate with value types --- .../differentiate_with.jl | 2 +- .../reverse_onearg.jl | 14 +-- .../DifferentiationInterfaceDiffractorExt.jl | 4 +- .../forward_onearg.jl | 32 +++---- .../forward_twoarg.jl | 6 +- .../reverse_onearg.jl | 13 +-- .../reverse_twoarg.jl | 7 +- .../onearg.jl | 64 ++++++------- .../twoarg.jl | 32 +++---- .../onearg.jl | 40 ++++----- .../twoarg.jl | 34 ++++--- ...rentiationInterfaceFiniteDifferencesExt.jl | 32 ++++--- .../onearg.jl | 63 ++++++++----- .../twoarg.jl | 28 +++--- .../onearg.jl | 48 +++++----- .../twoarg.jl | 16 ++-- .../onearg.jl | 16 ++-- .../twoarg.jl | 8 +- .../onearg.jl | 48 +++++----- .../twoarg.jl | 24 ++--- .../onearg.jl | 54 +++++------ .../twoarg.jl | 27 +++--- .../sparsity_detector.jl | 12 +-- .../hessian.jl | 8 +- .../jacobian.jl | 18 ++-- .../jacobian_mixed.jl | 12 +-- .../DifferentiationInterfaceTrackerExt.jl | 26 ++++-- .../DifferentiationInterfaceZygoteExt.jl | 40 ++++----- .../src/fallbacks/change_prep.jl | 4 +- .../src/fallbacks/no_prep.jl | 42 ++++----- .../src/first_order/derivative.jl | 22 +++-- .../src/first_order/gradient.jl | 10 ++- .../src/first_order/jacobian.jl | 29 +++--- .../src/first_order/pullback.jl | 25 +++--- .../src/first_order/pushforward.jl | 25 +++--- .../src/misc/from_primitive.jl | 26 +++--- .../src/misc/simple_finite_diff.jl | 12 +-- .../src/misc/zero_backends.jl | 24 ++--- .../src/second_order/hessian.jl | 12 +-- .../src/second_order/hvp.jl | 42 ++++----- .../src/second_order/second_derivative.jl | 11 ++- DifferentiationInterface/src/utils/prep.jl | 89 +++++++++++++------ .../src/tests/correctness_eval.jl | 24 ++--- 43 files changed, 610 insertions(+), 515 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index 528ba061a..f226298d7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,); strict=true) + prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,); strict=Val(true)) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index ee691774f..6f6caba52 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -1,7 +1,7 @@ ## Pullback struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} y::Y pb::PB end @@ -12,10 +12,10 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) - return DI.NoPullbackPrep{SIG}() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( @@ -25,13 +25,13 @@ function DI.prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - SIG = DI.signature(f, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, backend, x, ty, contexts...; strict) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) - return ChainRulesPullbackPrepSamePoint(SIG, y, pb) + return ChainRulesPullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index bede2d064..92176fc11 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -11,8 +11,8 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward function DI.prepare_pushforward(f, backend::AutoDiffractor, x, tx::NTuple) - SIG = DI.signature(f, backend, x, tx) - return DI.NoPushforwardPrep{SIG}() + _sig = DI.signature(f, backend, x, tx) + return DI.NoPushforwardPrep(_sig) end function DI.pushforward( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 25e6c4801..13322d64e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -6,10 +6,10 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) - return DI.NoPushforwardPrep{SIG}() + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep(_sig) end function DI.value_and_pushforward( @@ -117,24 +117,22 @@ end ## Gradient struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG} + _sig::Val{SIG} + _valB::Val{B} shadows::O end -function EnzymeForwardGradientPrep(::Type{SIG}, ::Val{B}, shadows::O) where {SIG,B,O} - return EnzymeForwardGradientPrep{SIG,B,O}(shadows) -end - function DI.prepare_gradient( f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) - return EnzymeForwardGradientPrep(SIG, valB, shadows) + return EnzymeForwardGradientPrep(_sig, valB, shadows) end function DI.gradient( @@ -199,28 +197,24 @@ end ## Jacobian struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} + _valB::Val{B} shadows::O output_length::Int end -function EnzymeForwardOneArgJacobianPrep( - ::Type{SIG}, ::Val{B}, shadows::O, output_length::Integer -) where {SIG,B,O} - return EnzymeForwardOneArgJacobianPrep{SIG,B,O}(shadows, output_length) -end - function DI.prepare_jacobian( f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) - return EnzymeForwardOneArgJacobianPrep(SIG, valB, shadows, length(y)) + return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y)) end function DI.jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 698e73c67..f3d9e777c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -7,10 +7,10 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) - return DI.NoPushforwardPrep{SIG}() + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep(_sig) end function DI.value_and_pushforward( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 2f0a0a8c6..f7f3a6a9b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -48,6 +48,7 @@ end ## Pullback struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} y_example::Y # useful to create return activity end @@ -57,11 +58,11 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, backend, x, ty, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) - return EnzymeReverseOneArgPullbackPrep{SIG,typeof(y)}(y) + return EnzymeReverseOneArgPullbackPrep(_sig, y) end ### Out-of-place @@ -195,10 +196,10 @@ function DI.prepare_gradient( backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) - return DI.NoGradientPrep{SIG}() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end ### Enzyme gradient API (only constants) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index aca9494ca..7b36e748d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,6 +1,7 @@ ## Pullback struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} ty_copy::TY end @@ -11,11 +12,11 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f!, y, backend, x, ty, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) ty_copy = map(copy, ty) - return EnzymeReverseTwoArgPullbackPrep{SIG,typeof(ty_copy)}(ty_copy) + return EnzymeReverseTwoArgPullbackPrep(_sig, ty_copy) end function DI.value_and_pullback( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index f873d6c35..6a9c8f30d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -1,7 +1,7 @@ ## Pushforward struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} y_prototype::Y jvp_exe::E1 jvp_exe!::E1! @@ -13,9 +13,9 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -30,7 +30,7 @@ function DI.prepare_pushforward( jvp_exe! = make_function( jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationOneArgPushforwardPrep(SIG, y_prototype, jvp_exe, jvp_exe!) + return FastDifferentiationOneArgPushforwardPrep(_sig, y_prototype, jvp_exe, jvp_exe!) end function DI.pushforward( @@ -100,7 +100,7 @@ end ## Pullback struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} vjp_exe::E1 vjp_exe!::E1! end @@ -111,9 +111,9 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -128,7 +128,7 @@ function DI.prepare_pullback( vjp_exe! = make_function( vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationOneArgPullbackPrep(SIG, vjp_exe, vjp_exe!) + return FastDifferentiationOneArgPullbackPrep(_sig, vjp_exe, vjp_exe!) end function DI.pullback( @@ -198,7 +198,7 @@ end ## Derivative struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} y_prototype::Y der_exe::E1 der_exe!::E1! @@ -209,9 +209,9 @@ function DI.prepare_derivative( backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -223,7 +223,7 @@ function DI.prepare_derivative( der_vec_var = derivative(y_vec_var, x_var) der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgDerivativePrep(SIG, y_prototype, der_exe, der_exe!) + return FastDifferentiationOneArgDerivativePrep(_sig, y_prototype, der_exe, der_exe!) end function DI.derivative( @@ -283,7 +283,7 @@ end ## Gradient struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end @@ -293,9 +293,9 @@ function DI.prepare_gradient( backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -306,7 +306,7 @@ function DI.prepare_gradient( jac_var = jacobian(y_vec_var, x_vec_var) jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgGradientPrep(SIG, jac_exe, jac_exe!) + return FastDifferentiationOneArgGradientPrep(_sig, jac_exe, jac_exe!) end function DI.gradient( @@ -362,7 +362,7 @@ end ## Jacobian struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} y_prototype::Y jac_exe::E1 jac_exe!::E1! @@ -373,9 +373,9 @@ function DI.prepare_jacobian( backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -391,7 +391,7 @@ function DI.prepare_jacobian( end jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgJacobianPrep(SIG, y_prototype, jac_exe, jac_exe!) + return FastDifferentiationOneArgJacobianPrep(_sig, y_prototype, jac_exe, jac_exe!) end function DI.jacobian( @@ -446,7 +446,7 @@ end struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: DI.SecondDerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} y_prototype::Y derivative_prep::D der2_exe::E2 @@ -458,9 +458,9 @@ function DI.prepare_second_derivative( backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -476,7 +476,7 @@ function DI.prepare_second_derivative( derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) return FastDifferentiationAllocatingSecondDerivativePrep( - SIG, y_prototype, derivative_prep, der2_exe, der2_exe! + _sig, y_prototype, derivative_prep, der2_exe, der2_exe! ) end @@ -540,7 +540,7 @@ end ## HVP struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} - sig::Type{SIG} + sig::Val{SIG} hvp_exe::E2 hvp_exe!::E2! gradient_prep::E1 @@ -552,9 +552,9 @@ function DI.prepare_hvp( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -570,7 +570,7 @@ function DI.prepare_hvp( ) gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) - return FastDifferentiationHVPPrep(SIG, hvp_exe, hvp_exe!, gradient_prep) + return FastDifferentiationHVPPrep(_sig, hvp_exe, hvp_exe!, gradient_prep) end function DI.hvp( @@ -639,7 +639,7 @@ end ## Hessian struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} gradient_prep::G hess_exe::E2 hess_exe!::E2! @@ -650,9 +650,9 @@ function DI.prepare_hessian( backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -669,7 +669,7 @@ function DI.prepare_hessian( hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true) gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) - return FastDifferentiationHessianPrep(SIG, gradient_prep, hess_exe, hess_exe!) + return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end function DI.hessian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index 4e5e59c73..7904f8789 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -1,7 +1,7 @@ ## Pushforward struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} jvp_exe::E1 jvp_exe!::E1! end @@ -13,9 +13,9 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -31,7 +31,7 @@ function DI.prepare_pushforward( jvp_exe! = make_function( jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationTwoArgPushforwardPrep(SIG, jvp_exe, jvp_exe!) + return FastDifferentiationTwoArgPushforwardPrep(_sig, jvp_exe, jvp_exe!) end function DI.pushforward( @@ -102,7 +102,7 @@ end ## Pullback struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} vjp_exe::E1 vjp_exe!::E1! end @@ -114,9 +114,9 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, ty, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -132,7 +132,7 @@ function DI.prepare_pullback( vjp_exe! = make_function( vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationTwoArgPullbackPrep(SIG, vjp_exe, vjp_exe!) + return FastDifferentiationTwoArgPullbackPrep(_sig, vjp_exe, vjp_exe!) end function DI.pullback( @@ -208,7 +208,7 @@ end ## Derivative struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} der_exe::E1 der_exe!::E1! end @@ -219,9 +219,9 @@ function DI.prepare_derivative( backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -233,7 +233,7 @@ function DI.prepare_derivative( der_vec_var = derivative(y_vec_var, x_var) der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationTwoArgDerivativePrep(SIG, der_exe, der_exe!) + return FastDifferentiationTwoArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.value_and_derivative( @@ -295,7 +295,7 @@ end ## Jacobian struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end @@ -306,9 +306,9 @@ function DI.prepare_jacobian( backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -324,7 +324,7 @@ function DI.prepare_jacobian( end jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationTwoArgJacobianPrep(SIG, jac_exe, jac_exe!) + return FastDifferentiationTwoArgJacobianPrep(_sig, jac_exe, jac_exe!) end function DI.value_and_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index ba49b9f39..79b3d528a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -1,7 +1,7 @@ ## Pushforward struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -14,9 +14,9 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) cache = if x isa Number || y isa Number @@ -35,7 +35,7 @@ function DI.prepare_pushforward( backend.relstep end dir = backend.dir - return FiniteDiffOneArgPushforwardPrep(SIG, cache, relstep, absstep, dir) + return FiniteDiffOneArgPushforwardPrep(_sig, cache, relstep, absstep, dir) end function DI.pushforward( @@ -122,7 +122,7 @@ end ## Derivative struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -130,9 +130,9 @@ struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) cache = if y isa Number @@ -152,7 +152,7 @@ function DI.prepare_derivative( backend.relstep end dir = backend.dir - return FiniteDiffOneArgDerivativePrep(SIG, cache, relstep, absstep, dir) + return FiniteDiffOneArgDerivativePrep(_sig, cache, relstep, absstep, dir) end ### Scalar to scalar @@ -251,7 +251,7 @@ end ## Gradient struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -259,9 +259,9 @@ struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) df = zero(y) .* x @@ -277,7 +277,7 @@ function DI.prepare_gradient( backend.relstep end dir = backend.dir - return FiniteDiffGradientPrep(SIG, cache, relstep, absstep, dir) + return FiniteDiffGradientPrep(_sig, cache, relstep, absstep, dir) end function DI.gradient( @@ -339,7 +339,7 @@ end ## Jacobian struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -347,9 +347,9 @@ struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) x1 = similar(x) @@ -367,7 +367,7 @@ function DI.prepare_jacobian( backend.relstep end dir = backend.dir - return FiniteDiffOneArgJacobianPrep(SIG, cache, relstep, absstep, dir) + return FiniteDiffOneArgJacobianPrep(_sig, cache, relstep, absstep, dir) end function DI.jacobian( @@ -442,7 +442,7 @@ end ## Hessian struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} gradient_cache::C1 hessian_cache::C2 relstep_g::RG @@ -452,9 +452,9 @@ struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} end function DI.prepare_hessian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) df = zero(y) .* x @@ -473,7 +473,7 @@ function DI.prepare_hessian( absstep_g = isnothing(backend.absstep) ? relstep_g : backend.absstep absstep_h = isnothing(backend.absstep) ? relstep_h : backend.absstep return FiniteDiffHessianPrep( - SIG, gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h + _sig, gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h ) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index 12c6dffbd..a81c49249 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -1,7 +1,7 @@ ## Pushforward struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -15,9 +15,9 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) cache = if x isa Number nothing else @@ -34,7 +34,7 @@ function DI.prepare_pushforward( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgPushforwardPrep(SIG, cache, relstep, absstep, dir) + return FiniteDiffTwoArgPushforwardPrep(_sig, cache, relstep, absstep, dir) end function DI.value_and_pushforward( @@ -154,7 +154,7 @@ end ## Derivative struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -162,9 +162,14 @@ struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f!, + y, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) df = similar(y) cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_INPLACE) relstep = if isnothing(backend.relstep) @@ -178,7 +183,7 @@ function DI.prepare_derivative( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgDerivativePrep(SIG, cache, relstep, absstep, dir) + return FiniteDiffTwoArgDerivativePrep(_sig, cache, relstep, absstep, dir) end function DI.prepare!_derivative( @@ -272,7 +277,7 @@ end ## Jacobian struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -280,9 +285,14 @@ struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f!, + y, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x1 = similar(x) fx = similar(y) fx1 = similar(y) @@ -298,7 +308,7 @@ function DI.prepare_jacobian( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgJacobianPrep(SIG, cache, relstep, absstep, dir) + return FiniteDiffTwoArgJacobianPrep(_sig, cache, relstep, absstep, dir) end function DI.prepare!_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index fca32bdc8..69cc37c1f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -17,10 +17,10 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) - return DI.NoPushforwardPrep{SIG}() + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep(_sig) end function DI.pushforward( @@ -60,10 +60,10 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) - return DI.NoPullbackPrep{SIG}() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.pullback( @@ -98,10 +98,14 @@ end ## Gradient function DI.prepare_gradient( - f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) - return DI.NoGradientPrep{SIG}() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end function DI.gradient( @@ -155,10 +159,14 @@ end ## Jacobian function DI.prepare_jacobian( - f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, + backend::AutoFiniteDifferences, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) - return DI.NoJacobianPrep{SIG}() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoJacobianPrep(_sig) end function DI.jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index d46914708..bd5a3f3cd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -61,6 +61,8 @@ end ### Prepared struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + _t::Type{T} xdual_tmp::X contexts_dual::CD end @@ -71,9 +73,9 @@ function DI.prepare_pushforward( x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,B,C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) T = tag_type(f, backend, x) if DI.ismutable_array(x) xdual_tmp = make_dual_similar(T, x, tx) @@ -81,9 +83,7 @@ function DI.prepare_pushforward( xdual_tmp = nothing end contexts_dual = translate_toprep(Dual{T,eltype(x),B}, contexts) - return ForwardDiffOneArgPushforwardPrep{SIG,T,typeof(xdual_tmp),typeof(contexts_dual)}( - xdual_tmp, contexts_dual - ) + return ForwardDiffOneArgPushforwardPrep(_sig, T, xdual_tmp, contexts_dual) end function compute_ydual_onearg( @@ -180,6 +180,7 @@ end ## Derivative struct ForwardDiffOneArgDerivativePrep{SIG,E} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} pushforward_prep::E end @@ -215,11 +216,15 @@ end ### Prepared function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f::F, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...; strict) - return ForwardDiffOneArgDerivativePrep{SIG,typeof(pushforward_prep)}(pushforward_prep) + return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep) end function DI.value_and_derivative( @@ -354,7 +359,7 @@ end ### Prepared struct ForwardDiffGradientPrep{SIG,C,CD} <: DI.GradientPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end @@ -364,14 +369,14 @@ function DI.prepare_gradient( backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) config = GradientConfig(nothing, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffGradientPrep(SIG, config, contexts_dual) + return ForwardDiffGradientPrep(_sig, config, contexts_dual) end function DI.value_and_gradient!( @@ -526,20 +531,24 @@ end ### Prepared struct ForwardDiffOneArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f::F, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) config = JacobianConfig(nothing, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffOneArgJacobianPrep(SIG, config, contexts_dual) + return ForwardDiffOneArgJacobianPrep(_sig, config, contexts_dual) end function DI.value_and_jacobian!( @@ -620,10 +629,14 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f::F, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) - return DI.NoSecondDerivativePrep{SIG}() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoSecondDerivativePrep(_sig) end function DI.second_derivative( @@ -786,23 +799,27 @@ end ### Prepared struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} array_config::C1 result_config::C2 contexts_dual::CD end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f::F, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) result = HessianResult(x) array_config = HessianConfig(nothing, x, chunk, tag) result_config = HessianConfig(nothing, result, x, chunk, tag) contexts_dual = translate_toprep(dual_type(array_config), contexts) - return ForwardDiffHessianPrep(SIG, array_config, result_config, contexts_dual) + return ForwardDiffHessianPrep(_sig, array_config, result_config, contexts_dual) end function DI.hessian!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index ab122f175..a4c7aea44 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -1,6 +1,8 @@ ## Pushforward struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + _t::Type{T} xdual_tmp::X ydual_tmp::Y contexts_dual::CD @@ -13,18 +15,14 @@ function DI.prepare_pushforward( x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,B,C} - SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) T = tag_type(f!, backend, x) xdual_tmp = make_dual_similar(T, x, tx) ydual_tmp = make_dual_similar(T, y, tx) # tx only for batch size contexts_dual = translate_toprep(eltype(xdual_tmp), contexts) - return ForwardDiffTwoArgPushforwardPrep{ - SIG,T,typeof(xdual_tmp),typeof(ydual_tmp),typeof(contexts_dual) - }( - xdual_tmp, ydual_tmp, contexts_dual - ) + return ForwardDiffTwoArgPushforwardPrep(_sig, T, xdual_tmp, ydual_tmp, contexts_dual) end function compute_ydual_twoarg( @@ -180,7 +178,7 @@ end ### Prepared struct ForwardDiffTwoArgDerivativePrep{SIG,C,CD} <: DI.DerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end @@ -191,13 +189,13 @@ function DI.prepare_derivative( backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) tag = get_tag(f!, backend, x) config = DerivativeConfig(nothing, y, x, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgDerivativePrep(SIG, config, contexts_dual) + return ForwardDiffTwoArgDerivativePrep(_sig, config, contexts_dual) end function DI.prepare!_derivative( @@ -374,7 +372,7 @@ end ### Prepared struct ForwardDiffTwoArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end @@ -385,14 +383,14 @@ function DI.prepare_jacobian( backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f!, backend, x) config = JacobianConfig(nothing, y, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgJacobianPrep(SIG, config, contexts_dual) + return ForwardDiffTwoArgJacobianPrep(_sig, config, contexts_dual) end function DI.prepare!_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 59f081340..1e880c395 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -3,7 +3,7 @@ # Contains either a single pre-allocated initial TPS # or a vector of pre-allocated TPSs. struct GTPSAOneArgPushforwardPrep{SIG,X} <: DI.PushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} xt::X end @@ -13,9 +13,9 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Constant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,D,C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. if D != Nothing @@ -31,7 +31,7 @@ function DI.prepare_pushforward( for i in eachindex(xt) xt[i] = TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}(; use=d) end - return GTPSAOneArgPushforwardPrep(SIG, xt) + return GTPSAOneArgPushforwardPrep(_sig, xt) end end @@ -112,15 +112,15 @@ end ## Gradient # Contains a vector of pre-allocated TPSs. struct GTPSAOneArgGradientPrep{SIG,X} <: DI.GradientPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} xt::X end # Unlike JVP, this requires us to use all variables function DI.prepare_gradient( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) ) where {D,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -133,7 +133,7 @@ function DI.prepare_gradient( xt[i][j] = 1 j += 1 end - return GTPSAOneArgGradientPrep(SIG, xt) + return GTPSAOneArgGradientPrep(_sig, xt) end function DI.gradient( @@ -195,15 +195,15 @@ end ## Jacobian # Contains a vector of pre-allocated TPSs struct GTPSAOneArgJacobianPrep{SIG,X} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} xt::X end # To materialize the entire Jacobian we use all variables function DI.prepare_jacobian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) ) where {D,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -218,7 +218,7 @@ function DI.prepare_jacobian( xt[i][j] = 1 j += 1 end - return GTPSAOneArgJacobianPrep(SIG, xt) + return GTPSAOneArgJacobianPrep(_sig, xt) end function DI.jacobian( @@ -282,14 +282,14 @@ end ## Second derivative # Contains single pre-allocated TPS struct GTPSAOneArgSecondDerivativePrep{SIG,X} <: DI.SecondDerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} xt::X end function DI.prepare_second_derivative( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) ) where {D,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -297,7 +297,7 @@ function DI.prepare_second_derivative( end xt = TPS{promote_type(typeof(x), Float64)}(; use=d) xt[1] = 1 # Set slope - return GTPSAOneArgSecondDerivativePrep(SIG, xt) + return GTPSAOneArgSecondDerivativePrep(_sig, xt) end function DI.second_derivative( @@ -411,15 +411,15 @@ end # Stores allocated array of TPS and an array for the monomial coefficient # indexing in GTPSA.cycle! (which is used if a Descriptor is specified) struct GTPSAOneArgHessianPrep{SIG,X,M} <: DI.HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} xt::X m::M end function DI.prepare_hessian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false + f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) ) where {D,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor m = Vector{UInt8}(undef, length(x)) @@ -442,7 +442,7 @@ function DI.prepare_hessian( j += 1 end - return GTPSAOneArgHessianPrep(SIG, xt, m) + return GTPSAOneArgHessianPrep(_sig, xt, m) end function DI.hessian( @@ -546,7 +546,7 @@ function DI.value_gradient_and_hessian!( end struct GTPSAOneArgHVPPrep{SIG,E,H} <: DI.HVPPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} hessprep::E hess::H end @@ -557,13 +557,13 @@ function DI.prepare_hvp( x, tx::NTuple, contexts::Vararg{DI.Constant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) hessprep = DI.prepare_hessian(f, backend, x; strict) fc = DI.with_contexts(f, contexts...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) - return GTPSAOneArgHVPPrep(SIG, hessprep, hess) + return GTPSAOneArgHVPPrep(_sig, hessprep, hess) end function DI.hvp( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index 57833fdd2..647662672 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -5,7 +5,7 @@ # # Output: Contains a vector of pre-allocated TPSs struct GTPSATwoArgPushforwardPrep{SIG,X,Y} <: DI.PushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} xt::X yt::Y end @@ -17,9 +17,9 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Constant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,D,C} - SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. if D != Nothing @@ -41,7 +41,7 @@ function DI.prepare_pushforward( for i in eachindex(yt) yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) end - return GTPSATwoArgPushforwardPrep(SIG, xt, yt) + return GTPSATwoArgPushforwardPrep(_sig, xt, yt) end function DI.pushforward( @@ -120,15 +120,15 @@ end # Input: Contains a vector of pre-allocated TPSs # Output: Contains a vector of pre-allocated TPSs struct GTPSATwoArgJacobianPrep{SIG,X,Y} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} xt::X yt::Y end function DI.prepare_jacobian( - f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Bool=false + f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) ) where {D,C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -150,7 +150,7 @@ function DI.prepare_jacobian( yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) end - return GTPSATwoArgJacobianPrep(SIG, xt, yt) + return GTPSATwoArgJacobianPrep(_sig, xt, yt) end function DI.jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index a8760d57a..b84b0ccb3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,7 +1,7 @@ ## Pullback struct MooncakeOneArgPullbackPrep{SIG,Tcache,DY} <: DI.PullbackPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::Tcache dy_righttype::DY end @@ -12,16 +12,16 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) - prep = MooncakeOneArgPullbackPrep(SIG, cache, dy_righttype) + prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype) DI.value_and_pullback(f, prep, backend, x, ty, contexts...) return prep end @@ -122,19 +122,19 @@ end ## Gradient struct MooncakeGradientPrep{SIG,Tcache} <: DI.GradientPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::Tcache end function DI.prepare_gradient( - f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) - prep = MooncakeGradientPrep(SIG, cache) + prep = MooncakeGradientPrep(_sig, cache) DI.value_and_gradient(f, prep, backend, x, contexts...) return prep end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 11d8ecd4e..1f5247706 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,5 +1,5 @@ struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} cache::Tcache dy_righttype::DY target_function::F @@ -12,9 +12,9 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = DI.signature(f!, y, backend, x, ty, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) target_function = function (f!, y, x, contexts...) f!(y, x, contexts...) return y @@ -30,7 +30,7 @@ function DI.prepare_pullback( silence_debug_messages=config.silence_debug_messages, ) dy_righttype_after = zero_tangent(y) - prep = MooncakeTwoArgPullbackPrep(SIG, cache, dy_righttype_after, target_function) + prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function) DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) return prep end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index eee0beb5c..9aae16830 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -3,7 +3,7 @@ struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: PolyesterForwardDiffOneArgPushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} single_threaded_prep::P end @@ -13,13 +13,13 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward( f, single_threaded(backend), x, tx, contexts...; strict ) - return PolyesterForwardDiffOneArgPushforwardPrep(SIG, single_threaded_prep) + return PolyesterForwardDiffOneArgPushforwardPrep(_sig, single_threaded_prep) end function DI.value_and_pushforward( @@ -83,7 +83,7 @@ end ## Derivative struct PolyesterForwardDiffOneArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} single_threaded_prep::P end @@ -92,13 +92,13 @@ function DI.prepare_derivative( backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_derivative( f, single_threaded(backend), x, contexts...; strict ) - return PolyesterForwardDiffOneArgDerivativePrep(SIG, single_threaded_prep) + return PolyesterForwardDiffOneArgDerivativePrep(_sig, single_threaded_prep) end function DI.value_and_derivative( @@ -158,7 +158,7 @@ end ## Gradient struct PolyesterForwardDiffGradientPrep{SIG,chunksize,P} <: DI.GradientPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end @@ -168,9 +168,9 @@ function DI.prepare_gradient( backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {chunksize,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else @@ -179,7 +179,7 @@ function DI.prepare_gradient( single_threaded_prep = DI.prepare_gradient( f, single_threaded(backend), x, contexts...; strict ) - return PolyesterForwardDiffGradientPrep(SIG, chunk, single_threaded_prep) + return PolyesterForwardDiffGradientPrep(_sig, chunk, single_threaded_prep) end function DI.value_and_gradient!( @@ -249,7 +249,7 @@ end ## Jacobian struct PolyesterForwardDiffOneArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end @@ -259,9 +259,9 @@ function DI.prepare_jacobian( backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {chunksize,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else @@ -270,7 +270,7 @@ function DI.prepare_jacobian( single_threaded_prep = DI.prepare_jacobian( f, single_threaded(backend), x, contexts...; strict ) - return PolyesterForwardDiffOneArgJacobianPrep(SIG, chunk, single_threaded_prep) + return PolyesterForwardDiffOneArgJacobianPrep(_sig, chunk, single_threaded_prep) end function DI.value_and_jacobian!( @@ -340,7 +340,7 @@ end ## Hessian struct PolyesterForwardDiffHessianPrep{SIG,P} <: DI.HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} single_threaded_prep::P end @@ -349,13 +349,13 @@ function DI.prepare_hessian( backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_hessian( f, single_threaded(backend), x, contexts...; strict ) - return PolyesterForwardDiffHessianPrep(SIG, single_threaded_prep) + return PolyesterForwardDiffHessianPrep(_sig, single_threaded_prep) end function DI.hessian( @@ -416,7 +416,7 @@ end ## Second derivative struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG,P} <: DI.SecondDerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} single_threaded_prep::P end @@ -425,13 +425,13 @@ function DI.prepare_second_derivative( backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_second_derivative( f, single_threaded(backend), x, contexts...; strict ) - return PolyesterForwardDiffOneArgSecondDerivativePrep(SIG, single_threaded_prep) + return PolyesterForwardDiffOneArgSecondDerivativePrep(_sig, single_threaded_prep) end function DI.value_derivative_and_second_derivative( diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 1bd32b5b8..818cc1e71 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -1,7 +1,7 @@ ## Pushforward struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} single_threaded_prep::P end @@ -12,13 +12,13 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward( f!, y, single_threaded(backend), x, tx, contexts... ) - return PolyesterForwardDiffTwoArgPushforwardPrep(SIG, single_threaded_prep) + return PolyesterForwardDiffTwoArgPushforwardPrep(_sig, single_threaded_prep) end function DI.value_and_pushforward( @@ -86,7 +86,7 @@ end ## Derivative struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} single_threaded_prep::P end @@ -96,13 +96,13 @@ function DI.prepare_derivative( backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_derivative( f!, y, single_threaded(backend), x, contexts... ) - return PolyesterForwardDiffTwoArgDerivativePrep(SIG, single_threaded_prep) + return PolyesterForwardDiffTwoArgDerivativePrep(_sig, single_threaded_prep) end function DI.value_and_derivative( @@ -166,7 +166,7 @@ end ## Jacobian struct PolyesterForwardDiffTwoArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end @@ -177,9 +177,9 @@ function DI.prepare_jacobian( backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {chunksize,C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else @@ -188,7 +188,7 @@ function DI.prepare_jacobian( single_threaded_prep = DI.prepare_jacobian( f!, y, single_threaded(backend), x, contexts... ) - return PolyesterForwardDiffTwoArgJacobianPrep(SIG, chunk, single_threaded_prep) + return PolyesterForwardDiffTwoArgJacobianPrep(_sig, chunk, single_threaded_prep) end function DI.value_and_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index e74ee00ca..2a98c70ef 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -6,10 +6,10 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) - return DI.NoPullbackPrep{SIG}() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.value_and_pullback( @@ -79,21 +79,21 @@ end ### Without contexts struct ReverseDiffGradientPrep{SIG,C,T} <: DI.GradientPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} config::C tape::T end function DI.prepare_gradient( - f, backend::AutoReverseDiff{compile}, x; strict::Bool=false + f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) ) where {compile} - SIG = DI.signature(f, backend, x) + _sig = DI.signature(f, backend, x) if compile tape = ReverseDiff.compile(GradientTape(f, x)) - return ReverseDiffGradientPrep(SIG, nothing, tape) + return ReverseDiffGradientPrep(_sig, nothing, tape) else config = GradientConfig(x) - return ReverseDiffGradientPrep(SIG, config, nothing) + return ReverseDiffGradientPrep(_sig, config, nothing) end end @@ -149,11 +149,11 @@ end ### With contexts function DI.prepare_gradient( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) config = GradientConfig(x) - return ReverseDiffGradientPrep(SIG, config, nothing) + return ReverseDiffGradientPrep(_sig, config, nothing) end function DI.value_and_gradient!( @@ -216,21 +216,21 @@ end ### Without contexts struct ReverseDiffOneArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} config::C tape::T end function DI.prepare_jacobian( - f, backend::AutoReverseDiff{compile}, x; strict::Bool=false + f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) ) where {compile} - SIG = DI.signature(f, backend, x; strict) + _sig = DI.signature(f, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f, x)) - return ReverseDiffOneArgJacobianPrep(SIG, nothing, tape) + return ReverseDiffOneArgJacobianPrep(_sig, nothing, tape) else config = JacobianConfig(x) - return ReverseDiffOneArgJacobianPrep(SIG, config, nothing) + return ReverseDiffOneArgJacobianPrep(_sig, config, nothing) end end @@ -286,11 +286,11 @@ end ### With contexts function DI.prepare_jacobian( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) config = JacobianConfig(x) - return ReverseDiffOneArgJacobianPrep(SIG, config, nothing) + return ReverseDiffOneArgJacobianPrep(_sig, config, nothing) end function DI.value_and_jacobian!( @@ -353,23 +353,23 @@ end ### Without contexts struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} gradient_prep::G hessian_config::HC hessian_tape::HT end function DI.prepare_hessian( - f, backend::AutoReverseDiff{compile}, x; strict::Bool=false + f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) ) where {compile} - SIG = DI.signature(f, backend, x; strict) + _sig = DI.signature(f, backend, x; strict) gradient_prep = DI.prepare_gradient(f, backend, x) if compile hessian_tape = ReverseDiff.compile(HessianTape(f, x)) - return ReverseDiffHessianPrep(SIG, gradient_prep, nothing, hessian_tape) + return ReverseDiffHessianPrep(_sig, gradient_prep, nothing, hessian_tape) else hessian_config = HessianConfig(x) - return ReverseDiffHessianPrep(SIG, gradient_prep, hessian_config, nothing) + return ReverseDiffHessianPrep(_sig, gradient_prep, hessian_config, nothing) end end @@ -418,12 +418,12 @@ end ### With contexts function DI.prepare_hessian( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) hessian_config = HessianConfig(x) - return ReverseDiffHessianPrep(SIG, gradient_prep, hessian_config, nothing) + return ReverseDiffHessianPrep(_sig, gradient_prep, hessian_config, nothing) end function DI.hessian!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 2bf3fa7e9..95701efc4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -7,10 +7,10 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, x, ty, contexts...; strict) - return DI.NoPullbackPrep{SIG}() + _sig = DI.signature(f!, y, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end ### Array in @@ -134,21 +134,21 @@ end ### Without contexts struct ReverseDiffTwoArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} config::C tape::T end function DI.prepare_jacobian( - f!, y, backend::AutoReverseDiff{compile}, x; strict::Bool=false + f!, y, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) ) where {compile} - SIG = DI.signature(f!, y, backend, x; strict) + _sig = DI.signature(f!, y, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f!, y, x)) - return ReverseDiffTwoArgJacobianPrep(SIG, nothing, tape) + return ReverseDiffTwoArgJacobianPrep(_sig, nothing, tape) else config = JacobianConfig(y, x) - return ReverseDiffTwoArgJacobianPrep(SIG, config, nothing) + return ReverseDiffTwoArgJacobianPrep(_sig, config, nothing) end end @@ -206,11 +206,16 @@ end ### With contexts function DI.prepare_jacobian( - f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f!, + y, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f!, y, backend, x, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) config = JacobianConfig(y, x) - return ReverseDiffTwoArgJacobianPrep(SIG, config, nothing) + return ReverseDiffTwoArgJacobianPrep(_sig, config, nothing) end function DI.value_and_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl index 1bc99a3e7..203992008 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl @@ -28,7 +28,7 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite if DI.pushforward_performance(backend) isa DI.PushforwardFast p = similar(y) prep = DI.prepare_pushforward_same_point( - f, backend, x, (DI.basis(x, first(eachindex(x))),) + f, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) ) for (kj, j) in enumerate(eachindex(x)) pushforward!(f, (p,), prep, backend, x, (DI.basis(x, j),)) @@ -42,7 +42,7 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite else p = similar(x) prep = DI.prepare_pullback_same_point( - f, backend, x, (DI.basis(y, first(eachindex(y))),) + f, backend, x, (DI.basis(y, first(eachindex(y))),); strict=Val(true) ) for (ki, i) in enumerate(eachindex(y)) pullback!(f, (p,), prep, backend, x, (DI.basis(y, i),)) @@ -64,7 +64,7 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{ if DI.pushforward_performance(backend) isa DI.PushforwardFast p = similar(y) prep = DI.prepare_pushforward_same_point( - f!, y, backend, x, (DI.basis(x, first(eachindex(x))),) + f!, y, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) ) for (kj, j) in enumerate(eachindex(x)) pushforward!(f!, y, (p,), prep, backend, x, (DI.basis(x, j),)) @@ -78,7 +78,7 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{ else p = similar(x) prep = DI.prepare_pullback_same_point( - f!, y, backend, x, (DI.basis(y, first(eachindex(y))),) + f!, y, backend, x, (DI.basis(y, first(eachindex(y))),); strict=Val(true) ) for (ki, i) in enumerate(eachindex(y)) pullback!(f!, y, (p,), prep, backend, x, (DI.basis(y, i),)) @@ -98,7 +98,9 @@ function ADTypes.hessian_sparsity(f, x, detector::DI.DenseSparsityDetector{:iter n = length(x) I, J = Int[], Int[] p = similar(x) - prep = DI.prepare_hvp_same_point(f, backend, x, (DI.basis(x, first(eachindex(x))),)) + prep = DI.prepare_hvp_same_point( + f, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) + ) for (kj, j) in enumerate(eachindex(x)) hvp!(f, (p,), prep, backend, x, (DI.basis(x, j),)) for ki in LinearIndices(p) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index b26fd2076..1bc4d2b7d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -8,7 +8,7 @@ struct SparseHessianPrep{ E2<:DI.HVPPrep, E1<:DI.GradientPrep, } <: DI.HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -26,7 +26,7 @@ SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result) ## Hessian, one argument function DI.prepare_hessian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {F,C} dense_backend = dense_ad(backend) sparsity = DI.hessian_sparsity_with_contexts( @@ -50,9 +50,9 @@ function _prepare_sparse_hessian_aux( backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Bool, + strict::Val, ) where {B,F,C} - SIG = DI.signature(f, backend, x, contexts...; strict) + _sig = DI.signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 5411693c5..59802f066 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -9,7 +9,7 @@ struct PushforwardSparseJacobianPrep{ R<:AbstractVector{<:NTuple}, E<:DI.PushforwardPrep, } <: SparseJacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -27,7 +27,7 @@ struct PullbackSparseJacobianPrep{ R<:AbstractVector{<:NTuple}, E<:DI.PullbackPrep, } <: SparseJacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -37,7 +37,7 @@ struct PullbackSparseJacobianPrep{ end function DI.prepare_jacobian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {F,C} dense_backend = dense_ad(backend) y = f(x, map(DI.unwrap, contexts)...) @@ -46,7 +46,7 @@ function DI.prepare_jacobian( end function DI.prepare_jacobian( - f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {F,C} dense_backend = dense_ad(backend) perf = DI.pushforward_performance(dense_backend) @@ -60,7 +60,7 @@ function _prepare_sparse_jacobian_aux( backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Bool, + strict::Val, ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -96,9 +96,9 @@ function _prepare_sparse_jacobian_aux_aux( backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Bool, + strict::Val, ) where {B,FY,C} - SIG = DI.signature(f_or_f!y..., backend, x, contexts...; strict) + _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) @@ -130,9 +130,9 @@ function _prepare_sparse_jacobian_aux_aux( backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Bool, + strict::Val, ) where {B,FY,C} - SIG = DI.signature(f_or_f!y..., backend, x, contexts...; strict) + _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = row_groups(coloring_result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index e3e375494..ed8585ad8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -13,7 +13,7 @@ struct MixedModeSparseJacobianPrep{ Ef<:DI.PushforwardPrep, Er<:DI.PullbackPrep, } <: SparseJacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} batch_size_settings_forward::BSf batch_size_settings_reverse::BSr coloring_result::C @@ -32,7 +32,7 @@ function DI.prepare_jacobian( backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} y = f(x, map(DI.unwrap, contexts)...) return _prepare_mixed_sparse_jacobian_aux(y, (f,), backend, x, contexts...; strict) @@ -44,7 +44,7 @@ function DI.prepare_jacobian( backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} return _prepare_mixed_sparse_jacobian_aux(y, (f!, y), backend, x, contexts...; strict) end @@ -55,7 +55,7 @@ function _prepare_mixed_sparse_jacobian_aux( backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Bool, + strict::Val, ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -96,9 +96,9 @@ function _prepare_mixed_sparse_jacobian_aux_aux( backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Bool, + strict::Val, ) where {Bf,Br,FY,C} - SIG = DI.signature(f_or_f!y..., backend, x, contexts...; strict) + _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) Nf, Af = batch_size_settings_forward.N, batch_size_settings_forward.A Nr, Ar = batch_size_settings_reverse.N, batch_size_settings_reverse.A diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index dc7e63263..a2ad45c7e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -10,7 +10,7 @@ DI.inplace_support(::AutoTracker) = DI.InPlaceNotSupported() ## Pullback struct TrackerPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} y::Y pb::PB end @@ -21,10 +21,10 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) - return DI.NoPullbackPrep{SIG}() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( @@ -35,9 +35,10 @@ function DI.prepare_pullback_same_point( ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + _sig = DI.signature(f, prep, backend, x, ty, contexts...) DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) - return TrackerPullbackPrepSamePoint(y, pb) + return TrackerPullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( @@ -64,6 +65,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy data(first(pb(dy))) @@ -79,6 +81,7 @@ function DI.pullback( ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy data(first(pb(dy))) @@ -89,9 +92,14 @@ end ## Gradient function DI.prepare_gradient( - f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant,C}; + strict::Val=Val(false), ) where {C} - return DI.NoGradientPrep() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end function DI.value_and_gradient( @@ -101,6 +109,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end @@ -112,6 +121,7 @@ function DI.gradient( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) end @@ -124,6 +134,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end @@ -136,6 +147,7 @@ function DI.gradient!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 4e87d29eb..983af3748 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -22,7 +22,7 @@ translate(c::DI.Cache) = Buffer(DI.unwrap(c)) ## Pullback struct ZygotePullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} y::Y pb::PB end @@ -33,10 +33,10 @@ function DI.prepare_pullback( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, ty, contexts...; strict) - return DI.NoPullbackPrep{SIG}() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( @@ -46,12 +46,12 @@ function DI.prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - SIG = DI.signature(f, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, backend, x, ty, contexts...; strict) y, pb = pullback(f, x, map(translate, contexts)...) - return ZygotePullbackPrepSamePoint(SIG, y, pb) + return ZygotePullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( @@ -105,10 +105,10 @@ end ## Gradient function DI.prepare_gradient( - f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) - return DI.NoGradientPrep{SIG}() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end function DI.value_and_gradient( @@ -145,10 +145,10 @@ end ## Jacobian function DI.prepare_jacobian( - f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Bool=false + f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) - return DI.NoJacobianPrep{SIG}() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoJacobianPrep(_sig) end function DI.value_and_jacobian( @@ -189,7 +189,7 @@ end # Beware, this uses ForwardDiff for the inner differentiation struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} fd_prep::P end @@ -199,13 +199,13 @@ function DI.prepare_hvp( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, tx, contexts...; strict) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) fd_prep = DI.prepare_hvp( f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...; strict ) - return ZygoteHVPPrep(SIG, fd_prep) + return ZygoteHVPPrep(_sig, fd_prep) end function DI.hvp( @@ -281,10 +281,10 @@ function DI.prepare_hessian( backend::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {C} - SIG = DI.signature(f, backend, x, contexts...; strict) - return DI.NoHessianPrep{SIG}() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoHessianPrep(_sig) end function DI.hessian( diff --git a/DifferentiationInterface/src/fallbacks/change_prep.jl b/DifferentiationInterface/src/fallbacks/change_prep.jl index 8736deb40..bf5202c9e 100644 --- a/DifferentiationInterface/src/fallbacks/change_prep.jl +++ b/DifferentiationInterface/src/fallbacks/change_prep.jl @@ -96,7 +96,7 @@ for op in [ x, seed::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} prep = $prep_op(f, backend, x, seed, contexts...; strict) return $prep_op_same_point(f, prep, backend, x, seed, contexts...) @@ -136,7 +136,7 @@ for op in [ x, seed::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} prep = $prep_op(f!, y, backend, x, seed, contexts...; strict) return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index 5e6debc37..bfec092ed 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -45,25 +45,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $op!(f, result, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $val_and_op!(f, result, prep, backend, x, contexts...) end op == :gradient && continue @@ -71,25 +71,25 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) return $op(f!, y, prep, backend, x, contexts...) end @eval function $op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) return $op!(f!, y, result, prep, backend, x, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) return $val_and_op(f!, y, prep, backend, x, contexts...) end @eval function $val_and_op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) return $val_and_op!(f!, y, result, prep, backend, x, contexts...) end @@ -98,25 +98,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $op!(f, result2, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=true) + prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) end @@ -124,7 +124,7 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) return $op(f, prep, backend, x, seed, contexts...) end @eval function $op!( @@ -135,13 +135,13 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) return $op!(f, result, prep, backend, x, seed, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) return $val_and_op(f, prep, backend, x, seed, contexts...) end @@ -154,7 +154,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) return $val_and_op!(f, result, prep, backend, x, seed, contexts...) end elseif op == :hvp @@ -167,7 +167,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) return $val_and_op!( f, result1, result2, prep, backend, x, seed, contexts... ) @@ -179,7 +179,7 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) return $op(f!, y, prep, backend, x, seed, contexts...) end @eval function $op!( @@ -191,13 +191,13 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) return $op!(f!, y, result, prep, backend, x, seed, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) return $val_and_op(f!, y, prep, backend, x, seed, contexts...) end @eval function $val_and_op!( @@ -209,7 +209,7 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=true) + prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) return $val_and_op!(f!, y, result, prep, backend, x, seed, contexts...) end end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 25680a7f8..f0329a639 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_derivative(f, backend, x, [contexts...]; strict=false) -> prep - prepare_derivative(f!, y, backend, x, [contexts...]; strict=false) -> prep + prepare_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_derivative(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("derivative"; inplace=true)) """ @@ -59,25 +59,31 @@ function derivative! end ## Preparation struct PushforwardDerivativePrep{SIG,E<:PushforwardPrep} <: DerivativePrep{SIG} + _sig::Val{SIG} pushforward_prep::E end function prepare_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} - SIG = signature(f, backend, x, contexts...; strict) + _sig = signature(f, backend, x, contexts...; strict) pushforward_prep = prepare_pushforward(f, backend, x, (one(x),), contexts...; strict) - return PushforwardDerivativePrep{SIG,typeof(pushforward_prep)}(pushforward_prep) + return PushforwardDerivativePrep(_sig, pushforward_prep) end function prepare_derivative( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false + f!::F, + y, + backend::AbstractADType, + x, + contexts::Vararg{Context,C}; + strict::Val=Val(false), ) where {F,C} - SIG = signature(f!, y, backend, x, contexts...; strict) + _sig = signature(f!, y, backend, x, contexts...; strict) pushforward_prep = prepare_pushforward( f!, y, backend, x, (one(x),), contexts...; strict ) - return PushforwardDerivativePrep{SIG,typeof(pushforward_prep)}(pushforward_prep) + return PushforwardDerivativePrep(_sig, pushforward_prep) end ## One argument diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 223c3d2e5..e9481c9c5 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_gradient(f, backend, x, [contexts...]; strict=false) -> prep + prepare_gradient(f, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("gradient")) """ @@ -53,16 +53,18 @@ function gradient! end ## Preparation struct PullbackGradientPrep{SIG,Y,E<:PullbackPrep} <: GradientPrep{SIG} + _sig::Val{SIG} + y::Y pullback_prep::E end function prepare_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} - SIG = signature(f, backend, x, contexts...; strict) + _sig = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? pullback_prep = prepare_pullback(f, backend, x, (one(typeof(y)),), contexts...; strict) - return PullbackGradientPrep{SIG,typeof(y),typeof(pullback_prep)}(pullback_prep) + return PullbackGradientPrep(_sig, y, pullback_prep) end ## One argument diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index f1079a81f..3b4f5fff9 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_jacobian(f, backend, x, [contexts...]; strict=false) -> prep - prepare_jacobian(f!, y, backend, x, [contexts...]; strict=false) -> prep + prepare_jacobian(f, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_jacobian(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("jacobian"; inplace=true)) """ @@ -67,7 +67,7 @@ struct PushforwardJacobianPrep{ R<:AbstractVector{<:NTuple}, E<:PushforwardPrep, } <: StandardJacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -81,7 +81,7 @@ struct PullbackJacobianPrep{ R<:AbstractVector{<:NTuple}, E<:PullbackPrep, } <: StandardJacobianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -89,7 +89,7 @@ struct PullbackJacobianPrep{ end function prepare_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} y = f(x, map(unwrap, contexts)...) perf = pushforward_performance(backend) @@ -106,7 +106,12 @@ function prepare_jacobian( end function prepare_jacobian( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false + f!::F, + y, + backend::AbstractADType, + x, + contexts::Vararg{Context,C}; + strict::Val=Val(false), ) where {F,C} perf = pushforward_performance(backend) # type-unstable @@ -129,9 +134,9 @@ function _prepare_jacobian_aux( backend::AbstractADType, x, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {B,FY,C} - SIG = signature(f_or_f!y..., backend, x, contexts...; strict) + _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] batched_seeds = [ @@ -142,7 +147,7 @@ function _prepare_jacobian_aux( f_or_f!y..., backend, x, batched_seeds[1], contexts...; strict ) return PushforwardJacobianPrep( - SIG, batch_size_settings, batched_seeds, batched_results, pushforward_prep + _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep ) end @@ -154,9 +159,9 @@ function _prepare_jacobian_aux( backend::AbstractADType, x, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {B,FY,C} - SIG = signature(f_or_f!y..., backend, x, contexts...; strict) + _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(y, ind) for ind in eachindex(y)] batched_seeds = [ @@ -167,7 +172,7 @@ function _prepare_jacobian_aux( f_or_f!y..., backend, x, batched_seeds[1], contexts...; strict ) return PullbackJacobianPrep( - SIG, batch_size_settings, batched_seeds, batched_results, pullback_prep + _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep ) end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index c7974686c..1837ef7d3 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_pullback(f, backend, x, ty, [contexts...]; strict=false) -> prep - prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=false) -> prep + prepare_pullback(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep + prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("pullback"; inplace=true)) """ @@ -17,8 +17,8 @@ $(docstring_prepare!("pullback")) function prepare!_pullback end """ - prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=false) -> prep_same - prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=false) -> prep_same + prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same + prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same $(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ @@ -86,6 +86,7 @@ function pullback! end ## Preparation struct PushforwardPullbackPrep{SIG,E} <: PullbackPrep{SIG} + _sig::Val{SIG} pushforward_prep::E end @@ -95,7 +96,7 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} return _prepare_pullback_aux( pullback_performance(backend), f, backend, x, ty, contexts...; strict @@ -109,7 +110,7 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} return _prepare_pullback_aux( pullback_performance(backend), f!, y, backend, x, ty, contexts...; strict @@ -123,12 +124,12 @@ function _prepare_pullback_aux( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f, backend, x, ty, contexts...; strict) + _sig = signature(f, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) pushforward_prep = prepare_pushforward(f, backend, x, (dx,), contexts...; strict) - return PushforwardPullbackPrep{SIG,typeof(pushforward_prep)}(pushforward_prep) + return PushforwardPullbackPrep(_sig, pushforward_prep) end function _prepare_pullback_aux( @@ -139,12 +140,12 @@ function _prepare_pullback_aux( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f!, y, backend, x, ty, contexts...; strict) + _sig = signature(f!, y, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) pushforward_prep = prepare_pushforward(f!, y, backend, x, (dx,), contexts...; strict) - return PushforwardPullbackPrep{SIG,typeof(pushforward_prep)}(pushforward_prep) + return PushforwardPullbackPrep(_sig, pushforward_prep) end ## One argument diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index dd6f5cea8..304208dc5 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -1,8 +1,8 @@ ## Docstrings """ - prepare_pushforward(f, backend, x, tx, [contexts...]; strict=false) -> prep - prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=false) -> prep + prepare_pushforward(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep + prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("pushforward"; inplace=true)) """ @@ -17,8 +17,8 @@ $(docstring_prepare!("pushforward")) function prepare!_pushforward end """ - prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=false) -> prep_same - prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=false) -> prep_same + prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same + prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same $(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ @@ -86,6 +86,7 @@ function pushforward! end ## Preparation struct PullbackPushforwardPrep{SIG,E} <: PushforwardPrep{SIG} + _sig::Val{SIG} pullback_prep::E end @@ -95,7 +96,7 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} return _prepare_pushforward_aux( pushforward_performance(backend), f, backend, x, tx, contexts...; strict @@ -109,7 +110,7 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} return _prepare_pushforward_aux( pushforward_performance(backend), f!, y, backend, x, tx, contexts...; strict @@ -123,13 +124,13 @@ function _prepare_pushforward_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) + _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) pullback_prep = prepare_pullback(f, backend, x, (dy,), contexts...; strict) - return PullbackPushforwardPrep{SIG,typeof(pullback_prep)}(pullback_prep) + return PullbackPushforwardPrep(_sig, pullback_prep) end function _prepare_pushforward_aux( @@ -140,12 +141,12 @@ function _prepare_pushforward_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f!, y, backend, x, tx, contexts...; strict) + _sig = signature(f!, y, backend, x, tx, contexts...; strict) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) pullback_prep = prepare_pullback(f!, y, backend, x, (dy,), contexts...; strict) - return PullbackPushforwardPrep{SIG,typeof(pullback_prep)}(pullback_prep) + return PullbackPushforwardPrep(_sig, pullback_prep) end ## One argument diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 9fdbeeb52..492fcae6d 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -37,6 +37,7 @@ function threshold_batchsize( end struct FromPrimitivePushforwardPrep{SIG,E<:PushforwardPrep} <: PushforwardPrep{SIG} + _sig::Val{SIG} pushforward_prep::E end @@ -46,11 +47,11 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) + _sig = signature(f, backend, x, tx, contexts...; strict) primitive_prep = prepare_pushforward(f, backend.backend, x, tx, contexts...; strict) - return FromPrimitivePushforwardPrep{SIG,typeof(primitive_prep)}(primitive_prep) + return FromPrimitivePushforwardPrep(_sig, primitive_prep) end function prepare_pushforward( @@ -60,11 +61,11 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f!, y, backend, x, tx, contexts...; strict) + _sig = signature(f!, y, backend, x, tx, contexts...; strict) primitive_prep = prepare_pushforward(f!, y, backend.backend, x, tx, contexts...; strict) - return FromPrimitivePushforwardPrep{SIG,typeof(primitive_prep)}(primitive_prep) + return FromPrimitivePushforwardPrep(_sig, primitive_prep) end function value_and_pushforward( @@ -153,6 +154,7 @@ function threshold_batchsize( end struct FromPrimitivePullbackPrep{SIG,E<:PullbackPrep} <: PullbackPrep{SIG} + _sig::Val{SIG} pullback_prep::E end @@ -162,11 +164,11 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f, backend, x, ty, contexts...; strict) + _sig = signature(f, backend, x, ty, contexts...; strict) primitive_prep = prepare_pullback(f, backend.backend, x, ty, contexts...; strict) - return FromPrimitivePullbackPrep{SIG,typeof(primitive_prep)}(primitive_prep) + return FromPrimitivePullbackPrep(_sig, primitive_prep) end function prepare_pullback( @@ -176,11 +178,11 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f!, y, backend, x, ty, contexts...; strict) + _sig = signature(f!, y, backend, x, ty, contexts...; strict) primitive_prep = prepare_pullback(f!, y, backend.backend, x, ty, contexts...; strict) - return FromPrimitivePullbackPrep{SIG,typeof(primitive_prep)}(primitive_prep) + return FromPrimitivePullbackPrep(_sig, primitive_prep) end function value_and_pullback( diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index b34053474..65cbd3c81 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -42,10 +42,10 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) - return NoPushforwardPrep{SIG}() + _sig = signature(f, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function prepare_pushforward( @@ -55,10 +55,10 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f!, y, backend, x, tx, contexts...; strict) - return NoPushforwardPrep{SIG}() + _sig = signature(f!, y, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function value_and_pushforward( diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index 5dcbf47da..94b64fa4c 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -26,10 +26,10 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) - return NoPushforwardPrep{SIG}() + _sig = signature(f, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function prepare_pushforward( @@ -39,10 +39,10 @@ function prepare_pushforward( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f!, y, backend, x, tx, contexts...; strict) - return NoPushforwardPrep{SIG}() + _sig = signature(f!, y, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function value_and_pushforward( @@ -129,10 +129,10 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f, backend, x, ty, contexts...; strict) - return NoPullbackPrep{SIG}() + _sig = signature(f, backend, x, ty, contexts...; strict) + return NoPullbackPrep(_sig) end function prepare_pullback( @@ -142,10 +142,10 @@ function prepare_pullback( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} - SIG = signature(f!, y, backend, x, ty, contexts...; strict) - return NoPullbackPrep{SIG}() + _sig = signature(f!, y, backend, x, ty, contexts...; strict) + return NoPullbackPrep(_sig) end function value_and_pullback( diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index a4cf03b0b..2d9e8c5ac 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_hessian(f, backend, x, [contexts...]; strict=false) -> prep + prepare_hessian(f, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("hessian")) """ @@ -60,7 +60,7 @@ struct HVPGradientHessianPrep{ E2<:HVPPrep, E1<:GradientPrep, } <: HessianPrep{SIG} - _sig::Type{SIG} + _sig::Val{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -69,7 +69,7 @@ struct HVPGradientHessianPrep{ end function prepare_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} # type-unstable batch_size_settings = pick_batchsize(outer(backend), x) @@ -83,9 +83,9 @@ function _prepare_hessian_aux( backend::AbstractADType, x, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {B,F,C} - SIG = signature(f, backend, x, contexts...; strict) + _sig = signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] batched_seeds = [ @@ -95,7 +95,7 @@ function _prepare_hessian_aux( hvp_prep = prepare_hvp(f, backend, x, batched_seeds[1], contexts...; strict) gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) return HVPGradientHessianPrep( - SIG, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep + _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep ) end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index b3766dfe4..7b9d2ca20 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_hvp(f, backend, x, tx, [contexts...]; strict=false) -> prep + prepare_hvp(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("hvp")) """ @@ -15,7 +15,7 @@ $(docstring_prepare("hvp")) function prepare!_hvp end """ - prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=false) -> prep_same + prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same $(docstring_prepare("hvp"; samepoint=true)) """ @@ -63,7 +63,7 @@ function prepare_hvp( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool=false, + strict::Val=Val(false), ) where {F,C} return _prepare_hvp_aux( hvp_mode(backend), @@ -81,7 +81,7 @@ end struct ForwardOverAnythingHVPPrep{SIG,G,GO,GI,PO,PI} <: HVPPrep{SIG} # pushforward of many pushforwards in theory, but pushforward of gradient in practice - _sig::Type{SIG} + _sig::Val{SIG} grad_buffer::G maybe_inner_gradient_prep::GO maybe_inner_gradient_in_prep::GI @@ -97,9 +97,9 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) + _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Outer pushforward @@ -123,7 +123,7 @@ function _prepare_hvp_aux( nothing end return ForwardOverAnythingHVPPrep( - SIG, grad_buffer, (), (), outer_pushforward_prep, outer_pushforward_in_prep + _sig, grad_buffer, (), (), outer_pushforward_prep, outer_pushforward_in_prep ) end @@ -135,9 +135,9 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) + _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient @@ -175,7 +175,7 @@ function _prepare_hvp_aux( nothing end return ForwardOverAnythingHVPPrep( - SIG, + _sig, grad_buffer, (inner_gradient_prep,), (inner_gradient_in_prep,), @@ -192,9 +192,9 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) + _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient @@ -236,7 +236,7 @@ function _prepare_hvp_aux( nothing end return ForwardOverAnythingHVPPrep( - SIG, + _sig, grad_buffer, (inner_gradient_prep,), (inner_gradient_in_prep,), @@ -452,7 +452,7 @@ end struct ReverseOverForwardHVPPrep{SIG,G2<:GradientPrep,G1<:GradientPrep} <: HVPPrep{SIG} # gradient of pushforward - _sig::Type{SIG} + _sig::Val{SIG} outer_gradient_prep::G2 gradient_prep::G1 end @@ -465,9 +465,9 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) + _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), @@ -480,7 +480,7 @@ function _prepare_hvp_aux( shuffled_single_pushforward, outer(backend), x, new_contexts...; strict ) gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) - return ReverseOverForwardHVPPrep(SIG, outer_gradient_prep, gradient_prep) + return ReverseOverForwardHVPPrep(_sig, outer_gradient_prep, gradient_prep) end function hvp( @@ -573,7 +573,7 @@ end struct ReverseOverReverseHVPPrep{SIG,G,PO,PI} <: HVPPrep{SIG} # pullback of gradient - _sig::Type{SIG} + _sig::Val{SIG} grad_buffer::G outer_pullback_prep::PO outer_pullback_in_prep::PI @@ -587,9 +587,9 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Bool, + strict::Val, ) where {F,C} - SIG = signature(f, backend, x, tx, contexts...; strict) + _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -612,7 +612,7 @@ function _prepare_hvp_aux( nothing end return ReverseOverReverseHVPPrep( - SIG, grad_buffer, outer_pullback_prep, outer_pullback_in_prep + _sig, grad_buffer, outer_pullback_prep, outer_pullback_in_prep ) end diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 6cda162a8..72f4b8855 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -1,7 +1,7 @@ ## Docstrings """ - prepare_second_derivative(f, backend, x, [contexts...]; strict=false) -> prep + prepare_second_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("second_derivative")) """ @@ -53,13 +53,14 @@ function value_derivative_and_second_derivative! end ## Preparation struct DerivativeSecondDerivativePrep{SIG,E<:DerivativePrep} <: SecondDerivativePrep{SIG} + _sig::Val{SIG} outer_derivative_prep::E end function prepare_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool=false + f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) ) where {F,C} - SIG = signature(f, backend, x, contexts...; strict) + _sig = signature(f, backend, x, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -67,9 +68,7 @@ function prepare_second_derivative( outer_derivative_prep = prepare_derivative( shuffled_derivative, outer(backend), x, new_contexts...; strict ) - return DerivativeSecondDerivativePrep{SIG,typeof(outer_derivative_prep)}( - outer_derivative_prep - ) + return DerivativeSecondDerivativePrep(_sig, outer_derivative_prep) end ## One argument diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index bfe6fb66a..61e2a75a2 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -4,53 +4,78 @@ abstract type Prep{SIG} end $(docstring_preptype("PushforwardPrep", "pushforward")) """ abstract type PushforwardPrep{SIG} <: Prep{SIG} end -struct NoPushforwardPrep{SIG} <: PushforwardPrep{SIG} end + +struct NoPushforwardPrep{SIG} <: PushforwardPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("PullbackPrep", "pullback")) """ abstract type PullbackPrep{SIG} <: Prep{SIG} end -struct NoPullbackPrep{SIG} <: PullbackPrep{SIG} end + +struct NoPullbackPrep{SIG} <: PullbackPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("DerivativePrep", "derivative")) """ abstract type DerivativePrep{SIG} <: Prep{SIG} end -struct NoDerivativePrep{SIG} <: DerivativePrep{SIG} end + +struct NoDerivativePrep{SIG} <: DerivativePrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("GradientPrep", "gradient")) """ abstract type GradientPrep{SIG} <: Prep{SIG} end -struct NoGradientPrep{SIG} <: GradientPrep{SIG} end + +struct NoGradientPrep{SIG} <: GradientPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("JacobianPrep", "jacobian")) """ abstract type JacobianPrep{SIG} <: Prep{SIG} end -struct NoJacobianPrep{SIG} <: JacobianPrep{SIG} end + +struct NoJacobianPrep{SIG} <: JacobianPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("HVPPrep", "hvp")) """ abstract type HVPPrep{SIG} <: Prep{SIG} end -struct NoHVPPrep{SIG} <: HVPPrep{SIG} end + +struct NoHVPPrep{SIG} <: HVPPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("HessianPrep", "hessian")) """ abstract type HessianPrep{SIG} <: Prep{SIG} end -struct NoHessianPrep{SIG} <: HessianPrep{SIG} end + +struct NoHessianPrep{SIG} <: HessianPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("SecondDerivativePrep", "second_derivative")) """ abstract type SecondDerivativePrep{SIG} <: Prep{SIG} end -struct NoSecondDerivativePrep{SIG} <: SecondDerivativePrep{SIG} end + +struct NoSecondDerivativePrep{SIG} <: SecondDerivativePrep{SIG} + _sig::Val{SIG} +end ## Checks -is_strict(::Prep{SIG}) where {SIG} = SIG !== Nothing +is_strict(::Prep{Nothing}) = Val(false) +is_strict(::Prep) = Val(true) function inconsistent_signatures_error(SIG, RUNTIME_SIG) msg = """ @@ -62,42 +87,48 @@ function inconsistent_signatures_error(SIG, RUNTIME_SIG) end function signature( - f, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool -) where {C} - if strict - return typeof((f, backend, x, contexts)) + f, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val{S} +) where {C,S} + if S + return Val(typeof((f, backend, x, contexts))) else - return Nothing + return Val(Nothing) end end function signature( - f!, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool -) where {C} - if strict - return typeof((f!, y, backend, x, contexts)) + f!, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val{S} +) where {C,S} + if S + return Val(typeof((f!, y, backend, x, contexts))) else - return Nothing + return Val(Nothing) end end function signature( - f, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}; strict::Bool -) where {C} - if strict - return typeof((f, backend, x, t, contexts)) + f, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}; strict::Val{S} +) where {C,S} + if S + return Val(typeof((f, backend, x, t, contexts))) else - return Nothing + return Val(Nothing) end end function signature( - f!, y, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}; strict::Bool -) where {C} - if strict - return typeof((f!, y, backend, x, t, contexts)) + f!, + y, + backend::AbstractADType, + x, + t::NTuple, + contexts::Vararg{Context,C}; + strict::Val{S}, +) where {C,S} + if S + return Val(typeof((f!, y, backend, x, t, contexts))) else - return Nothing + return Val(Nothing) end end diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index e0ad40df2..2b067c9ea 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -67,7 +67,7 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -132,7 +132,7 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -213,7 +213,7 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -286,7 +286,7 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -366,7 +366,7 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -433,7 +433,7 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -516,7 +516,7 @@ for op in ALL_OPS new_smaller.x, new_smaller.tang, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -580,7 +580,7 @@ for op in ALL_OPS new_smaller.x, new_smaller.tang, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -658,7 +658,7 @@ for op in ALL_OPS new_smaller.x, new_smaller.tang, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -734,7 +734,7 @@ for op in ALL_OPS new_smaller.x, new_smaller.tang, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -827,7 +827,7 @@ for op in ALL_OPS new_smaller.x, new_smaller.tang, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, @@ -891,7 +891,7 @@ for op in ALL_OPS new_smaller.x, new_smaller.tang, new_smaller.contexts...; - strict=true, + strict=Val(true), ), ba, xrand, From 50d377005629dbf8e9a3137a5643c86e6b5da601 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 22:57:39 +0100 Subject: [PATCH 07/22] Fixes --- .../ext/DifferentiationInterfaceReverseDiffExt/onearg.jl | 2 +- .../hessian.jl | 6 +++--- .../jacobian.jl | 8 ++++---- .../jacobian_mixed.jl | 8 +++++--- .../DifferentiationInterfaceTrackerExt.jl | 2 +- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 2a98c70ef..4e094a306 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -87,7 +87,7 @@ end function DI.prepare_gradient( f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) ) where {compile} - _sig = DI.signature(f, backend, x) + _sig = DI.signature(f, backend, x; strict) if compile tape = ReverseDiff.compile(GradientTape(f, x)) return ReverseDiffGradientPrep(_sig, nothing, tape) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 1bc4d2b7d..487ad770f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -62,10 +62,10 @@ function _prepare_sparse_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = DI.prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...) - gradient_prep = DI.prepare_gradient(f, DI.inner(dense_backend), x, contexts...) + hvp_prep = DI.prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...; strict) + gradient_prep = DI.prepare_gradient(f, DI.inner(dense_backend), x, contexts...; strict) return SparseHessianPrep( - SIG, + _sig, batch_size_settings, coloring_result, compressed_matrix, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 59802f066..82a8e540d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -109,10 +109,10 @@ function _prepare_sparse_jacobian_aux_aux( ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = DI.prepare_pushforward( - f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... + f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...; strict ) return PushforwardSparseJacobianPrep( - SIG, + _sig, batch_size_settings, coloring_result, compressed_matrix, @@ -143,10 +143,10 @@ function _prepare_sparse_jacobian_aux_aux( ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] pullback_prep = DI.prepare_pullback( - f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... + f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...; strict ) return PullbackSparseJacobianPrep( - SIG, + _sig, batch_size_settings, coloring_result, compressed_matrix, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index ed8585ad8..a8447c56b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -132,18 +132,20 @@ function _prepare_mixed_sparse_jacobian_aux_aux( DI.forward_backend(dense_backend), x, batched_seeds_forward[1], - contexts..., + contexts...; + strict, ) pullback_prep = DI.prepare_pullback( f_or_f!y..., DI.reverse_backend(dense_backend), x, batched_seeds_reverse[1], - contexts..., + contexts...; + strict, ) return MixedModeSparseJacobianPrep( - SIG, + _sig, batch_size_settings_forward, batch_size_settings_reverse, coloring_result, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index a2ad45c7e..c0ab6be0c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -35,7 +35,7 @@ function DI.prepare_pullback_same_point( ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} - _sig = DI.signature(f, prep, backend, x, ty, contexts...) + _sig = DI.signature(f, prep, backend, x, ty, contexts...; strict) DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(_sig, y, pb) From 3424a555e063a421af6ae77287f75e6aecd8f2df Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 23:03:05 +0100 Subject: [PATCH 08/22] Fix --- .../ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 95701efc4..6aceb1c29 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -9,7 +9,7 @@ function DI.prepare_pullback( contexts::Vararg{DI.Context,C}; strict::Val=Val(false), ) where {C} - _sig = DI.signature(f!, y, x, ty, contexts...; strict) + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end From e05f4fda126bd6e2c2609f270c77da4e9f96d04f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 23:13:22 +0100 Subject: [PATCH 09/22] Fixes --- .../onearg.jl | 3 +-- .../DifferentiationInterfaceTrackerExt.jl | 2 +- DifferentiationInterface/src/utils/prep.jl | 20 +++++++++++++------ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 9aae16830..f3dbcaa47 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -1,8 +1,7 @@ ## Pushforward -struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: - PolyesterForwardDiffOneArgPushforwardPrep{SIG} +struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} single_threaded_prep::P end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index c0ab6be0c..18c06b0bc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -35,7 +35,7 @@ function DI.prepare_pullback_same_point( ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} - _sig = DI.signature(f, prep, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, prep, backend, x, ty, contexts...; strict=DI.is_strict(prep)) DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(_sig, y, pb) diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 61e2a75a2..746e8690c 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -77,13 +77,21 @@ end is_strict(::Prep{Nothing}) = Val(false) is_strict(::Prep) = Val(true) -function inconsistent_signatures_error(SIG, RUNTIME_SIG) +struct PreparationMismatchError{SIG,RUNTIME_SIG} <: Exception end + +function PreparationMismatchError(::Type{SIG}, ::Type{RUNTIME_SIG}) where {SIG,RUNTIME_SIG} + return PreparationMismatchError{SIG,RUNTIME_SIG}() +end + +function Base.showerror( + io::IO, e::PreparationMismatchError{SIG,RUNTIME_SIG} +) where {SIG,RUNTIME_SIG} msg = """ Inconsistent signatures: - at preparation time: $SIG - at execution time: $RUNTIME_SIG """ - return ArgumentError(msg) + return print(io, msg) end function signature( @@ -138,7 +146,7 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f, backend, x, contexts)) if SIG != RUNTIME_SIG - throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + throw(PreparationMismatchError(SIG, RUNTIME_SIG)) end end end @@ -149,7 +157,7 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f!, y, backend, x, contexts)) if SIG != RUNTIME_SIG - throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + throw(PreparationMismatchError(SIG, RUNTIME_SIG)) end end end @@ -160,7 +168,7 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f, backend, x, t, contexts)) if SIG != RUNTIME_SIG - throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + throw(PreparationMismatchError(SIG, RUNTIME_SIG)) end end end @@ -171,7 +179,7 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f!, y, backend, x, t, contexts)) if SIG != RUNTIME_SIG - throw(inconsistent_signatures_error(SIG, RUNTIME_SIG)) + throw(PreparationMismatchError(SIG, RUNTIME_SIG)) end end end From 47d0b12e27f65148c724521f6679aa1e414a2fe6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 23:28:37 +0100 Subject: [PATCH 10/22] Fix caches --- DifferentiationInterface/src/second_order/hvp.jl | 6 ++++-- DifferentiationInterface/src/utils/context.jl | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 7b9d2ca20..3a13c6bf0 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -202,8 +202,10 @@ function _prepare_hvp_aux( xoi = overloaded_input( pushforward, shuffled_gradient!, grad_buffer, outer(backend), x, tx ) - inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...; strict) - inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contexts...; strict) + contextso = adapt_eltype.(contexts, Ref(eltype(xo))) + contextsoi = adapt_eltype.(contexts, Ref(eltype(xoi))) + inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contextso...; strict) + inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contextsoi...; strict) # Outer pushforward new_contexts = ( FunctionContext(f), diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 6d85ec0b9..aae8eb8cd 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -144,3 +144,6 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} tail_args = map(unwrap, contexts) return FixTail(f, tail_args...) end + +adapt_eltype(c::Constant, ::Type) where {T} = c +adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T)) From be6caff60de9e932a6f4b05c96a638ea9fe15a5c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 16 Mar 2025 23:32:00 +0100 Subject: [PATCH 11/22] Fix --- DifferentiationInterface/src/utils/context.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index aae8eb8cd..65edbde0f 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -145,5 +145,5 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} return FixTail(f, tail_args...) end -adapt_eltype(c::Constant, ::Type) where {T} = c +adapt_eltype(c::Constant, ::Type) = c adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T)) From 031fb5239b0df9d46db1561ceb919e95039d0597 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:08:23 +0100 Subject: [PATCH 12/22] Fixes --- .../forward_onearg.jl | 8 +- .../twoarg.jl | 9 +- .../onearg.jl | 2 +- .../onearg.jl | 33 ++----- .../twoarg.jl | 26 +++-- .../onearg.jl | 94 +++++++++++++++---- .../twoarg.jl | 43 +++++++-- .../DifferentiationInterfaceTrackerExt.jl | 2 +- DifferentiationInterface/src/utils/prep.jl | 53 ++++++++--- 9 files changed, 189 insertions(+), 81 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 13322d64e..d4d7b2eb0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -224,7 +224,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} - DI.check_prep(f, prep, backend, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -242,7 +242,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Constant,C}, ) where {F,B,C} - DI.check_prep(f, prep, backend, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -261,7 +261,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {F,C} - DI.check_prep(f, prep, backend, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -273,7 +273,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {F,C} - DI.check_prep(f, prep, backend, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index a4c7aea44..950c41383 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -48,8 +48,13 @@ function compute_ydual_twoarg( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,SIG,T,B,C} - (; xdual_tmp, ydual_tmp) = prep - make_dual!(T, xdual_tmp, x, tx) + (; ydual_tmp) = prep + if DI.ismutable_array(x) + make_dual!(T, prep.xdual_tmp, x, tx) + xdual_tmp = prep.xdual_tmp + else + xdual_tmp = make_dual(T, x, tx) + end contexts_dual = translate_prepared(contexts, prep.contexts_dual) f!(ydual_tmp, xdual_tmp, contexts_dual...) return ydual_tmp diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 1e880c395..3b5e92fe5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -25,7 +25,7 @@ function DI.prepare_pushforward( end if x isa Number xt = TPS{promote_type(typeof(first(tx)), typeof(x), Float64)}(; use=d) - return GTPSAOneArgPushforwardPrep(xt) + return GTPSAOneArgPushforwardPrep(_sig, xt) else xt = similar(x, TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}) for i in eachindex(xt) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index b84b0ccb3..27dbab0b4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -43,33 +43,21 @@ function DI.value_and_pullback( return new_y, (mycopy(new_dx),) end -function DI.value_and_pullback!( - f::F, - tx::NTuple{1}, - prep::MooncakeOneArgPullbackPrep{Y}, - backend::AutoMooncake, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,Y,C} - DI.check_prep(f, prep, backend, x, ty, contexts...) - y, (new_dx,) = DI.value_and_pullback(f, prep, backend, x, ty, contexts...) - copyto!(only(tx), new_dx) - return y, tx -end - function DI.value_and_pullback( f::F, - prep::MooncakeOneArgPullbackPrep, + prep::MooncakeOneArgPullbackPrep{Y}, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,C} +) where {F,Y,C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - y, tx = DI.value_and_pullback(f, prep, backend, x, (dy,), contexts...) - y, only(tx) + dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) + y, (_, new_dx) = value_and_pullback!!( + prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... + ) + y, mycopy(new_dx) end y = first(ys_and_tx[1]) tx = last.(ys_and_tx) @@ -86,11 +74,8 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, ty, contexts...) - ys = map(tx, ty) do dx, dy - y, _ = DI.value_and_pullback!(f, (dx,), prep, backend, x, (dy,), contexts...) - y - end - y = ys[1] + y, new_tx = DI.value_and_pullback(f, prep, backend, x, ty, contexts...) + foreach(copyto!, tx, new_tx) return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 1f5247706..29558cc55 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -45,14 +45,18 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - # Prepare cotangent to add after the forward pass. dy = only(ty) + # Prepare cotangent to add after the forward pass. dy_righttype_after = copyto!(prep.dy_righttype, dy) - # Run the reverse-pass and return the results. - contexts = map(DI.unwrap, contexts) y_after, (_, _, _, dx) = value_and_pullback!!( - prep.cache, dy_righttype_after, prep.target_function, f!, y, x, contexts... + prep.cache, + dy_righttype_after, + prep.target_function, + f!, + y, + x, + map(DI.unwrap, contexts)..., ) copyto!(y, y_after) return y, (mycopy(dx),) @@ -69,8 +73,18 @@ function DI.value_and_pullback( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy - _, tx = DI.value_and_pullback(f!, y, prep, backend, x, (dy,), contexts...) - only(tx) + dy_righttype_after = copyto!(prep.dy_righttype, dy) + y_after, (_, _, _, dx) = value_and_pullback!!( + prep.cache, + dy_righttype_after, + prep.target_function, + f!, + y, + x, + map(DI.unwrap, contexts)..., + ) + copyto!(y, y_after) + mycopy(dx) end return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index dd0089969..6009cd963 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -1,13 +1,20 @@ ## Pushforward -struct SymbolicsOneArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep +struct SymbolicsOneArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} pf_exe::E1 pf_exe!::E1! end function DI.prepare_pushforward( - f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) dx_var = variablize(dx, :dx) @@ -22,7 +29,7 @@ function DI.prepare_pushforward( elseif res isa RuntimeGeneratedFunction res, nothing end - return SymbolicsOneArgPushforwardPrep(pf_exe, pf_exe!) + return SymbolicsOneArgPushforwardPrep(_sig, pf_exe, pf_exe!) end function DI.pushforward( @@ -33,6 +40,7 @@ function DI.pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ty = map(tx) do dx dy = prep.pf_exe(x, dx, map(DI.unwrap, contexts)...) end @@ -48,6 +56,7 @@ function DI.pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.pf_exe!(dy, x, dx, map(DI.unwrap, contexts)...) @@ -63,6 +72,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward(f, prep, backend, x, tx, contexts...) end @@ -76,20 +86,23 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Derivative -struct SymbolicsOneArgDerivativePrep{E1,E1!} <: DI.DerivativePrep +struct SymbolicsOneArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) @@ -100,7 +113,7 @@ function DI.prepare_derivative( elseif res isa RuntimeGeneratedFunction res, nothing end - return SymbolicsOneArgDerivativePrep(der_exe, der_exe!) + return SymbolicsOneArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.derivative( @@ -110,6 +123,7 @@ function DI.derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.der_exe(x, map(DI.unwrap, contexts)...) end @@ -121,6 +135,7 @@ function DI.derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end @@ -132,6 +147,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative(f, prep, backend, x, contexts...) end @@ -144,20 +160,23 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient -struct SymbolicsOneArgGradientPrep{E1,E1!} <: DI.GradientPrep +struct SymbolicsOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} + _sig::Val{SIG} grad_exe::E1 grad_exe!::E1! end function DI.prepare_gradient( - f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) # Symbolic.gradient only accepts vectors @@ -165,7 +184,7 @@ function DI.prepare_gradient( res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false)) (grad_exe, grad_exe!) = res - return SymbolicsOneArgGradientPrep(grad_exe, grad_exe!) + return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!) end function DI.gradient( @@ -175,6 +194,7 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return reshape(prep.grad_exe(vec(x), map(DI.unwrap, contexts)...), size(x)) end @@ -186,6 +206,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.grad_exe!(vec(grad), vec(x), map(DI.unwrap, contexts)...) return grad end @@ -197,6 +218,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end @@ -208,13 +230,15 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian -struct SymbolicsOneArgJacobianPrep{E1,E1!} <: DI.JacobianPrep +struct SymbolicsOneArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end @@ -224,7 +248,9 @@ function DI.prepare_jacobian( backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, + strict::Val=Val(false), ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) jac_var = if backend isa AutoSparse @@ -235,7 +261,7 @@ function DI.prepare_jacobian( res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) (jac_exe, jac_exe!) = res - return SymbolicsOneArgJacobianPrep(jac_exe, jac_exe!) + return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!) end function DI.jacobian( @@ -245,6 +271,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, backend, x, contexts...) return prep.jac_exe(x, map(DI.unwrap, contexts)...) end @@ -256,6 +283,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, backend, x, contexts...) prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end @@ -267,6 +295,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -278,13 +307,15 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Hessian -struct SymbolicsOneArgHessianPrep{G,E2,E2!} <: DI.HessianPrep +struct SymbolicsOneArgHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} + _sig::Val{SIG} gradient_prep::G hess_exe::E2 hess_exe!::E2! @@ -294,8 +325,10 @@ function DI.prepare_hessian( f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) # Symbolic.hessian only accepts vectors @@ -309,7 +342,7 @@ function DI.prepare_hessian( (hess_exe, hess_exe!) = res gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) - return SymbolicsOneArgHessianPrep(gradient_prep, hess_exe, hess_exe!) + return SymbolicsOneArgHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end function DI.hessian( @@ -319,6 +352,7 @@ function DI.hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.hess_exe(vec(x), map(DI.unwrap, contexts)...) end @@ -330,6 +364,7 @@ function DI.hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.hess_exe!(hess, vec(x), map(DI.unwrap, contexts)...) return hess end @@ -341,6 +376,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, dense_ad(backend), x, contexts... ) @@ -357,6 +393,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... ) @@ -366,15 +403,22 @@ end ## HVP -struct SymbolicsOneArgHVPPrep{G,E2,E2!} <: DI.HVPPrep +struct SymbolicsOneArgHVPPrep{SIG,G,E2,E2!} <: DI.HVPPrep{SIG} + _sig::Val{SIG} gradient_prep::G hvp_exe::E2 hvp_exe!::E2! end function DI.prepare_hvp( - f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) dx_var = variablize(dx, :dx) @@ -389,7 +433,7 @@ function DI.prepare_hvp( (hvp_exe, hvp_exe!) = res gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) - return SymbolicsOneArgHVPPrep(gradient_prep, hvp_exe, hvp_exe!) + return SymbolicsOneArgHVPPrep(_sig, gradient_prep, hvp_exe, hvp_exe!) end function DI.hvp( @@ -400,6 +444,7 @@ function DI.hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return map(tx) do dx dg_vec = prep.hvp_exe(vec(x), vec(dx), map(DI.unwrap, contexts)...) reshape(dg_vec, size(x)) @@ -415,6 +460,7 @@ function DI.hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] prep.hvp_exe!(vec(dg), vec(x), vec(dx), map(DI.unwrap, contexts)...) @@ -430,6 +476,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) tg = DI.hvp(f, prep, backend, x, tx, contexts...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -445,6 +492,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) DI.hvp!(f, tg, prep, backend, x, tx, contexts...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -452,15 +500,17 @@ end ## Second derivative -struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: DI.SecondDerivativePrep +struct SymbolicsOneArgSecondDerivativePrep{SIG,D,E1,E1!} <: DI.SecondDerivativePrep{SIG} + _sig::Val{SIG} derivative_prep::D der2_exe::E1 der2_exe!::E1! end function DI.prepare_second_derivative( - f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) @@ -473,7 +523,7 @@ function DI.prepare_second_derivative( res, nothing end derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) - return SymbolicsOneArgSecondDerivativePrep(derivative_prep, der2_exe, der2_exe!) + return SymbolicsOneArgSecondDerivativePrep(_sig, derivative_prep, der2_exe, der2_exe!) end function DI.second_derivative( @@ -483,6 +533,7 @@ function DI.second_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.der2_exe(x, map(DI.unwrap, contexts)...) end @@ -494,6 +545,7 @@ function DI.second_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der2_exe!(der2, x, map(DI.unwrap, contexts)...) return der2 end @@ -505,6 +557,7 @@ function DI.value_derivative_and_second_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) der2 = DI.second_derivative(f, prep, backend, x, contexts...) return y, der, der2 @@ -519,6 +572,7 @@ function DI.value_derivative_and_second_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) DI.second_derivative!(f, der2, prep, backend, x, contexts...) return y, der, der2 diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index ab3e90928..eeeb9ed73 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -1,13 +1,21 @@ ## Pushforward -struct SymbolicsTwoArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep +struct SymbolicsTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} pushforward_exe::E1 pushforward_exe!::E1! end function DI.prepare_pushforward( - f!, y, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f!, + y, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) dx_var = variablize(dx, :dx) @@ -20,7 +28,7 @@ function DI.prepare_pushforward( res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) (pushforward_exe, pushforward_exe!) = res - return SymbolicsTwoArgPushforwardPrep(pushforward_exe, pushforward_exe!) + return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!) end function DI.pushforward( @@ -32,6 +40,7 @@ function DI.pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx dy = prep.pushforward_exe(x, dx, map(DI.unwrap, contexts)...) end @@ -48,6 +57,7 @@ function DI.pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.pushforward_exe!(dy, x, dx, map(DI.unwrap, contexts)...) @@ -64,6 +74,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -79,6 +90,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -86,14 +98,16 @@ end ## Derivative -struct SymbolicsTwoArgDerivativePrep{E1,E1!} <: DI.DerivativePrep +struct SymbolicsTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) y_var = variablize(y, :y) context_vars = variablize(contexts) @@ -102,7 +116,7 @@ function DI.prepare_derivative( res = build_function(der_var, x_var, context_vars...; expression=Val(false)) (der_exe, der_exe!) = res - return SymbolicsTwoArgDerivativePrep(der_exe, der_exe!) + return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.derivative( @@ -113,6 +127,7 @@ function DI.derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return prep.der_exe(x, map(DI.unwrap, contexts)...) end @@ -125,6 +140,7 @@ function DI.derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end @@ -137,6 +153,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) der = DI.derivative(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, der @@ -151,6 +168,7 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.derivative!(f!, y, der, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, der @@ -158,7 +176,8 @@ end ## Jacobian -struct SymbolicsTwoArgJacobianPrep{E1,E1!} <: DI.JacobianPrep +struct SymbolicsTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end @@ -168,8 +187,10 @@ function DI.prepare_jacobian( y, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; + strict::Val=Val(false), ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) y_var = variablize(y, :y) context_vars = variablize(contexts) @@ -182,7 +203,7 @@ function DI.prepare_jacobian( res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) (jac_exe, jac_exe!) = res - return SymbolicsTwoArgJacobianPrep(jac_exe, jac_exe!) + return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!) end function DI.jacobian( @@ -193,6 +214,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return prep.jac_exe(x, map(DI.unwrap, contexts)...) end @@ -205,6 +227,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end @@ -217,6 +240,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac @@ -231,6 +255,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 18c06b0bc..fd771bb01 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -35,7 +35,7 @@ function DI.prepare_pullback_same_point( ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} - _sig = DI.signature(f, prep, backend, x, ty, contexts...; strict=DI.is_strict(prep)) + _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(_sig, y, pb) diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 746e8690c..99dc458a4 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -77,21 +77,30 @@ end is_strict(::Prep{Nothing}) = Val(false) is_strict(::Prep) = Val(true) -struct PreparationMismatchError{SIG,RUNTIME_SIG} <: Exception end +struct PreparationMismatchError{SIG,RUNTIME_SIG} <: Exception + format::Vector{Symbol} +end -function PreparationMismatchError(::Type{SIG}, ::Type{RUNTIME_SIG}) where {SIG,RUNTIME_SIG} - return PreparationMismatchError{SIG,RUNTIME_SIG}() +function PreparationMismatchError( + ::Type{SIG}, ::Type{RUNTIME_SIG}; format +) where {SIG,RUNTIME_SIG} + return PreparationMismatchError{SIG,RUNTIME_SIG}(format) end function Base.showerror( io::IO, e::PreparationMismatchError{SIG,RUNTIME_SIG} -) where {SIG,RUNTIME_SIG} - msg = """ - Inconsistent signatures: - - at preparation time: $SIG - - at execution time: $RUNTIME_SIG - """ - return print(io, msg) +) where {SIG<:Tuple,RUNTIME_SIG<:Tuple} + println( + io, + "PreparationMismatchError (inconsistent types between preparation and execution):", + ) + for (s, pt, et) in zip(e.format, SIG.types, RUNTIME_SIG.types) + if pt == et + println(io, " - $s: ✅") + else + println(io, " - $s: ❌\n - prep: $pt\n - exec: $et") + end + end end function signature( @@ -146,7 +155,11 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f, backend, x, contexts)) if SIG != RUNTIME_SIG - throw(PreparationMismatchError(SIG, RUNTIME_SIG)) + throw( + PreparationMismatchError( + SIG, RUNTIME_SIG; format=[:f, :backend, :x, :contexts] + ), + ) end end end @@ -157,7 +170,11 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f!, y, backend, x, contexts)) if SIG != RUNTIME_SIG - throw(PreparationMismatchError(SIG, RUNTIME_SIG)) + throw( + PreparationMismatchError( + SIG, RUNTIME_SIG; format=[:f!, :y, :backend, :x, :contexts] + ), + ) end end end @@ -168,7 +185,11 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f, backend, x, t, contexts)) if SIG != RUNTIME_SIG - throw(PreparationMismatchError(SIG, RUNTIME_SIG)) + throw( + PreparationMismatchError( + SIG, RUNTIME_SIG; format=[:f, :backend, :x, :t, :contexts] + ), + ) end end end @@ -179,7 +200,11 @@ function check_prep( if SIG !== Nothing RUNTIME_SIG = typeof((f!, y, backend, x, t, contexts)) if SIG != RUNTIME_SIG - throw(PreparationMismatchError(SIG, RUNTIME_SIG)) + throw( + PreparationMismatchError( + SIG, RUNTIME_SIG; format=[:f!, :y, :backend, :x, :t, :contexts] + ), + ) end end end From 0aabc2bbc594e623d73790af51a7bb3ee5e4acf1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:28:31 +0100 Subject: [PATCH 13/22] Fixes --- .../ext/DifferentiationInterfaceGTPSAExt/onearg.jl | 2 +- .../ext/DifferentiationInterfaceSymbolicsExt/onearg.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 3b5e92fe5..2ca5797ac 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -560,7 +560,7 @@ function DI.prepare_hvp( strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - hessprep = DI.prepare_hessian(f, backend, x; strict) + hessprep = DI.prepare_hessian(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) return GTPSAOneArgHVPPrep(_sig, hessprep, hess) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 6009cd963..f8e6391c3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -247,7 +247,7 @@ function DI.prepare_jacobian( f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) From 3d17483b61dbdea5af55949ee5a2d8d87b0c03fe Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:56:59 +0100 Subject: [PATCH 14/22] Fixes --- .../forward_onearg.jl | 24 +++++++++---------- .../onearg.jl | 10 ++++---- .../onearg.jl | 6 ++--- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index d4d7b2eb0..c424dc540 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -137,11 +137,11 @@ end function DI.gradient( f::F, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) @@ -154,11 +154,11 @@ end function DI.value_and_gradient( f::F, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) @@ -172,11 +172,11 @@ end function DI.gradient!( f::F, grad, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end @@ -184,11 +184,11 @@ end function DI.value_and_gradient!( f::F, grad, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -219,11 +219,11 @@ end function DI.jacobian( f::F, - prep::EnzymeForwardOneArgJacobianPrep{B}, + prep::EnzymeForwardOneArgJacobianPrep{SIG,B}, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) @@ -237,11 +237,11 @@ end function DI.value_and_jacobian( f::F, - prep::EnzymeForwardOneArgJacobianPrep{B}, + prep::EnzymeForwardOneArgJacobianPrep{SIG,B}, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 2ca5797ac..ae4161556 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -87,9 +87,8 @@ function DI.value_and_pushforward( contexts::Vararg{DI.Constant,C}, ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) - fc = DI.with_contexts(f, contexts...) - ty = DI.pushforward(fc, prep, backend, x, tx) - y = fc(x) # TO-DO: optimize + ty = DI.pushforward(f, prep, backend, x, tx, contexts...) + y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize return y, ty end @@ -103,9 +102,8 @@ function DI.value_and_pushforward!( contexts::Vararg{DI.Constant,C}, ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) - fc = DI.with_contexts(f, contexts...) - DI.pushforward!(fc, ty, prep, backend, x, tx) - y = fc(x) # TO-DO: optimize + DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) + y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index f8e6391c3..2100b1319 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -460,7 +460,7 @@ function DI.hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - DI.check_prep(f, prep, backend, x, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] prep.hvp_exe!(vec(dg), vec(x), vec(dx), map(DI.unwrap, contexts)...) @@ -476,7 +476,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - DI.check_prep(f, prep, backend, x, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) tg = DI.hvp(f, prep, backend, x, tx, contexts...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -492,7 +492,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - DI.check_prep(f, prep, backend, x, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hvp!(f, tg, prep, backend, x, tx, contexts...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) return grad, tg From 618c8ae299b4120050c6e4a69ba61dcbfe1e7b7a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 10:23:22 +0100 Subject: [PATCH 15/22] Add tests --- .../reverse_onearg.jl | 4 +- .../onearg.jl | 8 +- DifferentiationInterface/src/docstrings.jl | 9 +- DifferentiationInterface/src/utils/prep.jl | 43 +++--- .../test/Core/Internals/signature.jl | 125 ++++++++++++++++++ .../src/DifferentiationInterfaceTest.jl | 1 + .../src/tests/correctness_eval.jl | 96 ++++++++++---- 7 files changed, 234 insertions(+), 52 deletions(-) create mode 100644 DifferentiationInterface/test/Core/Internals/signature.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index f7f3a6a9b..aa6465270 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -241,7 +241,7 @@ function DI.gradient!( backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} - DI.check_prep(f, prep, backend, x, contexts...) + DI.check_prep(f, prep, backend, x) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) gradient!(mode, grad, f_and_df, x) @@ -255,7 +255,7 @@ function DI.value_and_gradient!( backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} - DI.check_prep(f, prep, backend, x, contexts...) + DI.check_prep(f, prep, backend, x) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) _, result = gradient!(mode, grad, f_and_df, x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 2100b1319..330a70455 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -271,7 +271,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} - DI.check_prep(f, backend, x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) return prep.jac_exe(x, map(DI.unwrap, contexts)...) end @@ -283,7 +283,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} - DI.check_prep(f, backend, x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end @@ -295,7 +295,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} - DI.check_prep(f, backend, x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -307,7 +307,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} - DI.check_prep(f, backend, x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian!(f, jac, prep, backend, x, contexts...) end diff --git a/DifferentiationInterface/src/docstrings.jl b/DifferentiationInterface/src/docstrings.jl index 02d6068f1..e2d1ee006 100644 --- a/DifferentiationInterface/src/docstrings.jl +++ b/DifferentiationInterface/src/docstrings.jl @@ -21,9 +21,12 @@ function docstring_prepare(operator; samepoint=false, inplace=false) Depending on the backend, this can have several effects (preallocating memory, recording an execution trace) which are transparent to the user. !!! warning - The preparation result is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified. - Otherwise, preparation will be invalidated and you will need to run it again. - The keyword argument `strict` activates automatic type checking, but ensuring size consistency is up to the user. + The preparation result `prep` is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified. + Otherwise, preparation becomes invalid and you need to run it again. + In some settings, invalid preparations may still give correct results (e.g. for backends that require no preparation), but this is not a semantic guarantee and should not be relied upon. + + When `strict=Val(true)`, type checking is enforced between preparation and execution (but size checking is left to the user). + $(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "") """ end diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 99dc458a4..91450959c 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -77,30 +77,35 @@ end is_strict(::Prep{Nothing}) = Val(false) is_strict(::Prep) = Val(true) -struct PreparationMismatchError{SIG,RUNTIME_SIG} <: Exception +struct PreparationMismatchError{SIG,EXEC_SIG} <: Exception format::Vector{Symbol} end function PreparationMismatchError( - ::Type{SIG}, ::Type{RUNTIME_SIG}; format -) where {SIG,RUNTIME_SIG} - return PreparationMismatchError{SIG,RUNTIME_SIG}(format) + ::Type{SIG}, ::Type{EXEC_SIG}; format +) where {SIG,EXEC_SIG} + return PreparationMismatchError{SIG,EXEC_SIG}(format) end function Base.showerror( - io::IO, e::PreparationMismatchError{SIG,RUNTIME_SIG} -) where {SIG<:Tuple,RUNTIME_SIG<:Tuple} + io::IO, e::PreparationMismatchError{SIG,EXEC_SIG} +) where {SIG<:Tuple,EXEC_SIG<:Tuple} println( io, "PreparationMismatchError (inconsistent types between preparation and execution):", ) - for (s, pt, et) in zip(e.format, SIG.types, RUNTIME_SIG.types) + for (s, pt, et) in zip(e.format, SIG.types, EXEC_SIG.types) if pt == et println(io, " - $s: ✅") else println(io, " - $s: ❌\n - prep: $pt\n - exec: $et") end end + println( + io, + "To disable this check (not recommended), run preparation with the keyword argument `strict=Val(false)` when using DifferentiationInterface.", + ) + return nothing end function signature( @@ -153,11 +158,11 @@ function check_prep( f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - RUNTIME_SIG = typeof((f, backend, x, contexts)) - if SIG != RUNTIME_SIG + EXEC_SIG = typeof((f, backend, x, contexts)) + if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, RUNTIME_SIG; format=[:f, :backend, :x, :contexts] + SIG, EXEC_SIG; format=[:f, :backend, :x, :contexts] ), ) end @@ -168,11 +173,11 @@ function check_prep( f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - RUNTIME_SIG = typeof((f!, y, backend, x, contexts)) - if SIG != RUNTIME_SIG + EXEC_SIG = typeof((f!, y, backend, x, contexts)) + if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, RUNTIME_SIG; format=[:f!, :y, :backend, :x, :contexts] + SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :contexts] ), ) end @@ -183,11 +188,11 @@ function check_prep( f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - RUNTIME_SIG = typeof((f, backend, x, t, contexts)) - if SIG != RUNTIME_SIG + EXEC_SIG = typeof((f, backend, x, t, contexts)) + if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, RUNTIME_SIG; format=[:f, :backend, :x, :t, :contexts] + SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts] ), ) end @@ -198,11 +203,11 @@ function check_prep( f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} ) where {SIG,C} if SIG !== Nothing - RUNTIME_SIG = typeof((f!, y, backend, x, t, contexts)) - if SIG != RUNTIME_SIG + EXEC_SIG = typeof((f!, y, backend, x, t, contexts)) + if SIG != EXEC_SIG throw( PreparationMismatchError( - SIG, RUNTIME_SIG; format=[:f!, :y, :backend, :x, :t, :contexts] + SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts] ), ) end diff --git a/DifferentiationInterface/test/Core/Internals/signature.jl b/DifferentiationInterface/test/Core/Internals/signature.jl new file mode 100644 index 000000000..2d28ce7dc --- /dev/null +++ b/DifferentiationInterface/test/Core/Internals/signature.jl @@ -0,0 +1,125 @@ +using DifferentiationInterface +using DifferentiationInterface: AutoZeroReverse, AutoZeroForward +using Test + +backend = AutoZeroForward() +other_backend = AutoZeroReverse() +f(x, c) = x + c +f!(y, x) = y .= x +x = 1.0 +y = zeros(2) +c = 2.0 + +@testset "Out of place, no tangents" begin + prep = prepare_derivative(f, backend, x, Constant(c); strict=Val(true)) + prep_chill = prepare_derivative(f, backend, x, Constant(c); strict=Val(false)) + + @test_throws MethodError derivative(nothing, prep_chill, backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ❌ + - prep: typeof(f) + - exec: Nothing + - backend: ✅ + - x: ✅ + - contexts: ✅ + """ derivative(nothing, prep, backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ❌ + - prep: AutoZeroForward + - exec: AutoZeroReverse + - x: ✅ + - contexts: ✅ + """ derivative(f, prep, other_backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ✅ + - x: ❌ + - prep: Float64 + - exec: Int64 + - contexts: ✅ + """ derivative(f, prep, backend, 1, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ✅ + - x: ✅ + - contexts: ❌ + - prep: Tuple{Constant{Float64}} + - exec: Tuple{Constant{Int64}} + """ derivative(f, prep, backend, x, Constant(2)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ✅ + - x: ✅ + - contexts: ❌ + - prep: Tuple{Constant{Float64}} + - exec: Tuple{Constant{Int64}, Constant{Int64}} + """ derivative(f, prep, backend, x, Constant(2), Constant(3)) +end + +@testset "In place, no tangents" begin + prep = prepare_derivative(f!, y, backend, x; strict=Val(true)) + prep_chill = prepare_derivative(f!, y, backend, x; strict=Val(false)) + + @test_throws MethodError derivative(nothing, y, prep_chill, backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f!: ❌ + - prep: typeof(f!) + - exec: Nothing + - y: ✅ + - backend: ✅ + - x: ✅ + - contexts: ✅ + """ derivative(nothing, y, prep, backend, x) +end + +@testset "Out of place, with tangents" begin + prep = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(true)) + prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(false)) + + @test_throws MethodError pushforward(nothing, prep_chill, backend, x, (x,)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ❌ + - prep: typeof(f) + - exec: Nothing + - backend: ✅ + - x: ✅ + - tang: ✅ + - contexts: ✅ + """ pushforward(nothing, prep, backend, x, (x,), Constant(c)) +end + +@testset "In place, with tangents" begin + prep = prepare_pushforward(f!, y, backend, x, (x,); strict=Val(true)) + prep_chill = prepare_pushforward( + f!, y, backend, x, (x,), Constant(c); strict=Val(false) + ) + + @test_throws MethodError pushforward(nothing, y, prep_chill, backend, x, (x,)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f!: ❌ + - prep: typeof(f!) + - exec: Nothing + - y: ✅ + - backend: ✅ + - x: ✅ + - tang: ✅ + - contexts: ✅ + """ pushforward(nothing, y, prep, backend, x, (x,)) +end diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 2c92546ed..58e9777c5 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -90,6 +90,7 @@ using DifferentiationInterface: pushforward_performance, pullback_performance using DifferentiationInterface: Rewrap, Context, Constant, Cache, unwrap +using DifferentiationInterface: PreparationMismatchError using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES using JET: @test_opt using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 2b067c9ea..a73dfbd5d 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -1,3 +1,5 @@ +const PME = PreparationMismatchError + for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative @@ -60,7 +62,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -73,7 +75,7 @@ for op in ALL_OPS xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -100,6 +102,8 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -125,7 +129,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -138,7 +142,7 @@ for op in ALL_OPS xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -177,6 +181,10 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op!( + nothing, res1_in1_val, prepstrict, ba, x, contexts... + ) + @test_throws PME $op!(nothing, res1_in1_noval, prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -204,7 +212,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -219,7 +227,7 @@ for op in ALL_OPS xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -252,6 +260,8 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op(nothing, y_in1_val, prepstrict, ba, x, contexts...) + @test_throws PME $op(nothing, y_in1_noval, prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -277,7 +287,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -292,7 +302,7 @@ for op in ALL_OPS xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -333,6 +343,12 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op!( + nothing, y_in1_val, res1_in1_val, prepstrict, ba, x, contexts... + ) + @test_throws PME $op!( + nothing, y_in1_noval, res1_in1_noval, prepstrict, ba, x, contexts... + ) scenario_intact && @test new_scen == scen return nothing end @@ -359,7 +375,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -372,7 +388,7 @@ for op in ALL_OPS xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val, res2_out1_val = $val_and_op( @@ -401,6 +417,8 @@ for op in ALL_OPS @test mynnz(res2_out2_noval) == mynnz(scen.res2) end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -426,7 +444,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -439,7 +457,7 @@ for op in ALL_OPS xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val, res2_in1_val = mysimilar(res1), mysimilar(res2) @@ -482,6 +500,10 @@ for op in ALL_OPS @test mynnz(res2_out2_noval) == mynnz(scen.res2) end end + @test_throws PME $val_and_op!( + nothing, res1_in1_val, res2_in1_val, prepstrict, ba, x, contexts... + ) + @test_throws PME $op!(nothing, res2_in1_noval, prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -508,7 +530,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -524,7 +546,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -547,6 +569,8 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, tang, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, tang, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -572,7 +596,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -588,7 +612,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -623,6 +647,12 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op!( + nothing, res1_in1_val, prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $op!( + nothing, res1_in1_noval, prepstrict, ba, x, tang, contexts... + ) scenario_intact && @test new_scen == scen return nothing end @@ -648,7 +678,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -666,7 +696,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -699,6 +729,10 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op( + nothing, y_in1_val, prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $op(nothing, y_in1_noval, prepstrict, ba, x, tang, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -724,7 +758,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -742,7 +776,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -793,6 +827,12 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op!( + nothing, y_in1_val, res1_in1_val, prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $op!( + nothing, y_in2_noval, res1_in2_noval, prepstrict, ba, x, tang, contexts... + ) scenario_intact && @test new_scen == scen return nothing end @@ -819,7 +859,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -835,7 +875,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_out1_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) @@ -858,6 +898,8 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, tang, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, tang, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -883,7 +925,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, @@ -899,7 +941,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_in1_noval = mysimilar(res2) @@ -950,6 +992,12 @@ for op in ALL_OPS end end end + @test_throws PME $op!( + nothing, res2_in1_noval, prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $val_and_op!( + nothing, res1_in1_val, res2_in1_val, prepstrict, ba, x, tang, contexts... + ) scenario_intact && @test new_scen == scen return nothing end From df4a0d28f4db5d2ecccc1e9f7cdfe73129b1f90c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 10:24:53 +0100 Subject: [PATCH 16/22] Fix --- .../src/DifferentiationInterfaceTest.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 58e9777c5..a40645fbf 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -98,7 +98,7 @@ using ProgressMeter: ProgressUnknown, next! using Random: AbstractRNG, default_rng, rand! using SparseArrays: SparseArrays, AbstractSparseMatrix, SparseMatrixCSC, nnz, sparse, spdiagm -using Test: @testset, @test +using Test: @testset, @test, @test_throws """ FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian] From 0b2e62e84727a255405c126c2a95259ff0d02ea6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 10:32:12 +0100 Subject: [PATCH 17/22] Fix --- .../src/tests/correctness_eval.jl | 55 +++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index a73dfbd5d..9ed8013f8 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -54,6 +54,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -121,6 +122,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -182,9 +184,9 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, res1_in1_val, prepstrict, ba, x, contexts... + nothing, mysimilar(res1), prepstrict, ba, x, contexts... ) - @test_throws PME $op!(nothing, res1_in1_noval, prepstrict, ba, x, contexts...) + @test_throws PME $op!(nothing, mysimilar(res1), prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -204,6 +206,7 @@ for op in ALL_OPS xrand, yrand = myrandom(x), myrandom(y) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -260,8 +263,10 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end - @test_throws PME $val_and_op(nothing, y_in1_val, prepstrict, ba, x, contexts...) - @test_throws PME $op(nothing, y_in1_noval, prepstrict, ba, x, contexts...) + @test_throws PME $val_and_op( + nothing, mysimilar(y), prepstrict, ba, x, contexts... + ) + @test_throws PME $op(nothing, mysimilar(y), prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -279,6 +284,7 @@ for op in ALL_OPS xrand, yrand = myrandom(x), myrandom(y) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -344,10 +350,10 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, y_in1_val, res1_in1_val, prepstrict, ba, x, contexts... + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, contexts... ) @test_throws PME $op!( - nothing, y_in1_noval, res1_in1_noval, prepstrict, ba, x, contexts... + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, contexts... ) scenario_intact && @test new_scen == scen return nothing @@ -367,6 +373,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -436,6 +443,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -501,9 +509,9 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, res1_in1_val, res2_in1_val, prepstrict, ba, x, contexts... + nothing, mysimilar(res1), mysimilar(res2), prepstrict, ba, x, contexts... ) - @test_throws PME $op!(nothing, res2_in1_noval, prepstrict, ba, x, contexts...) + @test_throws PME $op!(nothing, mysimilar(res2), prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -522,6 +530,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -588,6 +597,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -648,10 +658,10 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, res1_in1_val, prepstrict, ba, x, tang, contexts... + nothing, mysimilar(res1), prepstrict, ba, x, tang, contexts... ) @test_throws PME $op!( - nothing, res1_in1_noval, prepstrict, ba, x, tang, contexts... + nothing, mysimilar(res1), prepstrict, ba, x, tang, contexts... ) scenario_intact && @test new_scen == scen return nothing @@ -670,6 +680,7 @@ for op in ALL_OPS xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -730,9 +741,11 @@ for op in ALL_OPS end end @test_throws PME $val_and_op( - nothing, y_in1_val, prepstrict, ba, x, tang, contexts... + nothing, mysimilar(y), prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $op( + nothing, mysimilar(y), prepstrict, ba, x, tang, contexts... ) - @test_throws PME $op(nothing, y_in1_noval, prepstrict, ba, x, tang, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -750,6 +763,7 @@ for op in ALL_OPS xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -828,10 +842,10 @@ for op in ALL_OPS end end @test_throws PME $val_and_op!( - nothing, y_in1_val, res1_in1_val, prepstrict, ba, x, tang, contexts... + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, tang, contexts... ) @test_throws PME $op!( - nothing, y_in2_noval, res1_in2_noval, prepstrict, ba, x, tang, contexts... + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, tang, contexts... ) scenario_intact && @test new_scen == scen return nothing @@ -851,6 +865,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -917,6 +932,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -993,10 +1009,17 @@ for op in ALL_OPS end end @test_throws PME $op!( - nothing, res2_in1_noval, prepstrict, ba, x, tang, contexts... + nothing, mysimilar(res2), prepstrict, ba, x, tang, contexts... ) @test_throws PME $val_and_op!( - nothing, res1_in1_val, res2_in1_val, prepstrict, ba, x, tang, contexts... + nothing, + mysimilar(res1), + mysimilar(res2), + prepstrict, + ba, + x, + tang, + contexts..., ) scenario_intact && @test new_scen == scen return nothing From b13a31db6899b297d023c1062ab3a963c8c58273 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:18:16 +0100 Subject: [PATCH 18/22] Fix --- .../onearg.jl | 12 ++++++++---- .../test/Back/PolyesterForwardDiff/test.jl | 7 ------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index f3dbcaa47..089f53258 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -442,7 +442,7 @@ function DI.value_derivative_and_second_derivative( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative( - f, prep, single_threaded(backend), x, contexts... + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end @@ -457,7 +457,7 @@ function DI.value_derivative_and_second_derivative!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative!( - f, der, der2, prep, single_threaded(backend), x, contexts... + f, der, der2, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end @@ -469,7 +469,9 @@ function DI.second_derivative( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - return DI.second_derivative(f, prep, single_threaded(backend), x, contexts...) + return DI.second_derivative( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.second_derivative!( @@ -481,5 +483,7 @@ function DI.second_derivative!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - return DI.second_derivative!(f, der2, prep, single_threaded(backend), x, contexts...) + return DI.second_derivative!( + f, der2, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 71ea785db..4f38af5b1 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -30,13 +30,6 @@ test_differentiation( backends, default_scenarios(; include_constantified=true, include_cachified=true); logging=LOGGING, - excluded=SECOND_ORDER, -); - -test_differentiation( - SecondOrder(AutoPolyesterForwardDiff(), AutoPolyesterForwardDiff()), - default_scenarios(); - logging=LOGGING, ); @testset "Batch size" begin From bcf644da06a151d57b79dcceaae5c400e4b31f7d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 15:30:25 +0100 Subject: [PATCH 19/22] Positional arguments --- .../reverse_onearg.jl | 4 +- .../DifferentiationInterfaceDiffractorExt.jl | 4 +- .../forward_onearg.jl | 6 +- .../forward_twoarg.jl | 2 +- .../reverse_onearg.jl | 4 +- .../reverse_twoarg.jl | 2 +- .../onearg.jl | 28 ++----- .../twoarg.jl | 13 +--- .../onearg.jl | 15 ++-- .../twoarg.jl | 24 ++---- ...rentiationInterfaceFiniteDifferencesExt.jl | 16 +--- .../onearg.jl | 30 ++----- .../twoarg.jl | 24 ++---- .../onearg.jl | 19 ++--- .../twoarg.jl | 4 +- .../onearg.jl | 9 +-- .../twoarg.jl | 2 +- .../onearg.jl | 36 +++------ .../twoarg.jl | 11 +-- .../onearg.jl | 23 +++--- .../twoarg.jl | 11 +-- .../hessian.jl | 10 +-- .../jacobian.jl | 20 ++--- .../jacobian_mixed.jl | 18 ++--- .../onearg.jl | 30 +++---- .../twoarg.jl | 6 +- .../DifferentiationInterfaceTrackerExt.jl | 8 +- .../DifferentiationInterfaceZygoteExt.jl | 29 ++----- .../src/fallbacks/change_prep.jl | 22 +++--- .../src/fallbacks/no_prep.jl | 78 +++++++++---------- .../src/first_order/derivative.jl | 17 ++-- .../src/first_order/gradient.jl | 8 +- .../src/first_order/jacobian.jl | 25 +++--- .../src/first_order/pullback.jl | 29 ++++--- .../src/first_order/pushforward.jl | 29 ++++--- .../src/misc/from_primitive.jl | 16 ++-- .../src/misc/simple_finite_diff.jl | 4 +- .../src/misc/zero_backends.jl | 18 +---- .../src/second_order/hessian.jl | 14 ++-- .../src/second_order/hvp.jl | 53 +++++++------ .../src/second_order/second_derivative.jl | 8 +- DifferentiationInterface/src/utils/prep.jl | 2 +- .../test/Core/ZeroBackends/test.jl | 18 +---- 43 files changed, 290 insertions(+), 459 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 6f6caba52..6d7f0d517 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -7,25 +7,25 @@ struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( + strict::Val, f, backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( + strict, f, prep::DI.NoPullbackPrep, backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Val=Val(false), ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) _sig = DI.signature(f, backend, x, ty, contexts...; strict) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 92176fc11..baab8a75f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -10,8 +10,8 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward -function DI.prepare_pushforward(f, backend::AutoDiffractor, x, tx::NTuple) - _sig = DI.signature(f, backend, x, tx) +function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple) + _sig = DI.signature(f, backend, x, tx; strict) return DI.NoPushforwardPrep(_sig) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index c424dc540..96b2fff76 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,12 +1,12 @@ ## Pushforward function DI.prepare_pushforward( + strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) return DI.NoPushforwardPrep(_sig) @@ -123,11 +123,11 @@ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( + strict::Val, f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) valB = to_val(DI.pick_batchsize(backend, x)) @@ -204,11 +204,11 @@ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( + strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index f3d9e777c..33e5ce1aa 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,13 +1,13 @@ ## Pushforward function DI.prepare_pushforward( + strict::Val, f!::F, y, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) return DI.NoPushforwardPrep(_sig) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index aa6465270..205a86837 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -53,12 +53,12 @@ struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( + strict::Val, f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) @@ -192,11 +192,11 @@ end ## Gradient function DI.prepare_gradient( + strict::Val, f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoGradientPrep(_sig) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 7b36e748d..30562fb7f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -6,13 +6,13 @@ struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( + strict::Val, f!::F, y, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) ty_copy = map(copy, ty) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 6a9c8f30d..0ebfdb7dc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -8,12 +8,12 @@ struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardP end function DI.prepare_pushforward( + strict::Val, f, backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) @@ -106,12 +106,12 @@ struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( + strict::Val, f, backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) @@ -205,11 +205,7 @@ struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePre end function DI.prepare_derivative( - f, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) @@ -289,11 +285,7 @@ struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( - f, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -369,11 +361,11 @@ struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SI end function DI.prepare_jacobian( + strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) @@ -454,11 +446,7 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: end function DI.prepare_second_derivative( - f, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) @@ -547,12 +535,12 @@ struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} end function DI.prepare_hvp( + strict::Val, f, backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) @@ -646,11 +634,11 @@ struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} end function DI.prepare_hessian( + strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index 7904f8789..f67ed4324 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -7,13 +7,13 @@ struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPre end function DI.prepare_pushforward( + strict::Val, f!, y, backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) @@ -108,13 +108,13 @@ struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( + strict::Val, f!, y, backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) @@ -214,12 +214,7 @@ struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{ end function DI.prepare_derivative( - f!, - y, - backend::AutoFastDifferentiation, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -301,12 +296,12 @@ struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( + strict::Val, f!, y, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 79b3d528a..7ebdb6e49 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -9,12 +9,7 @@ struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( - f, - backend::AutoFiniteDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) fc = DI.with_contexts(f, contexts...) @@ -130,7 +125,7 @@ struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) @@ -259,7 +254,7 @@ struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) @@ -347,7 +342,7 @@ struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) @@ -452,7 +447,7 @@ struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} end function DI.prepare_hessian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index a81c49249..259ebbdd3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -9,13 +9,13 @@ struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( + strict::Val, f!, y, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) cache = if x isa Number @@ -162,12 +162,7 @@ struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f!, - y, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) df = similar(y) @@ -203,9 +198,7 @@ function DI.prepare!_derivative( cache.c3 isa Union{Number,Nothing} || resize!(cache.c3, length(y)) return old_prep else - return DI.prepare_derivative( - f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) - ) + return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end @@ -285,12 +278,7 @@ struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f!, - y, - backend::AutoFiniteDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x1 = similar(x) @@ -330,9 +318,7 @@ function DI.prepare!_jacobian( cache.sparsity = nothing return old_prep else - return DI.prepare_jacobian( - f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) - ) + return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 69cc37c1f..31bcd5961 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -12,12 +12,12 @@ DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple() ## Pushforward function DI.prepare_pushforward( + strict::Val, f, backend::AutoFiniteDifferences, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) return DI.NoPushforwardPrep(_sig) @@ -55,12 +55,12 @@ end ## Pullback function DI.prepare_pullback( + strict::Val, f, backend::AutoFiniteDifferences, x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) @@ -98,11 +98,7 @@ end ## Gradient function DI.prepare_gradient( - f, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoGradientPrep(_sig) @@ -159,11 +155,7 @@ end ## Jacobian function DI.prepare_jacobian( - f, - backend::AutoFiniteDifferences, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoJacobianPrep(_sig) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index bd5a3f3cd..703fe6be6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -68,12 +68,12 @@ struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( + strict::Val, f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,B,C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) T = tag_type(f, backend, x) @@ -216,14 +216,10 @@ end ### Prepared function DI.prepare_derivative( - f::F, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) - pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...; strict) + pushforward_prep = DI.prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep) end @@ -365,11 +361,11 @@ struct ForwardDiffGradientPrep{SIG,C,CD} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( + strict::Val, f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) @@ -537,11 +533,7 @@ struct ForwardDiffOneArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f::F, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) @@ -629,11 +621,7 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoSecondDerivativePrep(_sig) @@ -806,11 +794,7 @@ struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} end function DI.prepare_hessian( - f::F, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 950c41383..a919d8965 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -9,13 +9,13 @@ struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( + strict::Val, f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,B,C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) T = tag_type(f!, backend, x) @@ -189,12 +189,7 @@ struct ForwardDiffTwoArgDerivativePrep{SIG,C,CD} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) tag = get_tag(f!, backend, x) @@ -217,9 +212,7 @@ function DI.prepare!_derivative( resize!(config.duals, length(y)) return old_prep else - return DI.prepare_derivative( - f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) - ) + return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end @@ -383,12 +376,7 @@ struct ForwardDiffTwoArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) @@ -414,9 +402,7 @@ function DI.prepare!_jacobian( resize!(xduals, length(x)) return old_prep else - return DI.prepare_jacobian( - f!, y, backend, x, contexts...; strict=DI.is_strict(old_prep) - ) + return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index ae4161556..83be97ed1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -8,12 +8,12 @@ struct GTPSAOneArgPushforwardPrep{SIG,X} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( + strict::Val, f::F, backend::AutoGTPSA{D}, x, tx::NTuple, contexts::Vararg{DI.Constant,C}; - strict::Val=Val(false), ) where {F,D,C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) @@ -116,7 +116,7 @@ end # Unlike JVP, this requires us to use all variables function DI.prepare_gradient( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing @@ -199,7 +199,7 @@ end # To materialize the entire Jacobian we use all variables function DI.prepare_jacobian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing @@ -285,7 +285,7 @@ struct GTPSAOneArgSecondDerivativePrep{SIG,X} <: DI.SecondDerivativePrep{SIG} end function DI.prepare_second_derivative( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing @@ -415,7 +415,7 @@ struct GTPSAOneArgHessianPrep{SIG,X,M} <: DI.HessianPrep{SIG} end function DI.prepare_hessian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing @@ -550,15 +550,10 @@ struct GTPSAOneArgHVPPrep{SIG,E,H} <: DI.HVPPrep{SIG} end function DI.prepare_hvp( - f, - backend::AutoGTPSA, - x, - tx::NTuple, - contexts::Vararg{DI.Constant,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C} ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - hessprep = DI.prepare_hessian(f, backend, x, contexts...; strict) + hessprep = DI.prepare_hessian(strict, f, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) return GTPSAOneArgHVPPrep(_sig, hessprep, hess) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index 647662672..e77c4900d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -11,13 +11,13 @@ struct GTPSATwoArgPushforwardPrep{SIG,X,Y} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( + strict::Val, f!::F, y, backend::AutoGTPSA{D}, x, tx::NTuple, contexts::Vararg{DI.Constant,C}; - strict::Val=Val(false), ) where {F,D,C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) @@ -126,7 +126,7 @@ struct GTPSATwoArgJacobianPrep{SIG,X,Y} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}; strict::Val=Val(false) + strict::Val, f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) if D != Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 27dbab0b4..e1e0c580b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -7,12 +7,7 @@ struct MooncakeOneArgPullbackPrep{SIG,Tcache,DY} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( - f::F, - backend::AutoMooncake, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) @@ -112,7 +107,7 @@ struct MooncakeGradientPrep{SIG,Tcache} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( - f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C} ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 29558cc55..43bed9857 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -6,13 +6,13 @@ struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( + strict::Val, f!::F, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) target_function = function (f!, y, x, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 089f53258..2e9380fd4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -7,16 +7,16 @@ struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SI end function DI.prepare_pushforward( + strict::Val, f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward( - f, single_threaded(backend), x, tx, contexts...; strict + strict, f, single_threaded(backend), x, tx, contexts... ) return PolyesterForwardDiffOneArgPushforwardPrep(_sig, single_threaded_prep) end @@ -87,15 +87,11 @@ struct PolyesterForwardDiffOneArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_derivative( - f, single_threaded(backend), x, contexts...; strict + strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgDerivativePrep(_sig, single_threaded_prep) end @@ -163,11 +159,11 @@ struct PolyesterForwardDiffGradientPrep{SIG,chunksize,P} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( + strict::Val, f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {chunksize,C} _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) @@ -176,7 +172,7 @@ function DI.prepare_gradient( chunk = Chunk{chunksize}() end single_threaded_prep = DI.prepare_gradient( - f, single_threaded(backend), x, contexts...; strict + strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffGradientPrep(_sig, chunk, single_threaded_prep) end @@ -254,11 +250,11 @@ struct PolyesterForwardDiffOneArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPre end function DI.prepare_jacobian( + strict::Val, f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {chunksize,C} _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) @@ -267,7 +263,7 @@ function DI.prepare_jacobian( chunk = Chunk{chunksize}() end single_threaded_prep = DI.prepare_jacobian( - f, single_threaded(backend), x, contexts...; strict + strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgJacobianPrep(_sig, chunk, single_threaded_prep) end @@ -344,15 +340,11 @@ struct PolyesterForwardDiffHessianPrep{SIG,P} <: DI.HessianPrep{SIG} end function DI.prepare_hessian( - f, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_hessian( - f, single_threaded(backend), x, contexts...; strict + strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffHessianPrep(_sig, single_threaded_prep) end @@ -420,15 +412,11 @@ struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG,P} <: DI.SecondDerivat end function DI.prepare_second_derivative( - f, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_second_derivative( - f, single_threaded(backend), x, contexts...; strict + strict, f, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffOneArgSecondDerivativePrep(_sig, single_threaded_prep) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 818cc1e71..1c7a8c92a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -12,7 +12,7 @@ function DI.prepare_pushforward( x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward( @@ -91,12 +91,7 @@ struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f!, - y, - backend::AutoPolyesterForwardDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Val ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_derivative( @@ -177,7 +172,7 @@ function DI.prepare_jacobian( backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, ) where {chunksize,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) if isnothing(chunksize) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 4e094a306..13b673ac5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -1,12 +1,7 @@ ## Pullback function DI.prepare_pullback( - f, - backend::AutoReverseDiff, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) @@ -85,7 +80,7 @@ struct ReverseDiffGradientPrep{SIG,C,T} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( - f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) + strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) if compile @@ -149,7 +144,7 @@ end ### With contexts function DI.prepare_gradient( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) config = GradientConfig(x) @@ -222,7 +217,7 @@ struct ReverseDiffOneArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) + strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) if compile @@ -286,7 +281,7 @@ end ### With contexts function DI.prepare_jacobian( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) config = JacobianConfig(x) @@ -360,10 +355,10 @@ struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.Hessia end function DI.prepare_hessian( - f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) + strict::Val, f, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f, backend, x; strict) - gradient_prep = DI.prepare_gradient(f, backend, x) + gradient_prep = DI.prepare_gradient(strict, f, backend, x) if compile hessian_tape = ReverseDiff.compile(HessianTape(f, x)) return ReverseDiffHessianPrep(_sig, gradient_prep, nothing, hessian_tape) @@ -418,10 +413,10 @@ end ### With contexts function DI.prepare_hessian( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) hessian_config = HessianConfig(x) return ReverseDiffHessianPrep(_sig, gradient_prep, hessian_config, nothing) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 6aceb1c29..531212fe5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -1,13 +1,13 @@ ## Pullback function DI.prepare_pullback( + strict::Val, f!, y, backend::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) @@ -140,7 +140,7 @@ struct ReverseDiffTwoArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( - f!, y, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false) + strict::Val, f!, y, backend::AutoReverseDiff{compile}, x ) where {compile} _sig = DI.signature(f!, y, backend, x; strict) if compile @@ -206,12 +206,7 @@ end ### With contexts function DI.prepare_jacobian( - f!, - y, - backend::AutoReverseDiff, - x, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) config = JacobianConfig(y, x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 487ad770f..1084583f2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -26,7 +26,7 @@ SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result) ## Hessian, one argument function DI.prepare_hessian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) sparsity = DI.hessian_sparsity_with_contexts( @@ -39,18 +39,18 @@ function DI.prepare_hessian( N = length(column_groups(coloring_result)) batch_size_settings = DI.pick_batchsize(DI.outer(dense_backend), N) return _prepare_sparse_hessian_aux( - batch_size_settings, coloring_result, f, backend, x, contexts...; strict + strict, batch_size_settings, coloring_result, f, backend, x, contexts... ) end function _prepare_sparse_hessian_aux( + strict::Val, batch_size_settings::DI.BatchSizeSettings{B}, coloring_result::AbstractColoringResult{:symmetric,:column}, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {B,F,C} _sig = DI.signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings @@ -62,8 +62,8 @@ function _prepare_sparse_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = DI.prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...; strict) - gradient_prep = DI.prepare_gradient(f, DI.inner(dense_backend), x, contexts...; strict) + hvp_prep = DI.prepare_hvp(strict, f, dense_backend, x, batched_seeds[1], contexts...) + gradient_prep = DI.prepare_gradient(strict, f, DI.inner(dense_backend), x, contexts...) return SparseHessianPrep( _sig, batch_size_settings, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 82a8e540d..85eac3b0f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -37,30 +37,30 @@ struct PullbackSparseJacobianPrep{ end function DI.prepare_jacobian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) y = f(x, map(DI.unwrap, contexts)...) perf = DI.pushforward_performance(dense_backend) - return _prepare_sparse_jacobian_aux(perf, y, (f,), backend, x, contexts...; strict) + return _prepare_sparse_jacobian_aux(strict, perf, y, (f,), backend, x, contexts...) end function DI.prepare_jacobian( - f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) perf = DI.pushforward_performance(dense_backend) - return _prepare_sparse_jacobian_aux(perf, y, (f!, y), backend, x, contexts...; strict) + return _prepare_sparse_jacobian_aux(strict, perf, y, (f!, y), backend, x, contexts...) end function _prepare_sparse_jacobian_aux( + strict::Val, perf::DI.PushforwardPerformance, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -84,11 +84,12 @@ function _prepare_sparse_jacobian_aux( end batch_size_settings = DI.pick_batchsize(dense_backend, N) return _prepare_sparse_jacobian_aux_aux( - batch_size_settings, coloring_result, y, f_or_f!y, backend, x, contexts...; strict + strict, batch_size_settings, coloring_result, y, f_or_f!y, backend, x, contexts... ) end function _prepare_sparse_jacobian_aux_aux( + strict::Val, batch_size_settings::DI.BatchSizeSettings{B}, coloring_result::AbstractColoringResult{:nonsymmetric,:column}, y, @@ -96,7 +97,6 @@ function _prepare_sparse_jacobian_aux_aux( backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {B,FY,C} _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings @@ -109,7 +109,7 @@ function _prepare_sparse_jacobian_aux_aux( ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = DI.prepare_pushforward( - f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...; strict + strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PushforwardSparseJacobianPrep( _sig, @@ -123,6 +123,7 @@ function _prepare_sparse_jacobian_aux_aux( end function _prepare_sparse_jacobian_aux_aux( + strict::Val, batch_size_settings::DI.BatchSizeSettings{B}, coloring_result::AbstractColoringResult{:nonsymmetric,:row}, y, @@ -130,7 +131,6 @@ function _prepare_sparse_jacobian_aux_aux( backend::AutoSparse, x, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {B,FY,C} _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings @@ -143,7 +143,7 @@ function _prepare_sparse_jacobian_aux_aux( ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] pullback_prep = DI.prepare_pullback( - f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...; strict + strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PullbackSparseJacobianPrep( _sig, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index a8447c56b..c5d0d2847 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -28,34 +28,34 @@ struct MixedModeSparseJacobianPrep{ end function DI.prepare_jacobian( + strict::Val, f::F, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} y = f(x, map(DI.unwrap, contexts)...) - return _prepare_mixed_sparse_jacobian_aux(y, (f,), backend, x, contexts...; strict) + return _prepare_mixed_sparse_jacobian_aux(strict, y, (f,), backend, x, contexts...) end function DI.prepare_jacobian( + strict::Val, f!::F, y, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {F,C} - return _prepare_mixed_sparse_jacobian_aux(y, (f!, y), backend, x, contexts...; strict) + return _prepare_mixed_sparse_jacobian_aux(strict, y, (f!, y), backend, x, contexts...) end function _prepare_mixed_sparse_jacobian_aux( + strict::Val, y, f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -75,6 +75,7 @@ function _prepare_mixed_sparse_jacobian_aux( batch_size_settings_reverse = DI.pick_batchsize(DI.reverse_backend(dense_backend), Nr) return _prepare_mixed_sparse_jacobian_aux_aux( + strict, batch_size_settings_forward, batch_size_settings_reverse, coloring_result, @@ -83,11 +84,11 @@ function _prepare_mixed_sparse_jacobian_aux( backend, x, contexts...; - strict, ) end function _prepare_mixed_sparse_jacobian_aux_aux( + strict::Val, batch_size_settings_forward::DI.BatchSizeSettings{Bf}, batch_size_settings_reverse::DI.BatchSizeSettings{Br}, coloring_result::AbstractColoringResult{:nonsymmetric,:bidirectional}, @@ -96,7 +97,6 @@ function _prepare_mixed_sparse_jacobian_aux_aux( backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {Bf,Br,FY,C} _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) Nf, Af = batch_size_settings_forward.N, batch_size_settings_forward.A @@ -128,20 +128,20 @@ function _prepare_mixed_sparse_jacobian_aux_aux( ] pushforward_prep = DI.prepare_pushforward( + strict, f_or_f!y..., DI.forward_backend(dense_backend), x, batched_seeds_forward[1], contexts...; - strict, ) pullback_prep = DI.prepare_pullback( + strict, f_or_f!y..., DI.reverse_backend(dense_backend), x, batched_seeds_reverse[1], contexts...; - strict, ) return MixedModeSparseJacobianPrep( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 330a70455..105d8a6a1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -7,12 +7,7 @@ struct SymbolicsOneArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( - f, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) @@ -100,7 +95,7 @@ struct SymbolicsOneArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -174,7 +169,7 @@ struct SymbolicsOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} end function DI.prepare_gradient( - f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -244,11 +239,11 @@ struct SymbolicsOneArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( + strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -322,11 +317,11 @@ struct SymbolicsOneArgHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} end function DI.prepare_hessian( + strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -341,7 +336,7 @@ function DI.prepare_hessian( res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false)) (hess_exe, hess_exe!) = res - gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) + gradient_prep = DI.prepare_gradient(strict, f, dense_ad(backend), x, contexts...) return SymbolicsOneArgHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end @@ -411,12 +406,7 @@ struct SymbolicsOneArgHVPPrep{SIG,G,E2,E2!} <: DI.HVPPrep{SIG} end function DI.prepare_hvp( - f, - backend::AutoSymbolics, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) @@ -432,7 +422,7 @@ function DI.prepare_hvp( ) (hvp_exe, hvp_exe!) = res - gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) + gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) return SymbolicsOneArgHVPPrep(_sig, gradient_prep, hvp_exe, hvp_exe!) end @@ -508,7 +498,7 @@ struct SymbolicsOneArgSecondDerivativePrep{SIG,D,E1,E1!} <: DI.SecondDerivativeP end function DI.prepare_second_derivative( - f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -522,7 +512,7 @@ function DI.prepare_second_derivative( elseif res isa RuntimeGeneratedFunction res, nothing end - derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) + derivative_prep = DI.prepare_derivative(strict, f, backend, x, contexts...) return SymbolicsOneArgSecondDerivativePrep(_sig, derivative_prep, der2_exe, der2_exe!) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index eeeb9ed73..58623720e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -7,13 +7,13 @@ struct SymbolicsTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} end function DI.prepare_pushforward( + strict::Val, f!, y, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) dx = first(tx) @@ -105,7 +105,7 @@ struct SymbolicsTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) @@ -183,12 +183,12 @@ struct SymbolicsTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} end function DI.prepare_jacobian( + strict::Val, f!, y, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index fd771bb01..38ecfa117 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -16,12 +16,12 @@ struct TrackerPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( + strict::Val, f, backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Val=Val(false), ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) @@ -92,11 +92,7 @@ end ## Gradient function DI.prepare_gradient( - f, - backend::AutoTracker, - x, - contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoGradientPrep(_sig) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 983af3748..c7a246211 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -28,12 +28,7 @@ struct ZygotePullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} end function DI.prepare_pullback( - f, - backend::AutoZygote, - x, - ty::NTuple, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) return DI.NoPullbackPrep(_sig) @@ -46,10 +41,9 @@ function DI.prepare_pullback_same_point( x, ty::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - _sig = DI.signature(f, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) y, pb = pullback(f, x, map(translate, contexts)...) return ZygotePullbackPrepSamePoint(_sig, y, pb) end @@ -105,7 +99,7 @@ end ## Gradient function DI.prepare_gradient( - f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoGradientPrep(_sig) @@ -145,7 +139,7 @@ end ## Jacobian function DI.prepare_jacobian( - f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false) + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoJacobianPrep(_sig) @@ -194,16 +188,11 @@ struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG} end function DI.prepare_hvp( - f, - backend::AutoZygote, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) fd_prep = DI.prepare_hvp( - f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...; strict + strict, f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) return ZygoteHVPPrep(_sig, fd_prep) end @@ -277,11 +266,7 @@ end ## Hessian function DI.prepare_hessian( - f, - backend::AutoZygote, - x, - contexts::Vararg{DI.GeneralizedConstant,C}; - strict::Val=Val(false), + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) return DI.NoHessianPrep(_sig) diff --git a/DifferentiationInterface/src/fallbacks/change_prep.jl b/DifferentiationInterface/src/fallbacks/change_prep.jl index bf5202c9e..492813391 100644 --- a/DifferentiationInterface/src/fallbacks/change_prep.jl +++ b/DifferentiationInterface/src/fallbacks/change_prep.jl @@ -43,10 +43,10 @@ for op in [ if op in (:derivative, :gradient, :jacobian) # 1-arg @eval function $prep_op!( - f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; + f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} check_prep(f, old_prep, backend, x, contexts...) - return $prep_op(f, backend, x, contexts...; strict=is_strict(old_prep)) + return $prep_op(is_strict(old_prep), f, backend, x, contexts...) end op == :gradient && continue # 2-arg @@ -54,7 +54,7 @@ for op in [ f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} check_prep(f!, y, old_prep, backend, x, contexts...) - return $prep_op(f!, y, backend, x, contexts...; strict=is_strict(old_prep)) + return $prep_op(is_strict(old_prep), f!, y, backend, x, contexts...) end elseif op in (:second_derivative, :hessian) @@ -63,7 +63,7 @@ for op in [ f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} check_prep(f, old_prep, backend, x, contexts...) - return $prep_op(f, backend, x, contexts...; strict=is_strict(old_prep)) + return $prep_op(is_strict(old_prep), f, backend, x, contexts...) end elseif op in (:pushforward, :pullback, :hvp) @@ -77,7 +77,7 @@ for op in [ contexts::Vararg{Context,C}; ) where {F,C} check_prep(f, old_prep, backend, x, seed, contexts...) - return $prep_op(f, backend, x, seed, contexts...; strict=is_strict(old_prep)) + return $prep_op(is_strict(old_prep), f, backend, x, seed, contexts...) end @eval function $prep_op_same_point( f::F, @@ -91,14 +91,14 @@ for op in [ return prep end @eval function $prep_op_same_point( + strict::Val, f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict) + prep = $prep_op(strict, f, backend, x, seed, contexts...) return $prep_op_same_point(f, prep, backend, x, seed, contexts...) end op == :hvp && continue @@ -113,9 +113,7 @@ for op in [ contexts::Vararg{Context,C}, ) where {F,C} check_prep(f!, y, old_prep, backend, x, seed, contexts...) - return $prep_op( - f!, y, backend, x, seed, contexts...; strict=is_strict(old_prep) - ) + return $prep_op(is_strict(old_prep), f!, y, backend, x, seed, contexts...) end @eval function $prep_op_same_point( f!::F, @@ -130,15 +128,15 @@ for op in [ return prep end @eval function $prep_op_same_point( + strict::Val, f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict) + prep = $prep_op(strict, f!, y, backend, x, seed, contexts...) return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) end end diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index bfec092ed..81c7b2bd0 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -45,25 +45,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op!(f, result, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result, prep, backend, x, contexts...) end op == :gradient && continue @@ -71,25 +71,25 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $op(f!, y, prep, backend, x, contexts...) end @eval function $op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $op!(f!, y, result, prep, backend, x, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $val_and_op(f!, y, prep, backend, x, contexts...) end @eval function $val_and_op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $val_and_op!(f!, y, result, prep, backend, x, contexts...) end @@ -98,51 +98,51 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op!(f, result2, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) end elseif op in (:pushforward, :pullback, :hvp) @eval function $op( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) - return $op(f, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $op(f, prep, backend, x, tang, contexts...) end @eval function $op!( f::F, result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) - return $op!(f, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $op!(f, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) - return $val_and_op(f, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $val_and_op(f, prep, backend, x, tang, contexts...) end if op in (:pushforward, :pullback) @@ -151,11 +151,11 @@ for op in [ result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) - return $val_and_op!(f, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $val_and_op!(f, result, prep, backend, x, tang, contexts...) end elseif op == :hvp @eval function $val_and_op!( @@ -164,12 +164,12 @@ for op in [ result2::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...; strict=Val(true)) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) return $val_and_op!( - f, result1, result2, prep, backend, x, seed, contexts... + f, result1, result2, prep, backend, x, tang, contexts... ) end end @@ -177,10 +177,10 @@ for op in [ op == :hvp && continue @eval function $op( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) - return $op(f!, y, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $op(f!, y, prep, backend, x, tang, contexts...) end @eval function $op!( f!::F, @@ -188,17 +188,17 @@ for op in [ result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) - return $op!(f!, y, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $op!(f!, y, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) - return $val_and_op(f!, y, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $val_and_op(f!, y, prep, backend, x, tang, contexts...) end @eval function $val_and_op!( f!::F, @@ -206,11 +206,11 @@ for op in [ result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...; strict=Val(true)) - return $val_and_op!(f!, y, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $val_and_op!(f!, y, result, prep, backend, x, tang, contexts...) end end end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index f0329a639..c8753432d 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -6,7 +6,9 @@ $(docstring_prepare("derivative"; inplace=true)) """ -function prepare_derivative end +function prepare_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_derivative(strict, args...) +end """ prepare!_derivative(f, prep, backend, x, [contexts...]) -> new_prep @@ -64,24 +66,19 @@ struct PushforwardDerivativePrep{SIG,E<:PushforwardPrep} <: DerivativePrep{SIG} end function prepare_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) - pushforward_prep = prepare_pushforward(f, backend, x, (one(x),), contexts...; strict) + pushforward_prep = prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) return PushforwardDerivativePrep(_sig, pushforward_prep) end function prepare_derivative( - f!::F, - y, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, contexts...; strict) pushforward_prep = prepare_pushforward( - f!, y, backend, x, (one(x),), contexts...; strict + strict, f!, y, backend, x, (one(x),), contexts... ) return PushforwardDerivativePrep(_sig, pushforward_prep) end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index e9481c9c5..8ddfb7137 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -5,7 +5,9 @@ $(docstring_prepare("gradient")) """ -function prepare_gradient end +function prepare_gradient(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_gradient(strict, args...) +end """ prepare!_gradient(f, prep, backend, x, [contexts...]) -> new_prep @@ -59,11 +61,11 @@ struct PullbackGradientPrep{SIG,Y,E<:PullbackPrep} <: GradientPrep{SIG} end function prepare_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? - pullback_prep = prepare_pullback(f, backend, x, (one(typeof(y)),), contexts...; strict) + pullback_prep = prepare_pullback(strict, f, backend, x, (one(typeof(y)),), contexts...) return PullbackGradientPrep(_sig, y, pullback_prep) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 3b4f5fff9..201cbea23 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -6,7 +6,9 @@ $(docstring_prepare("jacobian"; inplace=true)) """ -function prepare_jacobian end +function prepare_jacobian(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_jacobian(strict, args...) +end """ prepare!_jacobian(f, prep, backend, x, [contexts...]) -> new_prep @@ -89,7 +91,7 @@ struct PullbackJacobianPrep{ end function prepare_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} y = f(x, map(unwrap, contexts)...) perf = pushforward_performance(backend) @@ -101,17 +103,12 @@ function prepare_jacobian( end # function barrier return _prepare_jacobian_aux( - perf, batch_size_settings, y, (f,), backend, x, contexts...; strict + strict, perf, batch_size_settings, y, (f,), backend, x, contexts... ) end function prepare_jacobian( - f!::F, - y, - backend::AbstractADType, - x, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} perf = pushforward_performance(backend) # type-unstable @@ -122,11 +119,12 @@ function prepare_jacobian( end # function barrier return _prepare_jacobian_aux( - perf, batch_size_settings, y, (f!, y), backend, x, contexts...; strict + strict, perf, batch_size_settings, y, (f!, y), backend, x, contexts... ) end function _prepare_jacobian_aux( + strict::Val, ::PushforwardFast, batch_size_settings::BatchSizeSettings{B}, y, @@ -134,7 +132,6 @@ function _prepare_jacobian_aux( backend::AbstractADType, x, contexts::Vararg{Context,C}; - strict::Val, ) where {B,FY,C} _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings @@ -144,7 +141,7 @@ function _prepare_jacobian_aux( ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = prepare_pushforward( - f_or_f!y..., backend, x, batched_seeds[1], contexts...; strict + strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... ) return PushforwardJacobianPrep( _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep @@ -152,6 +149,7 @@ function _prepare_jacobian_aux( end function _prepare_jacobian_aux( + strict::Val, ::PushforwardSlow, batch_size_settings::BatchSizeSettings{B}, y, @@ -159,7 +157,6 @@ function _prepare_jacobian_aux( backend::AbstractADType, x, contexts::Vararg{Context,C}; - strict::Val, ) where {B,FY,C} _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings @@ -169,7 +166,7 @@ function _prepare_jacobian_aux( ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] pullback_prep = prepare_pullback( - f_or_f!y..., backend, x, batched_seeds[1], contexts...; strict + strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... ) return PullbackJacobianPrep( _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 1837ef7d3..63afa0656 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -6,7 +6,9 @@ $(docstring_prepare("pullback"; inplace=true)) """ -function prepare_pullback end +function prepare_pullback(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pullback(strict, args...) +end """ prepare!_pullback(f, prep, backend, x, ty, [contexts...]) -> new_prep @@ -22,7 +24,9 @@ function prepare!_pullback end $(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ -function prepare_pullback_same_point end +function prepare_pullback_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pullback_same_point(strict, args...) +end """ value_and_pullback(f, [prep,] backend, x, ty, [contexts...]) -> (y, tx) @@ -91,48 +95,44 @@ struct PushforwardPullbackPrep{SIG,E} <: PullbackPrep{SIG} end function prepare_pullback( - f::F, - backend::AbstractADType, - x, - ty::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pullback_aux( - pullback_performance(backend), f, backend, x, ty, contexts...; strict + strict, pullback_performance(backend), f, backend, x, ty, contexts... ) end function prepare_pullback( + strict::Val, f!::F, y, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} return _prepare_pullback_aux( - pullback_performance(backend), f!, y, backend, x, ty, contexts...; strict + strict, pullback_performance(backend), f!, y, backend, x, ty, contexts... ) end function _prepare_pullback_aux( + strict::Val, ::PullbackSlow, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(f, backend, x, (dx,), contexts...; strict) + pushforward_prep = prepare_pushforward(strict, f, backend, x, (dx,), contexts...) return PushforwardPullbackPrep(_sig, pushforward_prep) end function _prepare_pullback_aux( + strict::Val, ::PullbackSlow, f!::F, y, @@ -140,11 +140,10 @@ function _prepare_pullback_aux( x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(f!, y, backend, x, (dx,), contexts...; strict) + pushforward_prep = prepare_pushforward(strict, f!, y, backend, x, (dx,), contexts...) return PushforwardPullbackPrep(_sig, pushforward_prep) end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 304208dc5..883681031 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -6,7 +6,9 @@ $(docstring_prepare("pushforward"; inplace=true)) """ -function prepare_pushforward end +function prepare_pushforward(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pushforward(strict, args...) +end """ prepare!_pushforward(f, prep, backend, x, tx, [contexts...]) -> new_prep @@ -22,7 +24,9 @@ function prepare!_pushforward end $(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ -function prepare_pushforward_same_point end +function prepare_pushforward_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pushforward_same_point(strict, args...) +end """ value_and_pushforward(f, [prep,] backend, x, tx, [contexts...]) -> (y, ty) @@ -91,49 +95,45 @@ struct PullbackPushforwardPrep{SIG,E} <: PushforwardPrep{SIG} end function prepare_pushforward( - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pushforward_aux( - pushforward_performance(backend), f, backend, x, tx, contexts...; strict + strict, pushforward_performance(backend), f, backend, x, tx, contexts... ) end function prepare_pushforward( + strict::Val, f!::F, y, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} return _prepare_pushforward_aux( - pushforward_performance(backend), f!, y, backend, x, tx, contexts...; strict + strict, pushforward_performance(backend), f!, y, backend, x, tx, contexts... ) end function _prepare_pushforward_aux( + strict::Val, ::PushforwardSlow, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(f, backend, x, (dy,), contexts...; strict) + pullback_prep = prepare_pullback(strict, f, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end function _prepare_pushforward_aux( + strict::Val, ::PushforwardSlow, f!::F, y, @@ -141,11 +141,10 @@ function _prepare_pushforward_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(f!, y, backend, x, (dy,), contexts...; strict) + pullback_prep = prepare_pullback(strict, f!, y, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 492fcae6d..35c21a71e 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -42,29 +42,29 @@ struct FromPrimitivePushforwardPrep{SIG,E<:PushforwardPrep} <: PushforwardPrep{S end function prepare_pushforward( + strict::Val, f::F, backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) - primitive_prep = prepare_pushforward(f, backend.backend, x, tx, contexts...; strict) + primitive_prep = prepare_pushforward(strict, f, backend.backend, x, tx, contexts...) return FromPrimitivePushforwardPrep(_sig, primitive_prep) end function prepare_pushforward( + strict::Val, f!::F, y, backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) - primitive_prep = prepare_pushforward(f!, y, backend.backend, x, tx, contexts...; strict) + primitive_prep = prepare_pushforward(strict, f!, y, backend.backend, x, tx, contexts...) return FromPrimitivePushforwardPrep(_sig, primitive_prep) end @@ -159,29 +159,29 @@ struct FromPrimitivePullbackPrep{SIG,E<:PullbackPrep} <: PullbackPrep{SIG} end function prepare_pullback( + strict::Val, f::F, backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) - primitive_prep = prepare_pullback(f, backend.backend, x, ty, contexts...; strict) + primitive_prep = prepare_pullback(strict, f, backend.backend, x, ty, contexts...) return FromPrimitivePullbackPrep(_sig, primitive_prep) end function prepare_pullback( + strict::Val, f!::F, y, backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) - primitive_prep = prepare_pullback(f!, y, backend.backend, x, ty, contexts...; strict) + primitive_prep = prepare_pullback(strict, f!, y, backend.backend, x, ty, contexts...) return FromPrimitivePullbackPrep(_sig, primitive_prep) end diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index 65cbd3c81..bfa6656f4 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -37,25 +37,25 @@ function threshold_batchsize( end function prepare_pushforward( + strict::Val, f::F, backend::AutoSimpleFiniteDiff, x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end function prepare_pushforward( + strict::Val, f!::F, y, backend::AutoSimpleFiniteDiff, x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index 94b64fa4c..1b3c6c0eb 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -21,25 +21,20 @@ check_available(::AutoZeroForward) = true inplace_support(::AutoZeroForward) = InPlaceSupported() function prepare_pushforward( - f::F, - backend::AutoZeroForward, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) end function prepare_pushforward( + strict::Val, f!::F, y, backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) return NoPushforwardPrep(_sig) @@ -124,25 +119,20 @@ check_available(::AutoZeroReverse) = true inplace_support(::AutoZeroReverse) = InPlaceSupported() function prepare_pullback( - f::F, - backend::AutoZeroReverse, - x, - ty::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) return NoPullbackPrep(_sig) end function prepare_pullback( + strict::Val, f!::F, y, backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}; - strict::Val=Val(false), ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) return NoPullbackPrep(_sig) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 2d9e8c5ac..290adb525 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -5,7 +5,9 @@ $(docstring_prepare("hessian")) """ -function prepare_hessian end +function prepare_hessian(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_hessian(strict, args...) +end """ prepare!_hessian(f, backend, x, [contexts...]) -> new_prep @@ -69,21 +71,21 @@ struct HVPGradientHessianPrep{ end function prepare_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} # type-unstable batch_size_settings = pick_batchsize(outer(backend), x) # function barrier - return _prepare_hessian_aux(batch_size_settings, f, backend, x, contexts...; strict) + return _prepare_hessian_aux(strict, batch_size_settings, f, backend, x, contexts...) end function _prepare_hessian_aux( + strict::Val, batch_size_settings::BatchSizeSettings{B}, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; - strict::Val, ) where {B,F,C} _sig = signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings @@ -92,8 +94,8 @@ function _prepare_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = prepare_hvp(f, backend, x, batched_seeds[1], contexts...; strict) - gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) + hvp_prep = prepare_hvp(strict, f, backend, x, batched_seeds[1], contexts...) + gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) return HVPGradientHessianPrep( _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep ) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 3a13c6bf0..8caeb2b64 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -5,7 +5,9 @@ $(docstring_prepare("hvp")) """ -function prepare_hvp end +function prepare_hvp(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_hvp(strict, args...) +end """ prepare!_hvp(f, backend, x, tx, [contexts...]) -> new_prep @@ -19,7 +21,9 @@ function prepare!_hvp end $(docstring_prepare("hvp"; samepoint=true)) """ -function prepare_hvp_same_point end +function prepare_hvp_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_hvp_same_point(strict, args...) +end """ hvp(f, [prep,] backend, x, tx, [contexts...]) -> tg @@ -58,14 +62,10 @@ $(docstring_preparation_hint("hvp"; same_point=true)) function gradient_and_hvp! end function prepare_hvp( - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}; - strict::Val=Val(false), + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_hvp_aux( + strict, hvp_mode(backend), inner_preparation_behavior(outer(backend)), f, @@ -73,7 +73,6 @@ function prepare_hvp( x, tx, contexts...; - strict, ) end @@ -90,6 +89,7 @@ struct ForwardOverAnythingHVPPrep{SIG,G,GO,GI,PO,PI} <: HVPPrep{SIG} end function _prepare_hvp_aux( + strict::Val, ::ForwardOverAnything, ::DontPrepareInner, f::F, @@ -97,7 +97,6 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) @@ -107,17 +106,17 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts...; strict + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( + strict, shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts...; - strict, ) else nothing @@ -128,6 +127,7 @@ function _prepare_hvp_aux( end function _prepare_hvp_aux( + strict::Val, ::ForwardOverAnything, ::PrepareInnerSimple, f::F, @@ -135,13 +135,12 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient - inner_gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) + inner_gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) inner_gradient_in_prep = inner_gradient_prep # Outer pushforward new_contexts = ( @@ -159,17 +158,17 @@ function _prepare_hvp_aux( contexts..., ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts...; strict + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( + strict, shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in...; - strict, ) else nothing @@ -185,6 +184,7 @@ function _prepare_hvp_aux( end function _prepare_hvp_aux( + strict::Val, ::ForwardOverAnything, ::PrepareInnerOverload, f::F, @@ -192,7 +192,6 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) @@ -204,8 +203,8 @@ function _prepare_hvp_aux( ) contextso = adapt_eltype.(contexts, Ref(eltype(xo))) contextsoi = adapt_eltype.(contexts, Ref(eltype(xoi))) - inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contextso...; strict) - inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contextsoi...; strict) + inner_gradient_prep = prepare_gradient(strict, f, inner(backend), xo, contextso...) + inner_gradient_in_prep = prepare_gradient(strict, f, inner(backend), xoi, contextsoi...) # Outer pushforward new_contexts = ( FunctionContext(f), @@ -222,17 +221,17 @@ function _prepare_hvp_aux( contexts..., ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts...; strict + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( + strict, shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in...; - strict, ) else nothing @@ -460,6 +459,7 @@ struct ReverseOverForwardHVPPrep{SIG,G2<:GradientPrep,G1<:GradientPrep} <: HVPPr end function _prepare_hvp_aux( + strict::Val, ::ReverseOverForward, ::InnerPreparationBehavior, f::F, @@ -467,7 +467,6 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) @@ -479,9 +478,9 @@ function _prepare_hvp_aux( contexts..., ) outer_gradient_prep = prepare_gradient( - shuffled_single_pushforward, outer(backend), x, new_contexts...; strict + strict, shuffled_single_pushforward, outer(backend), x, new_contexts... ) - gradient_prep = prepare_gradient(f, inner(backend), x, contexts...; strict) + gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) return ReverseOverForwardHVPPrep(_sig, outer_gradient_prep, gradient_prep) end @@ -582,6 +581,7 @@ struct ReverseOverReverseHVPPrep{SIG,G,PO,PI} <: HVPPrep{SIG} end function _prepare_hvp_aux( + strict::Val, ::ReverseOverReverse, ::InnerPreparationBehavior, f::F, @@ -589,7 +589,6 @@ function _prepare_hvp_aux( x, tx::NTuple, contexts::Vararg{Context,C}; - strict::Val, ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) @@ -598,17 +597,17 @@ function _prepare_hvp_aux( ) grad_buffer = similar(x) outer_pullback_prep = prepare_pullback( - shuffled_gradient, outer(backend), x, tx, new_contexts...; strict + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pullback_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pullback( + strict, shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts...; - strict, ) else nothing diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 72f4b8855..f26f335e6 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -5,7 +5,9 @@ $(docstring_prepare("second_derivative")) """ -function prepare_second_derivative end +function prepare_second_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_second_derivative(strict, args...) +end """ prepare!_second_derivative(f, prep, backend, x, [contexts...]) -> new_prep @@ -58,7 +60,7 @@ struct DerivativeSecondDerivativePrep{SIG,E<:DerivativePrep} <: SecondDerivative end function prepare_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val=Val(false) + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) rewrap = Rewrap(contexts...) @@ -66,7 +68,7 @@ function prepare_second_derivative( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) outer_derivative_prep = prepare_derivative( - shuffled_derivative, outer(backend), x, new_contexts...; strict + strict, shuffled_derivative, outer(backend), x, new_contexts... ) return DerivativeSecondDerivativePrep(_sig, outer_derivative_prep) end diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 91450959c..cb3bb6ff6 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -103,7 +103,7 @@ function Base.showerror( end println( io, - "To disable this check (not recommended), run preparation with the keyword argument `strict=Val(false)` when using DifferentiationInterface.", + "If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(false)` inside DifferentiationInterface.", ) return nothing end diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 0571a73b3..6d56dabd4 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -17,24 +17,10 @@ for backend in zero_backends end @testset "Type stability" begin - test_differentiation( - AutoZeroForward(), - default_scenarios(; include_batchified=false, include_constantified=true); - correctness=false, - type_stability=:full, - logging=LOGGING, - ) - - test_differentiation( - AutoZeroReverse(), - default_scenarios(; include_batchified=false, include_constantified=true); - correctness=false, - type_stability=:full, - logging=LOGGING, - ) - test_differentiation( [ + AutoZeroForward(), + AutoZeroReverse(), SecondOrder(AutoZeroForward(), AutoZeroReverse()), SecondOrder(AutoZeroReverse(), AutoZeroForward()), ], From 534f98125fbfb8ff17b33969823bb5a607ffa539 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 15:50:20 +0100 Subject: [PATCH 20/22] Fix --- .../twoarg.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 1c7a8c92a..74d460cc3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -6,13 +6,13 @@ struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SI end function DI.prepare_pushforward( + strict::Val, f!, y, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) single_threaded_prep = DI.prepare_pushforward( @@ -91,11 +91,11 @@ struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} end function DI.prepare_derivative( - f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}; strict::Val + strict::Val, f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) single_threaded_prep = DI.prepare_derivative( - f!, y, single_threaded(backend), x, contexts... + strict, f!, y, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffTwoArgDerivativePrep(_sig, single_threaded_prep) end @@ -167,12 +167,12 @@ struct PolyesterForwardDiffTwoArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPre end function DI.prepare_jacobian( + strict::Val, f!, y, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}; - strict::Val, ) where {chunksize,C} _sig = DI.signature(f!, y, backend, x, contexts...; strict) if isnothing(chunksize) @@ -181,7 +181,7 @@ function DI.prepare_jacobian( chunk = Chunk{chunksize}() end single_threaded_prep = DI.prepare_jacobian( - f!, y, single_threaded(backend), x, contexts... + strict, f!, y, single_threaded(backend), x, contexts... ) return PolyesterForwardDiffTwoArgJacobianPrep(_sig, chunk, single_threaded_prep) end From 594d7209199cb00b6b3aa62be67f6665db50de49 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:56:18 +0100 Subject: [PATCH 21/22] Last fix --- .../differentiate_with.jl | 2 +- .../reverse_onearg.jl | 3 +-- .../DifferentiationInterfaceTrackerExt.jl | 2 +- DifferentiationInterface/src/docstrings.jl | 3 +-- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index f226298d7..be85d1f24 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,); strict=Val(true)) + prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,)) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 6d7f0d517..7c68b20e3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -19,7 +19,6 @@ function DI.prepare_pullback( end function DI.prepare_pullback_same_point( - strict, f, prep::DI.NoPullbackPrep, backend::AutoReverseChainRules, @@ -28,7 +27,7 @@ function DI.prepare_pullback_same_point( contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - _sig = DI.signature(f, backend, x, ty, contexts...; strict) + _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) return ChainRulesPullbackPrepSamePoint(_sig, y, pb) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 38ecfa117..aaeae6cc7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -35,8 +35,8 @@ function DI.prepare_pullback_same_point( ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} - _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) DI.check_prep(f, prep, backend, x, ty, contexts...) + _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(_sig, y, pb) end diff --git a/DifferentiationInterface/src/docstrings.jl b/DifferentiationInterface/src/docstrings.jl index e2d1ee006..e14ccfb52 100644 --- a/DifferentiationInterface/src/docstrings.jl +++ b/DifferentiationInterface/src/docstrings.jl @@ -19,6 +19,7 @@ function docstring_prepare(operator; samepoint=false, inplace=false) Create a `prep` object that can be given to [`$(operator)`](@ref) and its variants to speed them up$(samepoint_warning(samepoint)). Depending on the backend, this can have several effects (preallocating memory, recording an execution trace) which are transparent to the user. + $(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "") !!! warning The preparation result `prep` is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified. @@ -26,8 +27,6 @@ function docstring_prepare(operator; samepoint=false, inplace=false) In some settings, invalid preparations may still give correct results (e.g. for backends that require no preparation), but this is not a semantic guarantee and should not be relied upon. When `strict=Val(true)`, type checking is enforced between preparation and execution (but size checking is left to the user). - - $(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "") """ end From 25f91a4aa4fef1a2c33613047c13725e7a27129f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Mar 2025 17:58:41 +0100 Subject: [PATCH 22/22] Versions --- .github/workflows/Test.yml | 2 +- DifferentiationInterface/Project.toml | 2 +- DifferentiationInterfaceTest/Project.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 13de31d96..f91ce0853 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - "1.10" diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 5017e4ebf..e29a979c0 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.44" +version = "0.6.45" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 8c9c6fb84..729087caf 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.9.5" +version = "0.9.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"