diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 5017e4ebf..e29a979c0 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.44" +version = "0.6.45" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index 2eb8d2e93..be85d1f24 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,7 +1,7 @@ function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) (; f, backend) = dw y = f(x) - prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,)) + prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,)) function pullbackfunc(dy) tx = DI.pullback(f, prep_same, backend, x, (dy,)) return (NoTangent(), only(tx)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 079e5dd1a..7c68b20e3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -1,37 +1,47 @@ ## Pullback -struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep +struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} y::Y pb::PB end function DI.prepare_pullback( - f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} + strict::Val, + f, + backend::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} - return DI.NoPullbackPrep() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.GeneralizedConstant,C}, + contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) + _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) - return ChainRulesPullbackPrepSamePoint(y, pb) + return ChainRulesPullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -43,11 +53,12 @@ end function DI.value_and_pullback( f, prep::ChainRulesPullbackPrepSamePoint, - ::AutoReverseChainRules, + backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy unthunk(pb(dy)[2]) @@ -58,11 +69,12 @@ end function DI.pullback( f, prep::ChainRulesPullbackPrepSamePoint, - ::AutoReverseChainRules, + backend::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy unthunk(pb(dy)[2]) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 2973f3b37..baab8a75f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -10,9 +10,15 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward -DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep() +function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple) + _sig = DI.signature(f, backend, x, tx; strict) + return DI.NoPushforwardPrep(_sig) +end -function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple) +function DI.pushforward( + f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple +) + DI.check_prep(f, prep, backend, x, tx) ty = map(tx) do dx # code copied from Diffractor.jl z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx)) @@ -25,6 +31,7 @@ end function DI.value_and_pushforward( f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple ) + DI.check_prep(f, prep, backend, x, tx) return f(x), DI.pushforward(f, prep, backend, x, tx) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 72c6df2df..96b2fff76 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,23 +1,26 @@ ## Pushforward function DI.prepare_pushforward( + strict::Val, f::F, - ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {F,C} - return DI.NoPushforwardPrep() + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep(_sig) end function DI.value_and_pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) dx = only(tx) @@ -29,12 +32,13 @@ end function DI.value_and_pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode, Val(B)) x_and_tx = BatchDuplicated(x, tx) @@ -45,12 +49,13 @@ end function DI.pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) dx = only(tx) @@ -62,12 +67,13 @@ end function DI.pushforward( f::F, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode, Val(B)) x_and_tx = BatchDuplicated(x, tx) @@ -85,6 +91,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) # dy cannot be passed anyway y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) @@ -100,6 +107,7 @@ function DI.pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) # dy cannot be passed anyway new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) @@ -108,32 +116,33 @@ end ## Gradient -struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep +struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG} + _sig::Val{SIG} + _valB::Val{B} shadows::O end -function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O} - return EnzymeForwardGradientPrep{B,O}(shadows) -end - function DI.prepare_gradient( + strict::Val, f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.Constant,C}; ) where {F,C} + _sig = DI.signature(f, backend, x, contexts...; strict) valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) - return EnzymeForwardGradientPrep(valB, shadows) + return EnzymeForwardGradientPrep(_sig, valB, shadows) end function DI.gradient( f::F, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -145,11 +154,12 @@ end function DI.value_and_gradient( f::F, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -162,58 +172,59 @@ end function DI.gradient!( f::F, grad, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end function DI.value_and_gradient!( f::F, grad, - prep::EnzymeForwardGradientPrep{B}, + prep::EnzymeForwardGradientPrep{SIG,B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end ## Jacobian -struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep +struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} + _valB::Val{B} shadows::O output_length::Int end -function EnzymeForwardOneArgJacobianPrep( - ::Val{B}, shadows::O, output_length::Integer -) where {B,O} - return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length) -end - function DI.prepare_jacobian( + strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, - contexts::Vararg{DI.Constant,C}, + contexts::Vararg{DI.Constant,C}; ) where {F,C} + _sig = DI.signature(f, backend, x, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) - return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y)) + return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y)) end function DI.jacobian( f::F, - prep::EnzymeForwardOneArgJacobianPrep{B}, + prep::EnzymeForwardOneArgJacobianPrep{SIG,B}, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -226,11 +237,12 @@ end function DI.value_and_jacobian( f::F, - prep::EnzymeForwardOneArgJacobianPrep{B}, + prep::EnzymeForwardOneArgJacobianPrep{SIG,B}, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, -) where {F,B,C} +) where {F,SIG,B,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(B), contexts...) @@ -249,6 +261,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -260,6 +273,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 5055a23af..33e5ce1aa 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,25 +1,28 @@ ## Pushforward function DI.prepare_pushforward( + strict::Val, f!::F, y, - ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {F,C} - return DI.NoPushforwardPrep() + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep(_sig) end function DI.value_and_pushforward( f!::F, y, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) dx = only(tx) @@ -34,12 +37,13 @@ end function DI.value_and_pushforward( f!::F, y, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) ty = ntuple(_ -> make_zero(y), Val(B)) @@ -59,6 +63,7 @@ function DI.pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) _, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...) return ty end @@ -67,12 +72,13 @@ function DI.value_and_pushforward!( f!::F, y, ty::NTuple{B}, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) mode = forward_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) x_and_tx = BatchDuplicated(x, tx) @@ -92,6 +98,7 @@ function DI.pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index a6ceda51a..205a86837 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -47,19 +47,22 @@ end ## Pullback -struct EnzymeReverseOneArgPullbackPrep{Y} <: DI.PullbackPrep +struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} y_example::Y # useful to create return activity end function DI.prepare_pullback( + strict::Val, f::F, - ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ty::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {F,C} + _sig = DI.signature(f, backend, x, ty, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) - return EnzymeReverseOneArgPullbackPrep(y) + return EnzymeReverseOneArgPullbackPrep(_sig, y) end ### Out-of-place @@ -72,6 +75,7 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode)) IA = guess_activity(typeof(x), mode) @@ -97,6 +101,7 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) IA = batchify_activity(guess_activity(typeof(x), mode), Val(B)) @@ -122,6 +127,7 @@ function DI.pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return last(DI.value_and_pullback(f, prep, backend, x, ty, contexts...)) end @@ -136,6 +142,7 @@ function DI.value_and_pullback!( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode)) RA = guess_activity(typeof(prep.y_example), mode) @@ -157,6 +164,7 @@ function DI.value_and_pullback!( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) mode = reverse_split_withprimal(backend) f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B)) @@ -177,26 +185,33 @@ function DI.pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return last(DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)) end ## Gradient function DI.prepare_gradient( - f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C} + strict::Val, + f::F, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, + x, + contexts::Vararg{DI.Context,C}; ) where {F,C} - return DI.NoGradientPrep() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end ### Enzyme gradient API (only constants) function DI.gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) @@ -206,11 +221,12 @@ end function DI.value_and_gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) @@ -221,10 +237,11 @@ end function DI.gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} + DI.check_prep(f, prep, backend, x) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) gradient!(mode, grad, f_and_df, x) @@ -234,10 +251,11 @@ end function DI.value_and_gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} + DI.check_prep(f, prep, backend, x) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) _, result = gradient!(mode, grad, f_and_df, x) @@ -248,11 +266,12 @@ end function DI.gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) @@ -271,11 +290,12 @@ end function DI.value_and_gradient( f::F, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) @@ -295,11 +315,12 @@ end function DI.gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) @@ -311,11 +332,12 @@ end function DI.value_and_gradient!( f::F, grad, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) mode = reverse_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) annotated_contexts = translate(backend, mode, Val(1), contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 091b695e4..30562fb7f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,18 +1,22 @@ ## Pullback -struct EnzymeReverseTwoArgPullbackPrep{TY} <: DI.PullbackPrep +struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} ty_copy::TY end function DI.prepare_pullback( + strict::Val, f!::F, y, - ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, + backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, ty::NTuple, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {F,C} - return EnzymeReverseTwoArgPullbackPrep(map(copy, ty)) + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) + ty_copy = map(copy, ty) + return EnzymeReverseTwoArgPullbackPrep(_sig, ty_copy) end function DI.value_and_pullback( @@ -24,6 +28,7 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) copyto!(only(prep.ty_copy), only(ty)) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) @@ -46,6 +51,7 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, prep.ty_copy, ty) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) @@ -68,6 +74,7 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) copyto!(only(prep.ty_copy), only(ty)) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) @@ -89,6 +96,7 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, prep.ty_copy, ty) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) @@ -111,6 +119,7 @@ function DI.value_and_pullback!( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) copyto!(only(prep.ty_copy), only(ty)) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode) @@ -134,6 +143,7 @@ function DI.value_and_pullback!( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, prep.ty_copy, ty) mode = reverse_noprimal(backend) f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index c189d09d9..8b1550532 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -9,20 +9,20 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B) ## Annotations @inline function get_f_and_df( - f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) + f::F, backend::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) ) where {F,M,B} return f end @inline function get_f_and_df( - f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) + f::F, backend::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) ) where {F,M,B} return Const(f) end @inline function get_f_and_df( f::F, - ::AutoEnzyme{ + backend::AutoEnzyme{ M, <:Union{ Duplicated, @@ -48,13 +48,13 @@ force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) @inline function _translate( - ::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant + backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant ) where {B} return Const(DI.unwrap(c)) end @inline function _translate( - ::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache + backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache ) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) @@ -100,8 +100,10 @@ function reverse_split_withprimal(backend::AutoEnzyme{Nothing}) return set_err(ReverseSplitWithPrimal, backend) end -set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) -set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode +function set_err(mode::Mode, backend::AutoEnzyme{<:Any,Nothing}) + return EnzymeCore.set_err_if_func_written(mode) +end +set_err(mode::Mode, backend::AutoEnzyme{<:Any,<:Annotation}) = mode function maybe_reshape(A::AbstractMatrix, m, n) @assert size(A) == (m, n) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 6d19c4ce3..0ebfdb7dc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -1,14 +1,21 @@ ## Pushforward -struct FastDifferentiationOneArgPushforwardPrep{Y,E1,E1!} <: DI.PushforwardPrep +struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} y_prototype::Y jvp_exe::E1 jvp_exe!::E1! end function DI.prepare_pushforward( - f, ::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -23,17 +30,18 @@ function DI.prepare_pushforward( jvp_exe! = make_function( jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationOneArgPushforwardPrep(y_prototype, jvp_exe, jvp_exe!) + return FastDifferentiationOneArgPushforwardPrep(_sig, y_prototype, jvp_exe, jvp_exe!) end function DI.pushforward( f, prep::FastDifferentiationOneArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ty = map(tx) do dx result = prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number @@ -49,11 +57,12 @@ function DI.pushforward!( f, ty::NTuple, prep::FastDifferentiationOneArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.jvp_exe!(myvec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -69,6 +78,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward(f, prep, backend, x, tx, contexts...) end @@ -82,20 +92,28 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Pullback -struct FastDifferentiationOneArgPullbackPrep{E1,E1!} <: DI.PullbackPrep +struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} vjp_exe::E1 vjp_exe!::E1! end function DI.prepare_pullback( - f, ::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -110,17 +128,18 @@ function DI.prepare_pullback( vjp_exe! = make_function( vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationOneArgPullbackPrep(vjp_exe, vjp_exe!) + return FastDifferentiationOneArgPullbackPrep(_sig, vjp_exe, vjp_exe!) end function DI.pullback( f, prep::FastDifferentiationOneArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) tx = map(ty) do dy result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) if x isa Number @@ -136,11 +155,12 @@ function DI.pullback!( f, tx::NTuple, prep::FastDifferentiationOneArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.vjp_exe!(myvec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) @@ -156,6 +176,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pullback(f, prep, backend, x, ty, contexts...) end @@ -169,21 +190,24 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pullback!(f, tx, prep, backend, x, ty, contexts...) end ## Derivative -struct FastDifferentiationOneArgDerivativePrep{Y,E1,E1!} <: DI.DerivativePrep +struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} y_prototype::Y der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f, ::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -195,16 +219,17 @@ function DI.prepare_derivative( der_vec_var = derivative(y_vec_var, x_var) der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgDerivativePrep(y_prototype, der_exe, der_exe!) + return FastDifferentiationOneArgDerivativePrep(_sig, y_prototype, der_exe, der_exe!) end function DI.derivative( f, prep::FastDifferentiationOneArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) result = prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number return only(result) @@ -217,10 +242,11 @@ function DI.derivative!( f, der, prep::FastDifferentiationOneArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return der end @@ -232,6 +258,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative(f, prep, backend, x, contexts...) end @@ -244,20 +271,23 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient -struct FastDifferentiationOneArgGradientPrep{E1,E1!} <: DI.GradientPrep +struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end function DI.prepare_gradient( - f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -268,16 +298,17 @@ function DI.prepare_gradient( jac_var = jacobian(y_vec_var, x_vec_var) jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgGradientPrep(jac_exe, jac_exe!) + return FastDifferentiationOneArgGradientPrep(_sig, jac_exe, jac_exe!) end function DI.gradient( f, prep::FastDifferentiationOneArgGradientPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) grad_vec = @view jac[1, :] return reshape(grad_vec, size(x)) @@ -287,10 +318,11 @@ function DI.gradient!( f, grad, prep::FastDifferentiationOneArgGradientPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(reshape(grad, 1, length(grad)), myvec(x), map(myvec_unwrap, contexts)...) return grad end @@ -302,6 +334,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end @@ -313,24 +346,28 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian -struct FastDifferentiationOneArgJacobianPrep{Y,E1,E1!} <: DI.JacobianPrep +struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} y_prototype::Y jac_exe::E1 jac_exe!::E1! end function DI.prepare_jacobian( + strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -346,16 +383,17 @@ function DI.prepare_jacobian( end jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationOneArgJacobianPrep(y_prototype, jac_exe, jac_exe!) + return FastDifferentiationOneArgJacobianPrep(_sig, y_prototype, jac_exe, jac_exe!) end function DI.jacobian( f, prep::FastDifferentiationOneArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) end @@ -363,10 +401,11 @@ function DI.jacobian!( f, jac, prep::FastDifferentiationOneArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac end @@ -378,6 +417,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -389,14 +429,16 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Second derivative -struct FastDifferentiationAllocatingSecondDerivativePrep{Y,D,E2,E2!} <: - DI.SecondDerivativePrep +struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <: + DI.SecondDerivativePrep{SIG} + _sig::Val{SIG} y_prototype::Y derivative_prep::D der2_exe::E2 @@ -404,8 +446,9 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{Y,D,E2,E2!} <: end function DI.prepare_second_derivative( - f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) y_prototype = f(x, map(DI.unwrap, contexts)...) x_var = variablize(x, :x) context_vars = variablize(contexts) @@ -421,17 +464,18 @@ function DI.prepare_second_derivative( derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) return FastDifferentiationAllocatingSecondDerivativePrep( - y_prototype, derivative_prep, der2_exe, der2_exe! + _sig, y_prototype, derivative_prep, der2_exe, der2_exe! ) end function DI.second_derivative( f, prep::FastDifferentiationAllocatingSecondDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) result = prep.der2_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number return only(result) @@ -444,10 +488,11 @@ function DI.second_derivative!( f, der2, prep::FastDifferentiationAllocatingSecondDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der2_exe!(myvec(der2), myvec(x), map(myvec_unwrap, contexts)...) return der2 end @@ -459,6 +504,7 @@ function DI.value_derivative_and_second_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) der2 = DI.second_derivative(f, prep, backend, x, contexts...) return y, der, der2 @@ -473,6 +519,7 @@ function DI.value_derivative_and_second_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) DI.second_derivative!(f, der2, prep, backend, x, contexts...) return y, der, der2 @@ -480,15 +527,22 @@ end ## HVP -struct FastDifferentiationHVPPrep{E2,E2!,E1} <: DI.HVPPrep +struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG} + sig::Val{SIG} hvp_exe::E2 hvp_exe!::E2! gradient_prep::E1 end function DI.prepare_hvp( - f, backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -504,17 +558,18 @@ function DI.prepare_hvp( ) gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) - return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep) + return FastDifferentiationHVPPrep(_sig, hvp_exe, hvp_exe!, gradient_prep) end function DI.hvp( f, prep::FastDifferentiationHVPPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) tg = map(tx) do dx dg_vec = prep.hvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) return reshape(dg_vec, size(x)) @@ -526,11 +581,12 @@ function DI.hvp!( f, tg::NTuple, prep::FastDifferentiationHVPPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] prep.hvp_exe!(myvec(dg), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -546,6 +602,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) tg = DI.hvp(f, prep, backend, x, tx, contexts...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -561,6 +618,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hvp!(f, tg, prep, backend, x, tx, contexts...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -568,18 +626,21 @@ end ## Hessian -struct FastDifferentiationHessianPrep{G,E2,E2!} <: DI.HessianPrep +struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} + _sig::Val{SIG} gradient_prep::G hess_exe::E2 hess_exe!::E2! end function DI.prepare_hessian( + strict::Val, f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = f(x_var, context_vars...) @@ -596,7 +657,7 @@ function DI.prepare_hessian( hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true) gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) - return FastDifferentiationHessianPrep(gradient_prep, hess_exe, hess_exe!) + return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end function DI.hessian( @@ -606,6 +667,7 @@ function DI.hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.hess_exe(myvec(x), map(myvec_unwrap, contexts)...) end @@ -617,6 +679,7 @@ function DI.hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.hess_exe!(hess, myvec(x), map(myvec_unwrap, contexts)...) return hess end @@ -628,6 +691,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, dense_ad(backend), x, contexts... ) @@ -644,6 +708,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index 6133213e3..f67ed4324 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -1,13 +1,21 @@ ## Pushforward -struct FastDifferentiationTwoArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep +struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} jvp_exe::E1 jvp_exe!::E1! end function DI.prepare_pushforward( - f!, y, ::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f!, + y, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -23,18 +31,19 @@ function DI.prepare_pushforward( jvp_exe! = make_function( jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationTwoArgPushforwardPrep(jvp_exe, jvp_exe!) + return FastDifferentiationTwoArgPushforwardPrep(_sig, jvp_exe, jvp_exe!) end function DI.pushforward( f!, y, prep::FastDifferentiationTwoArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx reshape(prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...), size(y)) end @@ -46,11 +55,12 @@ function DI.pushforward!( y, ty::NTuple, prep::FastDifferentiationTwoArgPushforwardPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.jvp_exe!(myvec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) @@ -67,6 +77,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -82,6 +93,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -89,14 +101,22 @@ end ## Pullback -struct FastDifferentiationTwoArgPullbackPrep{E1,E1!} <: DI.PullbackPrep +struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} vjp_exe::E1 vjp_exe!::E1! end function DI.prepare_pullback( - f!, y, ::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f!, + y, + backend::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -112,18 +132,19 @@ function DI.prepare_pullback( vjp_exe! = make_function( vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true ) - return FastDifferentiationTwoArgPullbackPrep(vjp_exe, vjp_exe!) + return FastDifferentiationTwoArgPullbackPrep(_sig, vjp_exe, vjp_exe!) end function DI.pullback( f!, y, prep::FastDifferentiationTwoArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) if x isa Number @@ -140,11 +161,12 @@ function DI.pullback!( y, tx::NTuple, prep::FastDifferentiationTwoArgPullbackPrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.vjp_exe!(myvec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) @@ -161,6 +183,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = DI.pullback(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, tx @@ -176,6 +199,7 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) DI.pullback!(f!, y, tx, prep, backend, x, ty, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, tx @@ -183,14 +207,16 @@ end ## Derivative -struct FastDifferentiationTwoArgDerivativePrep{E1,E1!} <: DI.DerivativePrep +struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f!, y, ::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} + strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -202,17 +228,18 @@ function DI.prepare_derivative( der_vec_var = derivative(y_vec_var, x_var) der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationTwoArgDerivativePrep(der_exe, der_exe!) + return FastDifferentiationTwoArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.value_and_derivative( f!, y, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) return y, der @@ -223,10 +250,11 @@ function DI.value_and_derivative!( y, der, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return y, der @@ -236,10 +264,11 @@ function DI.derivative( f!, y, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) return der end @@ -249,28 +278,32 @@ function DI.derivative!( y, der, prep::FastDifferentiationTwoArgDerivativePrep, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...) return der end ## Jacobian -struct FastDifferentiationTwoArgJacobianPrep{E1,E1!} <: DI.JacobianPrep +struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end function DI.prepare_jacobian( + strict::Val, f!, y, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) y_var = variablize(y, :y) @@ -286,17 +319,18 @@ function DI.prepare_jacobian( end jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) - return FastDifferentiationTwoArgJacobianPrep(jac_exe, jac_exe!) + return FastDifferentiationTwoArgJacobianPrep(_sig, jac_exe, jac_exe!) end function DI.value_and_jacobian( f!, y, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) return y, jac @@ -307,10 +341,11 @@ function DI.value_and_jacobian!( y, jac, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return y, jac @@ -320,10 +355,11 @@ function DI.jacobian( f!, y, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) return jac end @@ -333,10 +369,11 @@ function DI.jacobian!( y, jac, prep::FastDifferentiationTwoArgJacobianPrep, - ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 05cf3a999..7ebdb6e49 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -1,6 +1,7 @@ ## Pushforward -struct FiniteDiffOneArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep +struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -8,8 +9,9 @@ struct FiniteDiffOneArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep end function DI.prepare_pushforward( - f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) cache = if x isa Number || y isa Number @@ -28,17 +30,18 @@ function DI.prepare_pushforward( backend.relstep end dir = backend.dir - return FiniteDiffOneArgPushforwardPrep(cache, relstep, absstep, dir) + return FiniteDiffOneArgPushforwardPrep(_sig, cache, relstep, absstep, dir) end function DI.pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{Nothing}, + prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) ty = map(tx) do dx @@ -51,12 +54,13 @@ end function DI.value_and_pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{Nothing}, + prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...) y = f(x, map(DI.unwrap, contexts)...) @@ -77,12 +81,13 @@ end function DI.pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache}, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) ty = map(tx) do dx @@ -93,12 +98,13 @@ end function DI.value_and_pushforward( f, - prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache}, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -110,7 +116,8 @@ end ## Derivative -struct FiniteDiffOneArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep +struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -118,8 +125,9 @@ struct FiniteDiffOneArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep end function DI.prepare_derivative( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) cache = if y isa Number @@ -139,18 +147,19 @@ function DI.prepare_derivative( backend.relstep end dir = backend.dir - return FiniteDiffOneArgDerivativePrep(cache, relstep, absstep, dir) + return FiniteDiffOneArgDerivativePrep(_sig, cache, relstep, absstep, dir) end ### Scalar to scalar function DI.derivative( f, - prep::FiniteDiffOneArgDerivativePrep{Nothing}, + prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_derivative(fc, x, fdtype(backend); relstep, absstep, dir) @@ -158,11 +167,12 @@ end function DI.value_and_derivative( f, - prep::FiniteDiffOneArgDerivativePrep{Nothing}, + prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing}, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -178,11 +188,12 @@ end function DI.derivative( f, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) @@ -191,11 +202,12 @@ end function DI.derivative!( f, der, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir) @@ -203,11 +215,12 @@ end function DI.value_and_derivative( f, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) (; relstep, absstep, dir) = prep y = fc(x) @@ -217,11 +230,12 @@ end function DI.value_and_derivative!( f, der, - prep::FiniteDiffOneArgDerivativePrep{<:GradientCache}, - ::AutoFiniteDiff, + prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache}, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return ( @@ -231,7 +245,8 @@ end ## Gradient -struct FiniteDiffGradientPrep{C,R,A,D} <: DI.GradientPrep +struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -239,8 +254,9 @@ struct FiniteDiffGradientPrep{C,R,A,D} <: DI.GradientPrep end function DI.prepare_gradient( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) df = zero(y) .* x @@ -256,16 +272,17 @@ function DI.prepare_gradient( backend.relstep end dir = backend.dir - return FiniteDiffGradientPrep(cache, relstep, absstep, dir) + return FiniteDiffGradientPrep(_sig, cache, relstep, absstep, dir) end function DI.gradient( f, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) @@ -274,10 +291,11 @@ end function DI.value_and_gradient( f, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return fc(x), finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) @@ -287,10 +305,11 @@ function DI.gradient!( f, grad, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir) @@ -300,10 +319,11 @@ function DI.value_and_gradient!( f, grad, prep::FiniteDiffGradientPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x::AbstractArray, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return ( @@ -313,7 +333,8 @@ end ## Jacobian -struct FiniteDiffOneArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep +struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -321,8 +342,9 @@ struct FiniteDiffOneArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep end function DI.prepare_jacobian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) x1 = similar(x) @@ -340,16 +362,17 @@ function DI.prepare_jacobian( backend.relstep end dir = backend.dir - return FiniteDiffOneArgJacobianPrep(cache, relstep, absstep, dir) + return FiniteDiffOneArgJacobianPrep(_sig, cache, relstep, absstep, dir) end function DI.jacobian( f, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_jacobian(fc, x, prep.cache; relstep, absstep, dir) @@ -358,10 +381,11 @@ end function DI.value_and_jacobian( f, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) (; relstep, absstep, dir) = prep y = fc(x) @@ -372,10 +396,11 @@ function DI.jacobian!( f, jac, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) return copyto!( @@ -390,10 +415,11 @@ function DI.value_and_jacobian!( f, jac, prep::FiniteDiffOneArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -410,7 +436,8 @@ end ## Hessian -struct FiniteDiffHessianPrep{C1,C2,RG,AG,RH,AH} <: DI.HessianPrep +struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG} + _sig::Val{SIG} gradient_cache::C1 hessian_cache::C2 relstep_g::RG @@ -420,8 +447,9 @@ struct FiniteDiffHessianPrep{C1,C2,RG,AG,RH,AH} <: DI.HessianPrep end function DI.prepare_hessian( - f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) fc = DI.with_contexts(f, contexts...) y = fc(x) df = zero(y) .* x @@ -440,13 +468,18 @@ function DI.prepare_hessian( absstep_g = isnothing(backend.absstep) ? relstep_g : backend.absstep absstep_h = isnothing(backend.absstep) ? relstep_h : backend.absstep return FiniteDiffHessianPrep( - gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h + _sig, gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h ) end function DI.hessian( - f, prep::FiniteDiffHessianPrep, ::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_hessian( @@ -458,10 +491,11 @@ function DI.hessian!( f, hess, prep::FiniteDiffHessianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) return finite_difference_hessian!( @@ -470,8 +504,13 @@ function DI.hessian!( end function DI.value_gradient_and_hessian( - f, prep::FiniteDiffHessianPrep, ::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::FiniteDiffHessianPrep, + backend::AutoFiniteDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) grad = finite_difference_gradient( @@ -488,10 +527,11 @@ function DI.value_gradient_and_hessian!( grad, hess, prep::FiniteDiffHessianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep fc = DI.with_contexts(f, contexts...) finite_difference_gradient!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index abd76ed68..259ebbdd3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -1,6 +1,7 @@ ## Pushforward -struct FiniteDiffTwoArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep +struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -8,8 +9,15 @@ struct FiniteDiffTwoArgPushforwardPrep{C,R,A,D} <: DI.PushforwardPrep end function DI.prepare_pushforward( - f!, y, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f!, + y, + backend::AutoFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) cache = if x isa Number nothing else @@ -26,18 +34,19 @@ function DI.prepare_pushforward( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgPushforwardPrep(cache, relstep, absstep, dir) + return FiniteDiffTwoArgPushforwardPrep(_sig, cache, relstep, absstep, dir) end function DI.value_and_pushforward( f!, y, - prep::FiniteDiffTwoArgPushforwardPrep{Nothing}, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,Nothing}, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep function step(t::Number, dx) new_y = similar(y) @@ -63,12 +72,13 @@ end function DI.pushforward( f!, y, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) ty = map(tx) do dx @@ -82,12 +92,13 @@ end function DI.value_and_pushforward( f!, y, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) ty = map(tx) do dx @@ -103,12 +114,13 @@ function DI.pushforward!( f!, y, ty::NTuple, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) for b in eachindex(tx, ty) @@ -122,12 +134,13 @@ function DI.value_and_pushforward!( f!, y, ty::NTuple, - prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache}, - ::AutoFiniteDiff, + prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache}, + backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {C} +) where {SIG,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) for b in eachindex(tx, ty) @@ -140,7 +153,8 @@ end ## Derivative -struct FiniteDiffTwoArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep +struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -148,8 +162,9 @@ struct FiniteDiffTwoArgDerivativePrep{C,R,A,D} <: DI.DerivativePrep end function DI.prepare_derivative( - f!, y, backend::AutoFiniteDiff, x, ::Vararg{DI.Context,C} + strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) df = similar(y) cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_INPLACE) relstep = if isnothing(backend.relstep) @@ -163,7 +178,7 @@ function DI.prepare_derivative( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgDerivativePrep(cache, relstep, absstep, dir) + return FiniteDiffTwoArgDerivativePrep(_sig, cache, relstep, absstep, dir) end function DI.prepare!_derivative( @@ -174,6 +189,7 @@ function DI.prepare!_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if y isa Vector (; cache) = old_prep cache.fx isa Union{Number,Nothing} || resize!(cache.fx, length(y)) @@ -182,7 +198,7 @@ function DI.prepare!_derivative( cache.c3 isa Union{Number,Nothing} || resize!(cache.c3, length(y)) return old_prep else - return DI.prepare_derivative(f!, y, backend, x, contexts...) + return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end @@ -190,10 +206,11 @@ function DI.value_and_derivative( f!, y, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) fc!(y, x) @@ -206,10 +223,11 @@ function DI.value_and_derivative!( y, der, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) fc!(y, x) @@ -221,10 +239,11 @@ function DI.derivative( f!, y, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) fc!(y, x) @@ -237,10 +256,11 @@ function DI.derivative!( y, der, prep::FiniteDiffTwoArgDerivativePrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir) @@ -249,7 +269,8 @@ end ## Jacobian -struct FiniteDiffTwoArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep +struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} cache::C relstep::R absstep::A @@ -257,8 +278,9 @@ struct FiniteDiffTwoArgJacobianPrep{C,R,A,D} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x1 = similar(x) fx = similar(y) fx1 = similar(y) @@ -274,7 +296,7 @@ function DI.prepare_jacobian( backend.relstep end dir = backend.dir - return FiniteDiffTwoArgJacobianPrep(cache, relstep, absstep, dir) + return FiniteDiffTwoArgJacobianPrep(_sig, cache, relstep, absstep, dir) end function DI.prepare!_jacobian( @@ -285,6 +307,7 @@ function DI.prepare!_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if x isa Vector && y isa Vector (; cache) = old_prep cache.x1 isa Union{Number,Nothing} || resize!(cache.x1, length(x)) @@ -295,7 +318,7 @@ function DI.prepare!_jacobian( cache.sparsity = nothing return old_prep else - return DI.prepare_jacobian(f!, y, backend, x, contexts...) + return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end @@ -303,10 +326,11 @@ function DI.value_and_jacobian( f!, y, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -320,10 +344,11 @@ function DI.value_and_jacobian!( y, jac, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) @@ -335,10 +360,11 @@ function DI.jacobian( f!, y, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -351,10 +377,11 @@ function DI.jacobian!( y, jac, prep::FiniteDiffTwoArgJacobianPrep, - ::AutoFiniteDiff, + backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep fc! = DI.with_contexts(f!, contexts...) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index f2f692002..31bcd5961 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -12,19 +12,26 @@ DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple() ## Pushforward function DI.prepare_pushforward( - f, ::AutoFiniteDifferences, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoFiniteDifferences, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} - return DI.NoPushforwardPrep() + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + return DI.NoPushforwardPrep(_sig) end function DI.pushforward( f, - ::DI.NoPushforwardPrep, + prep::DI.NoPushforwardPrep, backend::AutoFiniteDifferences, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) ty = map(tx) do dx jvp(backend.fdm, fc, (x, dx)) @@ -40,6 +47,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward(f, prep, backend, x, tx, contexts...) end @@ -47,19 +55,26 @@ end ## Pullback function DI.prepare_pullback( - f, ::AutoFiniteDifferences, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoFiniteDifferences, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} - return DI.NoPullbackPrep() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.pullback( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoFiniteDifferences, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.with_contexts(f, contexts...) tx = map(ty) do dy only(j′vp(backend.fdm, fc, dy, x)) @@ -75,6 +90,7 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pullback(f, prep, backend, x, ty, contexts...) end @@ -82,18 +98,20 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} - return DI.NoGradientPrep() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end function DI.gradient( f, - ::DI.NoGradientPrep, + prep::DI.NoGradientPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return only(grad(backend.fdm, fc, x)) end @@ -105,6 +123,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end @@ -116,6 +135,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end @@ -127,6 +147,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end @@ -134,18 +155,20 @@ end ## Jacobian function DI.prepare_jacobian( - f, ::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}; ) where {C} - return DI.NoJacobianPrep() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoJacobianPrep(_sig) end function DI.jacobian( f, - ::DI.NoJacobianPrep, + prep::DI.NoJacobianPrep, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return only(jacobian(backend.fdm, fc, x)) end @@ -157,6 +180,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -168,6 +192,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -179,6 +204,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index 2e46031a0..996c659c6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -33,7 +33,6 @@ DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload() include("utils.jl") include("onearg.jl") include("twoarg.jl") -# include("secondorder.jl") include("differentiate_with.jl") include("misc.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 43c0f44be..703fe6be6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -60,14 +60,22 @@ end ### Prepared -struct ForwardDiffOneArgPushforwardPrep{T,X,CD} <: DI.PushforwardPrep +struct ForwardDiffOneArgPushforwardPrep{SIG,T,X,CD} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + _t::Type{T} xdual_tmp::X contexts_dual::CD end function DI.prepare_pushforward( - f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C} + strict::Val, + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}; ) where {F,B,C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) T = tag_type(f, backend, x) if DI.ismutable_array(x) xdual_tmp = make_dual_similar(T, x, tx) @@ -75,18 +83,16 @@ function DI.prepare_pushforward( xdual_tmp = nothing end contexts_dual = translate_toprep(Dual{T,eltype(x),B}, contexts) - return ForwardDiffOneArgPushforwardPrep{T,typeof(xdual_tmp),typeof(contexts_dual)}( - xdual_tmp, contexts_dual - ) + return ForwardDiffOneArgPushforwardPrep(_sig, T, xdual_tmp, contexts_dual) end function compute_ydual_onearg( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, x::Number, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} xdual = make_dual(T, x, tx) contexts_dual = translate_prepared(contexts, prep.contexts_dual) ydual = f(xdual, contexts_dual...) @@ -95,11 +101,11 @@ end function compute_ydual_onearg( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} if DI.ismutable_array(x) make_dual!(T, prep.xdual_tmp, x, tx) xdual_tmp = prep.xdual_tmp @@ -113,12 +119,13 @@ end function DI.value_and_pushforward( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) ty = mypartials(T, Val(B), ydual) @@ -128,12 +135,13 @@ end function DI.value_and_pushforward!( f::F, ty::NTuple, - prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) y = myvalue(T, ydual) mypartials!(T, ty, ydual) @@ -142,12 +150,13 @@ end function DI.pushforward( f::F, - prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual) return ty @@ -156,12 +165,13 @@ end function DI.pushforward!( f::F, ty::NTuple, - prep::ForwardDiffOneArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffOneArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ydual = compute_ydual_onearg(f, prep, x, tx, contexts...) mypartials!(T, ty, ydual) return ty @@ -169,7 +179,8 @@ end ## Derivative -struct ForwardDiffOneArgDerivativePrep{E} <: DI.DerivativePrep +struct ForwardDiffOneArgDerivativePrep{SIG,E} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} pushforward_prep::E end @@ -205,10 +216,11 @@ end ### Prepared function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} - pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) - return ForwardDiffOneArgDerivativePrep(pushforward_prep) + _sig = DI.signature(f, backend, x, contexts...; strict) + pushforward_prep = DI.prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) + return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep) end function DI.value_and_derivative( @@ -218,6 +230,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -232,6 +245,7 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -245,6 +259,7 @@ function DI.derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) ) @@ -258,6 +273,7 @@ function DI.derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der end @@ -338,19 +354,25 @@ end ### Prepared -struct ForwardDiffGradientPrep{C,CD} <: DI.GradientPrep +struct ForwardDiffGradientPrep{SIG,C,CD} <: DI.GradientPrep{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_gradient( - f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{DI.Context,C} + strict::Val, + f::F, + backend::AutoForwardDiff, + x::AbstractArray, + contexts::Vararg{DI.Context,C}; ) where {F,C} + _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) config = GradientConfig(nothing, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffGradientPrep(config, contexts_dual) + return ForwardDiffGradientPrep(_sig, config, contexts_dual) end function DI.value_and_gradient!( @@ -361,6 +383,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -381,6 +404,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = GradientResult(x) @@ -400,6 +424,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -416,6 +441,7 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -500,19 +526,21 @@ end ### Prepared -struct ForwardDiffOneArgJacobianPrep{C,CD} <: DI.JacobianPrep +struct ForwardDiffOneArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} + _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) config = JacobianConfig(nothing, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffOneArgJacobianPrep(config, contexts_dual) + return ForwardDiffOneArgJacobianPrep(_sig, config, contexts_dual) end function DI.value_and_jacobian!( @@ -523,6 +551,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) y = fc(x) @@ -544,6 +573,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -561,6 +591,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -577,6 +608,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -589,18 +621,20 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} - return DI.NoSecondDerivativePrep() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoSecondDerivativePrep(_sig) end function DI.second_derivative( f::F, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -613,11 +647,12 @@ end function DI.second_derivative!( f::F, der2, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -629,11 +664,12 @@ end function DI.value_derivative_and_second_derivative( f::F, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -650,11 +686,12 @@ function DI.value_derivative_and_second_derivative!( f::F, der, der2, - ::DI.NoSecondDerivativePrep, + prep::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) T2 = tag_type(f, backend, xdual) @@ -667,67 +704,6 @@ function DI.value_derivative_and_second_derivative!( return y, der, der2 end -## HVP - -#= -function DI.prepare_hvp( - f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} -) where {F,C} - return DI.prepare_hvp(f, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.hvp(f, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.hvp!(f, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.gradient_and_hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.gradient_and_hvp!( - f, grad, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end -=# - ## Hessian ### Unprepared, only when chunk size and tag are not specified @@ -810,22 +786,24 @@ end ### Prepared -struct ForwardDiffHessianPrep{C1,C2,CD} <: DI.HessianPrep +struct ForwardDiffHessianPrep{SIG,C1,C2,CD} <: DI.HessianPrep{SIG} + _sig::Val{SIG} array_config::C1 result_config::C2 contexts_dual::CD end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} + _sig = DI.signature(f, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f, backend, x) result = HessianResult(x) array_config = HessianConfig(nothing, x, chunk, tag) result_config = HessianConfig(nothing, result, x, chunk, tag) contexts_dual = translate_toprep(dual_type(array_config), contexts) - return ForwardDiffHessianPrep(array_config, result_config, contexts_dual) + return ForwardDiffHessianPrep(_sig, array_config, result_config, contexts_dual) end function DI.hessian!( @@ -836,6 +814,7 @@ function DI.hessian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -852,6 +831,7 @@ function DI.hessian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -870,6 +850,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -891,6 +872,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.FixTail(f, contexts_dual...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl deleted file mode 100644 index d01efd074..000000000 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ /dev/null @@ -1,144 +0,0 @@ -struct ForwardDiffOverSomethingHVPPrep{E1<:DI.GradientPrep,E2<:DI.PushforwardPrep} <: - DI.HVPPrep - inner_gradient_prep::E1 - outer_pushforward_prep::E2 -end - -function DI.prepare_hvp( - f::F, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - T = tag_type(DI.shuffled_gradient, DI.outer(backend), x) - xdual = make_dual(T, x, tx) - inner_gradient_prep = DI.prepare_gradient(f, DI.inner(backend), xdual, contexts...) - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - outer_pushforward_prep = DI.prepare_pushforward( - DI.shuffled_gradient, DI.outer(backend), x, tx, new_contexts... - ) - return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep) -end - -function DI.hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward!( - DI.shuffled_gradient, - tg, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) - return tg -end - -function DI.gradient_and_hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.value_and_pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - new_grad, _ = DI.value_and_pushforward!( - DI.shuffled_gradient, - tg, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg -end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index f52ae69d9..a919d8965 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -1,33 +1,38 @@ ## Pushforward -struct ForwardDiffTwoArgPushforwardPrep{T,X,Y,CD} <: DI.PushforwardPrep +struct ForwardDiffTwoArgPushforwardPrep{SIG,T,X,Y,CD} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + _t::Type{T} xdual_tmp::X ydual_tmp::Y contexts_dual::CD end function DI.prepare_pushforward( - f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C} + strict::Val, + f!::F, + y, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}; ) where {F,B,C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) T = tag_type(f!, backend, x) xdual_tmp = make_dual_similar(T, x, tx) ydual_tmp = make_dual_similar(T, y, tx) # tx only for batch size contexts_dual = translate_toprep(eltype(xdual_tmp), contexts) - return ForwardDiffTwoArgPushforwardPrep{ - T,typeof(xdual_tmp),typeof(ydual_tmp),typeof(contexts_dual) - }( - xdual_tmp, ydual_tmp, contexts_dual - ) + return ForwardDiffTwoArgPushforwardPrep(_sig, T, xdual_tmp, ydual_tmp, contexts_dual) end function compute_ydual_twoarg( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, x::Number, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} (; ydual_tmp) = prep xdual_tmp = make_dual(T, x, tx) contexts_dual = translate_prepared(contexts, prep.contexts_dual) @@ -38,13 +43,18 @@ end function compute_ydual_twoarg( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} - (; xdual_tmp, ydual_tmp) = prep - make_dual!(T, xdual_tmp, x, tx) +) where {F,SIG,T,B,C} + (; ydual_tmp) = prep + if DI.ismutable_array(x) + make_dual!(T, prep.xdual_tmp, x, tx) + xdual_tmp = prep.xdual_tmp + else + xdual_tmp = make_dual(T, x, tx) + end contexts_dual = translate_prepared(contexts, prep.contexts_dual) f!(ydual_tmp, xdual_tmp, contexts_dual...) return ydual_tmp @@ -53,12 +63,13 @@ end function DI.value_and_pushforward( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) ty = mypartials(T, Val(B), ydual_tmp) @@ -69,12 +80,13 @@ function DI.value_and_pushforward!( f!::F, y, ty::NTuple, - prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) myvalue!(T, y, ydual_tmp) mypartials!(T, ty, ydual_tmp) @@ -84,12 +96,13 @@ end function DI.pushforward( f!::F, y, - prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, -) where {F,T,B,C} +) where {F,SIG,T,B,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) ty = mypartials(T, Val(B), ydual_tmp) return ty @@ -99,12 +112,13 @@ function DI.pushforward!( f!::F, y, ty::NTuple, - prep::ForwardDiffTwoArgPushforwardPrep{T}, - ::AutoForwardDiff, + prep::ForwardDiffTwoArgPushforwardPrep{SIG,T}, + backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,T,C} +) where {F,SIG,T,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...) mypartials!(T, ty, ydual_tmp) return ty @@ -168,18 +182,20 @@ end ### Prepared -struct ForwardDiffTwoArgDerivativePrep{C,CD} <: DI.DerivativePrep +struct ForwardDiffTwoArgDerivativePrep{SIG,C,CD} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_derivative( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}; ) where {F,C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) tag = get_tag(f!, backend, x) config = DerivativeConfig(nothing, y, x, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgDerivativePrep(config, contexts_dual) + return ForwardDiffTwoArgDerivativePrep(_sig, config, contexts_dual) end function DI.prepare!_derivative( @@ -190,12 +206,13 @@ function DI.prepare!_derivative( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if y isa Vector (; config) = old_prep resize!(config.duals, length(y)) return old_prep else - return DI.prepare_derivative(f!, y, backend, x, contexts...) + return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end @@ -207,6 +224,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) result = MutableDiffResult(y, (similar(y),)) @@ -227,6 +245,7 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) result = MutableDiffResult(y, (der,)) @@ -246,6 +265,7 @@ function DI.derivative( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -264,6 +284,7 @@ function DI.derivative!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -348,19 +369,21 @@ end ### Prepared -struct ForwardDiffTwoArgJacobianPrep{C,CD} <: DI.JacobianPrep +struct ForwardDiffTwoArgJacobianPrep{SIG,C,CD} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} config::C contexts_dual::CD end function DI.prepare_jacobian( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) chunk = choose_chunk(backend, x) tag = get_tag(f!, backend, x) config = JacobianConfig(nothing, y, x, chunk, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgJacobianPrep(config, contexts_dual) + return ForwardDiffTwoArgJacobianPrep(_sig, config, contexts_dual) end function DI.prepare!_jacobian( @@ -371,6 +394,7 @@ function DI.prepare!_jacobian( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} + DI.check_prep(f!, y, old_prep, backend, x, contexts...) if x isa Vector && y isa Vector (; config) = old_prep (yduals, xduals) = config.duals @@ -378,7 +402,7 @@ function DI.prepare!_jacobian( resize!(xduals, length(x)) return old_prep else - return DI.prepare_jacobian(f!, y, backend, x, contexts...) + return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...) end end @@ -390,6 +414,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) jac = similar(y, length(y), length(x)) @@ -411,6 +436,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) result = MutableDiffResult(y, (jac,)) @@ -430,6 +456,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing @@ -448,6 +475,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc! = DI.FixTail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index c094bd14a..2dd4409cc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -19,7 +19,7 @@ choose_chunk(::AutoForwardDiff{chunksize}, x) where {chunksize} = Chunk{chunksiz get_tag(f, backend::AutoForwardDiff, x) = backend.tag -function get_tag(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize} +function get_tag(f::F, backend::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize} return Tag(f, eltype(x)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index ca7937d63..83be97ed1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -2,14 +2,20 @@ # Contains either a single pre-allocated initial TPS # or a vector of pre-allocated TPSs. -struct GTPSAOneArgPushforwardPrep{X} <: DI.PushforwardPrep +struct GTPSAOneArgPushforwardPrep{SIG,X} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} xt::X end function DI.prepare_pushforward( - ::F, backend::AutoGTPSA{D}, x, tx::NTuple, ::Vararg{DI.Constant,C} + strict::Val, + f::F, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant,C}; ) where {F,D,C} - + _sig = DI.signature(f, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. if D != Nothing @@ -19,24 +25,25 @@ function DI.prepare_pushforward( end if x isa Number xt = TPS{promote_type(typeof(first(tx)), typeof(x), Float64)}(; use=d) - return GTPSAOneArgPushforwardPrep(xt) + return GTPSAOneArgPushforwardPrep(_sig, xt) else xt = similar(x, TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}) for i in eachindex(xt) xt[i] = TPS{promote_type(eltype(first(tx)), eltype(x), Float64)}(; use=d) end - return GTPSAOneArgPushforwardPrep(xt) + return GTPSAOneArgPushforwardPrep(_sig, xt) end end function DI.pushforward( f, prep::GTPSAOneArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) ty = map(tx) do dx foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) @@ -55,11 +62,12 @@ function DI.pushforward!( f, ty::NTuple, prep::GTPSAOneArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) fc = DI.with_contexts(f, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -78,9 +86,9 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} - fc = DI.with_contexts(f, contexts...) - ty = DI.pushforward(fc, prep, backend, x, tx) - y = fc(x) # TO-DO: optimize + DI.check_prep(f, prep, backend, x, tx, contexts...) + ty = DI.pushforward(f, prep, backend, x, tx, contexts...) + y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize return y, ty end @@ -93,22 +101,24 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} - fc = DI.with_contexts(f, contexts...) - DI.pushforward!(fc, ty, prep, backend, x, tx) - y = fc(x) # TO-DO: optimize + DI.check_prep(f, prep, backend, x, tx, contexts...) + DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) + y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize return y, ty end ## Gradient # Contains a vector of pre-allocated TPSs. -struct GTPSAOneArgGradientPrep{X} <: DI.GradientPrep +struct GTPSAOneArgGradientPrep{SIG,X} <: DI.GradientPrep{SIG} + _sig::Val{SIG} xt::X end # Unlike JVP, this requires us to use all variables function DI.prepare_gradient( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -121,12 +131,13 @@ function DI.prepare_gradient( xt[i][j] = 1 j += 1 end - return GTPSAOneArgGradientPrep(xt) + return GTPSAOneArgGradientPrep(_sig, xt) end function DI.gradient( - f, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -136,8 +147,14 @@ function DI.gradient( end function DI.gradient!( - f, grad, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + grad, + prep::GTPSAOneArgGradientPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -146,8 +163,9 @@ function DI.gradient!( end function DI.value_and_gradient( - f, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgGradientPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -157,8 +175,14 @@ function DI.value_and_gradient( end function DI.value_and_gradient!( - f, grad, prep::GTPSAOneArgGradientPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + grad, + prep::GTPSAOneArgGradientPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -168,14 +192,16 @@ end ## Jacobian # Contains a vector of pre-allocated TPSs -struct GTPSAOneArgJacobianPrep{X} <: DI.JacobianPrep +struct GTPSAOneArgJacobianPrep{SIG,X} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} xt::X end # To materialize the entire Jacobian we use all variables function DI.prepare_jacobian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -190,12 +216,13 @@ function DI.prepare_jacobian( xt[i][j] = 1 j += 1 end - return GTPSAOneArgJacobianPrep(xt) + return GTPSAOneArgJacobianPrep(_sig, xt) end function DI.jacobian( - f, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -205,8 +232,14 @@ function DI.jacobian( end function DI.jacobian!( - f, jac, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::GTPSAOneArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -215,8 +248,9 @@ function DI.jacobian!( end function DI.value_and_jacobian( - f, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, prep::GTPSAOneArgJacobianPrep, backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -227,8 +261,14 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f, jac, prep::GTPSAOneArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::GTPSAOneArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -239,13 +279,15 @@ end ## Second derivative # Contains single pre-allocated TPS -struct GTPSAOneArgSecondDerivativePrep{X} <: DI.SecondDerivativePrep +struct GTPSAOneArgSecondDerivativePrep{SIG,X} <: DI.SecondDerivativePrep{SIG} + _sig::Val{SIG} xt::X end function DI.prepare_second_derivative( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -253,7 +295,7 @@ function DI.prepare_second_derivative( end xt = TPS{promote_type(typeof(x), Float64)}(; use=d) xt[1] = 1 # Set slope - return GTPSAOneArgSecondDerivativePrep(xt) + return GTPSAOneArgSecondDerivativePrep(_sig, xt) end function DI.second_derivative( @@ -263,6 +305,7 @@ function DI.second_derivative( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -291,6 +334,7 @@ function DI.second_derivative!( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -312,6 +356,7 @@ function DI.value_derivative_and_second_derivative( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -343,6 +388,7 @@ function DI.value_derivative_and_second_derivative!( x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -362,14 +408,16 @@ end ## Hessian # Stores allocated array of TPS and an array for the monomial coefficient # indexing in GTPSA.cycle! (which is used if a Descriptor is specified) -struct GTPSAOneArgHessianPrep{X,M} <: DI.HessianPrep +struct GTPSAOneArgHessianPrep{SIG,X,M} <: DI.HessianPrep{SIG} + _sig::Val{SIG} xt::X m::M end function DI.prepare_hessian( - f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + strict::Val, f, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} + _sig = DI.signature(f, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor m = Vector{UInt8}(undef, length(x)) @@ -392,12 +440,17 @@ function DI.prepare_hessian( j += 1 end - return GTPSAOneArgHessianPrep(xt, m) + return GTPSAOneArgHessianPrep(_sig, xt, m) end function DI.hessian( - f, prep::GTPSAOneArgHessianPrep, ::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -418,10 +471,11 @@ function DI.hessian!( f, hess, prep::GTPSAOneArgHessianPrep, - ::AutoGTPSA{D}, + backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -438,8 +492,13 @@ function DI.hessian!( end function DI.value_gradient_and_hessian( - f, prep::GTPSAOneArgHessianPrep, ::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + f, + prep::GTPSAOneArgHessianPrep, + backend::AutoGTPSA{D}, + x, + contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -463,10 +522,11 @@ function DI.value_gradient_and_hessian!( grad, hess, prep::GTPSAOneArgHessianPrep, - ::AutoGTPSA{D}, + backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc = DI.with_contexts(f, contexts...) yt = fc(prep.xt) @@ -483,18 +543,20 @@ function DI.value_gradient_and_hessian!( return yt[0], grad, hess end -struct GTPSAOneArgHVPPrep{E,H} <: DI.HVPPrep +struct GTPSAOneArgHVPPrep{SIG,E,H} <: DI.HVPPrep{SIG} + _sig::Val{SIG} hessprep::E hess::H end function DI.prepare_hvp( - f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + strict::Val, f, backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C} ) where {C} - hessprep = DI.prepare_hessian(f, backend, x) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + hessprep = DI.prepare_hessian(strict, f, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) - return GTPSAOneArgHVPPrep(hessprep, hess) + return GTPSAOneArgHVPPrep(_sig, hessprep, hess) end function DI.hvp( @@ -505,6 +567,7 @@ function DI.hvp( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hessian!(f, prep.hess, prep.hessprep, backend, x, contexts...) tg = map(tx) do dx dg = similar(x, eltype(prep.hess)) @@ -530,6 +593,7 @@ function DI.hvp!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hessian!(f, prep.hess, prep.hessprep, backend, x, contexts...) for b in eachindex(tg) dx, dg = tx[b], tg[b] @@ -553,6 +617,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) grad = similar(x, eltype(prep.hess)) DI.value_gradient_and_hessian!( f, grad, prep.hess, prep.hessprep, backend, x, contexts... @@ -582,6 +647,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {D,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.value_gradient_and_hessian!( f, grad, prep.hess, prep.hessprep, backend, x, contexts... ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index 3e2e88de5..e77c4900d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -4,15 +4,22 @@ # or a vector of pre-allocated TPSs. # # Output: Contains a vector of pre-allocated TPSs -struct GTPSATwoArgPushforwardPrep{X,Y} <: DI.PushforwardPrep +struct GTPSATwoArgPushforwardPrep{SIG,X,Y} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} xt::X yt::Y end function DI.prepare_pushforward( - ::F, y, backend::AutoGTPSA{D}, x, tx::NTuple, ::Vararg{DI.Constant,C} + strict::Val, + f!::F, + y, + backend::AutoGTPSA{D}, + x, + tx::NTuple, + contexts::Vararg{DI.Constant,C}; ) where {F,D,C} - + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) # For pushforward/JVP, we only actually need 1 single variable (in the GTPSA sense) # because we even if we did multiple we will add up the derivatives of each at the end. if D != Nothing @@ -34,18 +41,19 @@ function DI.prepare_pushforward( for i in eachindex(yt) yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) end - return GTPSATwoArgPushforwardPrep(xt, yt) + return GTPSATwoArgPushforwardPrep(_sig, xt, yt) end function DI.pushforward( f!, y, prep::GTPSATwoArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) fc! = DI.with_contexts(f!, contexts...) ty = map(tx) do dx foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) @@ -62,11 +70,12 @@ function DI.pushforward!( y, ty::NTuple, prep::GTPSATwoArgPushforwardPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) fc! = DI.with_contexts(f!, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] @@ -87,6 +96,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) return y, ty end @@ -101,6 +111,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) return y, ty end @@ -108,14 +119,16 @@ end ## Jacobian # Input: Contains a vector of pre-allocated TPSs # Output: Contains a vector of pre-allocated TPSs -struct GTPSATwoArgJacobianPrep{X,Y} <: DI.JacobianPrep +struct GTPSATwoArgJacobianPrep{SIG,X,Y} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} xt::X yt::Y end function DI.prepare_jacobian( - f, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} + strict::Val, f!, y, backend::AutoGTPSA{D}, x, contexts::Vararg{DI.Constant,C} ) where {D,C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) if D != Nothing d = backend.descriptor else @@ -137,12 +150,18 @@ function DI.prepare_jacobian( yt[i] = TPS{promote_type(eltype(y), Float64)}(; use=d) end - return GTPSATwoArgJacobianPrep(xt, yt) + return GTPSATwoArgJacobianPrep(_sig, xt, yt) end function DI.jacobian( - f!, y, prep::GTPSATwoArgJacobianPrep, ::AutoGTPSA, x, contexts::Vararg{DI.Constant,C} + f!, + y, + prep::GTPSATwoArgJacobianPrep, + backend::AutoGTPSA, + x, + contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc! = DI.with_contexts(f!, contexts...) fc!(prep.yt, prep.xt) @@ -157,10 +176,11 @@ function DI.jacobian!( y, jac, prep::GTPSATwoArgJacobianPrep, - ::AutoGTPSA, + backend::AutoGTPSA, x, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part fc! = DI.with_contexts(f!, contexts...) fc!(prep.yt, prep.xt) @@ -177,6 +197,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) # y set on line 151 return y, jac end @@ -190,6 +211,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Constant,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) return y, jac end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 81acd8673..e1e0c580b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,20 +1,22 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{Tcache,DY} <: DI.PullbackPrep +struct MooncakeOneArgPullbackPrep{SIG,Tcache,DY} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} cache::Tcache dy_righttype::DY end function DI.prepare_pullback( - f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {F,C} + _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) - prep = MooncakeOneArgPullbackPrep(cache, dy_righttype) + prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype) DI.value_and_pullback(f, prep, backend, x, ty, contexts...) return prep end @@ -22,11 +24,12 @@ end function DI.value_and_pullback( f::F, prep::MooncakeOneArgPullbackPrep{Y}, - ::AutoMooncake, + backend::AutoMooncake, x, ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,Y,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( @@ -35,31 +38,21 @@ function DI.value_and_pullback( return new_y, (mycopy(new_dx),) end -function DI.value_and_pullback!( - f::F, - tx::NTuple{1}, - prep::MooncakeOneArgPullbackPrep{Y}, - backend::AutoMooncake, - x, - ty::NTuple{1}, - contexts::Vararg{DI.Context,C}, -) where {F,Y,C} - y, (new_dx,) = DI.value_and_pullback(f, prep, backend, x, ty, contexts...) - copyto!(only(tx), new_dx) - return y, tx -end - function DI.value_and_pullback( f::F, - prep::MooncakeOneArgPullbackPrep, + prep::MooncakeOneArgPullbackPrep{Y}, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C}, -) where {F,C} +) where {F,Y,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - y, tx = DI.value_and_pullback(f, prep, backend, x, (dy,), contexts...) - y, only(tx) + dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) + y, (_, new_dx) = value_and_pullback!!( + prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... + ) + y, mycopy(new_dx) end y = first(ys_and_tx[1]) tx = last.(ys_and_tx) @@ -75,11 +68,9 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} - ys = map(tx, ty) do dx, dy - y, _ = DI.value_and_pullback!(f, (dx,), prep, backend, x, (dy,), contexts...) - y - end - y = ys[1] + DI.check_prep(f, prep, backend, x, ty, contexts...) + y, new_tx = DI.value_and_pullback(f, prep, backend, x, ty, contexts...) + foreach(copyto!, tx, new_tx) return y, tx end @@ -91,6 +82,7 @@ function DI.pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return DI.value_and_pullback(f, prep, backend, x, ty, contexts...)[2] end @@ -103,30 +95,38 @@ function DI.pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, ty, contexts...) return DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2] end ## Gradient -struct MooncakeGradientPrep{Tcache} <: DI.GradientPrep +struct MooncakeGradientPrep{SIG,Tcache} <: DI.GradientPrep{SIG} + _sig::Val{SIG} cache::Tcache end function DI.prepare_gradient( - f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoMooncake, x, contexts::Vararg{DI.Context,C} ) where {F,C} + _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) - prep = MooncakeGradientPrep(cache) + prep = MooncakeGradientPrep(_sig, cache) DI.value_and_gradient(f, prep, backend, x, contexts...) return prep end function DI.value_and_gradient( - f::F, prep::MooncakeGradientPrep, ::AutoMooncake, x, contexts::Vararg{DI.Context,C} + f::F, + prep::MooncakeGradientPrep, + backend::AutoMooncake, + x, + contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) return y, mycopy(new_grad) end @@ -135,10 +135,11 @@ function DI.value_and_gradient!( f::F, grad, prep::MooncakeGradientPrep, - ::AutoMooncake, + backend::AutoMooncake, x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) copyto!(grad, new_grad) return y, grad @@ -151,6 +152,7 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) _, grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return grad end @@ -163,6 +165,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) return grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index bd86b34f8..43bed9857 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,12 +1,20 @@ -struct MooncakeTwoArgPullbackPrep{Tcache,DY,F} <: DI.PullbackPrep +struct MooncakeTwoArgPullbackPrep{SIG,Tcache,DY,F} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} cache::Tcache dy_righttype::DY target_function::F end function DI.prepare_pullback( - f!::F, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f!::F, + y, + backend::AutoMooncake, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; ) where {F,C} + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) target_function = function (f!, y, x, contexts...) f!(y, x, contexts...) return y @@ -22,26 +30,33 @@ function DI.prepare_pullback( silence_debug_messages=config.silence_debug_messages, ) dy_righttype_after = zero_tangent(y) - return MooncakeTwoArgPullbackPrep(cache, dy_righttype_after, target_function) + prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function) + DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) + return prep end function DI.value_and_pullback( f!::F, y, prep::MooncakeTwoArgPullbackPrep, - ::AutoMooncake, + backend::AutoMooncake, x, ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - # Prepare cotangent to add after the forward pass. + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) dy = only(ty) + # Prepare cotangent to add after the forward pass. dy_righttype_after = copyto!(prep.dy_righttype, dy) - # Run the reverse-pass and return the results. - contexts = map(DI.unwrap, contexts) y_after, (_, _, _, dx) = value_and_pullback!!( - prep.cache, dy_righttype_after, prep.target_function, f!, y, x, contexts... + prep.cache, + dy_righttype_after, + prep.target_function, + f!, + y, + x, + map(DI.unwrap, contexts)..., ) copyto!(y, y_after) return y, (mycopy(dx),) @@ -56,9 +71,20 @@ function DI.value_and_pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy - _, tx = DI.value_and_pullback(f!, y, prep, backend, x, (dy,), contexts...) - only(tx) + dy_righttype_after = copyto!(prep.dy_righttype, dy) + y_after, (_, _, _, dx) = value_and_pullback!!( + prep.cache, + dy_righttype_after, + prep.target_function, + f!, + y, + x, + map(DI.unwrap, contexts)..., + ) + copyto!(y, y_after) + mycopy(dx) end return y, tx end @@ -73,6 +99,7 @@ function DI.value_and_pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) _, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) return y, tx @@ -87,6 +114,7 @@ function DI.pullback( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) return DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2] end @@ -100,5 +128,6 @@ function DI.pullback!( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) return DI.value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2] end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 17cf1ab6c..2e9380fd4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -1,127 +1,180 @@ ## Pushforward +struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + single_threaded_prep::P +end + function DI.prepare_pushforward( - f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} - return DI.prepare_pushforward(f, single_threaded(backend), x, tx, contexts...) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + single_threaded_prep = DI.prepare_pushforward( + strict, f, single_threaded(backend), x, tx, contexts... + ) + return PolyesterForwardDiffOneArgPushforwardPrep(_sig, single_threaded_prep) end function DI.value_and_pushforward( f, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_pushforward(f, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward( + f, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end function DI.value_and_pushforward!( f, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.value_and_pushforward!( - f, ty, prep, single_threaded(backend), x, tx, contexts... + f, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... ) end function DI.pushforward( f, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward(f, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.pushforward( + f, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end function DI.pushforward!( f, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffOneArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward!(f, ty, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.pushforward!( + f, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end ## Derivative +struct PolyesterForwardDiffOneArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} + single_threaded_prep::P +end + function DI.prepare_derivative( - f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} - return DI.prepare_derivative(f, single_threaded(backend), x, contexts...) + _sig = DI.signature(f, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_derivative( + strict, f, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffOneArgDerivativePrep(_sig, single_threaded_prep) end function DI.value_and_derivative( f, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_derivative(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.value_and_derivative( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_and_derivative!( f, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_derivative!(f, der, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.value_and_derivative!( + f, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.derivative( f, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.derivative( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.derivative!( f, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffOneArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative!(f, der, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.derivative!( + f, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end ## Gradient -struct PolyesterForwardDiffGradientPrep{chunksize,P} <: DI.GradientPrep +struct PolyesterForwardDiffGradientPrep{SIG,chunksize,P} <: DI.GradientPrep{SIG} + _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_gradient( - f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context,C}; ) where {chunksize,C} + _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_gradient(f, single_threaded(backend), x, contexts...) - return PolyesterForwardDiffGradientPrep(chunk, single_threaded_prep) + single_threaded_prep = DI.prepare_gradient( + strict, f, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffGradientPrep(_sig, chunk, single_threaded_prep) end function DI.value_and_gradient!( @@ -132,6 +185,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) threaded_gradient!(fc, grad, x, prep.chunk) @@ -152,6 +206,7 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) threaded_gradient!(fc, grad, x, prep.chunk) @@ -171,6 +226,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_and_gradient!(f, similar(x), prep, backend, x, contexts...) end @@ -181,26 +237,35 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.gradient!(f, similar(x), prep, backend, x, contexts...) end ## Jacobian -struct PolyesterForwardDiffOneArgJacobianPrep{chunksize,P} <: DI.JacobianPrep +struct PolyesterForwardDiffOneArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_jacobian( - f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + strict::Val, + f, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context,C}; ) where {chunksize,C} + _sig = DI.signature(f, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end - single_threaded_prep = DI.prepare_jacobian(f, single_threaded(backend), x, contexts...) - return PolyesterForwardDiffOneArgJacobianPrep(chunk, single_threaded_prep) + single_threaded_prep = DI.prepare_jacobian( + strict, f, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffOneArgJacobianPrep(_sig, chunk, single_threaded_prep) end function DI.value_and_jacobian!( @@ -211,6 +276,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk) @@ -229,6 +295,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc = DI.with_contexts(f, contexts...) return threaded_jacobian!(fc, jac, x, prep.chunk) @@ -246,6 +313,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) return DI.value_and_jacobian!(f, jac, prep, backend, x, contexts...) @@ -258,6 +326,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) return DI.jacobian!(f, jac, prep, backend, x, contexts...) @@ -265,148 +334,103 @@ end ## Hessian +struct PolyesterForwardDiffHessianPrep{SIG,P} <: DI.HessianPrep{SIG} + _sig::Val{SIG} + single_threaded_prep::P +end + function DI.prepare_hessian( - f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} - return DI.prepare_hessian(f, single_threaded(backend), x, contexts...) + _sig = DI.signature(f, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_hessian( + strict, f, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffHessianPrep(_sig, single_threaded_prep) end function DI.hessian( f, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.hessian(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.hessian( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.hessian!( f, hess, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.hessian!(f, hess, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.hessian!( + f, hess, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_gradient_and_hessian( f, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_gradient_and_hessian(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.value_gradient_and_hessian( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_gradient_and_hessian!( f, grad, hess, - prep::DI.HessianPrep, + prep::PolyesterForwardDiffHessianPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_gradient_and_hessian!( - f, grad, hess, prep, single_threaded(backend), x, contexts... - ) -end - -## HVP - -#= -function DI.prepare_hvp( - f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} -) where {C} - return DI.prepare_hvp( - f, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp( - f, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp!( - f, - tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp!( - f, tg, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... + f, grad, hess, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end -function DI.gradient_and_hvp( - f, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end +## Second derivative -function DI.gradient_and_hvp!( - f, - grad, - tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.gradient_and_hvp!( - f, - grad, - tg, - prep, - DI.SecondOrder(single_threaded(backend), backend), - x, - tx, - contexts..., - ) +struct PolyesterForwardDiffOneArgSecondDerivativePrep{SIG,P} <: DI.SecondDerivativePrep{SIG} + _sig::Val{SIG} + single_threaded_prep::P end -=# - -## Second derivative function DI.prepare_second_derivative( - f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} - return DI.prepare_second_derivative(f, single_threaded(backend), x, contexts...) + _sig = DI.signature(f, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_second_derivative( + strict, f, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffOneArgSecondDerivativePrep(_sig, single_threaded_prep) end function DI.value_derivative_and_second_derivative( f, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative( - f, prep, single_threaded(backend), x, contexts... + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end @@ -414,33 +438,40 @@ function DI.value_derivative_and_second_derivative!( f, der, der2, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return DI.value_derivative_and_second_derivative!( - f, der, der2, prep, single_threaded(backend), x, contexts... + f, der, der2, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end function DI.second_derivative( f, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.second_derivative(f, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.second_derivative( + f, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.second_derivative!( f, der2, - prep::DI.SecondDerivativePrep, + prep::PolyesterForwardDiffOneArgSecondDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.second_derivative!(f, der2, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + return DI.second_derivative!( + f, der2, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 4f63a65c1..74d460cc3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -1,22 +1,38 @@ ## Pushforward +struct PolyesterForwardDiffTwoArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + single_threaded_prep::P +end + function DI.prepare_pushforward( - f!, y, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f!, + y, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} - return DI.prepare_pushforward(f!, y, single_threaded(backend), x, tx, contexts...) + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) + single_threaded_prep = DI.prepare_pushforward( + f!, y, single_threaded(backend), x, tx, contexts... + ) + return PolyesterForwardDiffTwoArgPushforwardPrep(_sig, single_threaded_prep) end function DI.value_and_pushforward( f!, y, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.value_and_pushforward( - f!, y, prep, single_threaded(backend), x, tx, contexts... + f!, y, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... ) end @@ -24,117 +40,150 @@ function DI.value_and_pushforward!( f!, y, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) return DI.value_and_pushforward!( - f!, y, ty, prep, single_threaded(backend), x, tx, contexts... + f!, y, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... ) end function DI.pushforward( f!, y, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward(f!, y, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + return DI.pushforward( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end function DI.pushforward!( f!, y, ty::NTuple, - prep::DI.PushforwardPrep, + prep::PolyesterForwardDiffTwoArgPushforwardPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.pushforward!(f!, y, ty, prep, single_threaded(backend), x, tx, contexts...) + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + return DI.pushforward!( + f!, y, ty, prep.single_threaded_prep, single_threaded(backend), x, tx, contexts... + ) end ## Derivative +struct PolyesterForwardDiffTwoArgDerivativePrep{SIG,P} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} + single_threaded_prep::P +end + function DI.prepare_derivative( - f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f!, y, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {C} - return DI.prepare_derivative(f!, y, single_threaded(backend), x, contexts...) + _sig = DI.signature(f!, y, backend, x, contexts...; strict) + single_threaded_prep = DI.prepare_derivative( + strict, f!, y, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffTwoArgDerivativePrep(_sig, single_threaded_prep) end function DI.value_and_derivative( f!, y, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.value_and_derivative(f!, y, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) + return DI.value_and_derivative( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.value_and_derivative!( f!, y, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return DI.value_and_derivative!( - f!, y, der, prep, single_threaded(backend), x, contexts... + f!, y, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... ) end function DI.derivative( f!, y, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative(f!, y, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) + return DI.derivative( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end function DI.derivative!( f!, y, der, - prep::DI.DerivativePrep, + prep::PolyesterForwardDiffTwoArgDerivativePrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.derivative!(f!, y, der, prep, single_threaded(backend), x, contexts...) + DI.check_prep(f!, y, prep, backend, x, contexts...) + return DI.derivative!( + f!, y, der, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) end ## Jacobian -struct PolyesterForwardDiffTwoArgJacobianPrep{chunksize,P} <: DI.JacobianPrep +struct PolyesterForwardDiffTwoArgJacobianPrep{SIG,chunksize,P} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} chunk::Chunk{chunksize} single_threaded_prep::P end function DI.prepare_jacobian( - f!, y, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + strict::Val, + f!, + y, + backend::AutoPolyesterForwardDiff{chunksize}, + x, + contexts::Vararg{DI.Context,C}; ) where {chunksize,C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end single_threaded_prep = DI.prepare_jacobian( - f!, y, single_threaded(backend), x, contexts... + strict, f!, y, single_threaded(backend), x, contexts... ) - return PolyesterForwardDiffTwoArgJacobianPrep(chunk, single_threaded_prep) + return PolyesterForwardDiffTwoArgJacobianPrep(_sig, chunk, single_threaded_prep) end function DI.value_and_jacobian( @@ -145,6 +194,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {K,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -167,6 +217,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {K,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) @@ -187,6 +238,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -208,6 +260,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} fc! = DI.with_contexts(f!, contexts...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index fadca806a..13b673ac5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -1,19 +1,21 @@ ## Pullback function DI.prepare_pullback( - f, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} - return DI.NoPullbackPrep() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.with_contexts(f, contexts...) y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) @@ -30,12 +32,13 @@ end function DI.value_and_pullback!( f, tx::NTuple, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) fc = DI.with_contexts(f, contexts...) y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) @@ -53,12 +56,13 @@ end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) x_array = [x] f_array(x_array, args...) = f(only(x_array), args...) y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty, contexts...) @@ -69,24 +73,29 @@ end ### Without contexts -@kwdef struct ReverseDiffGradientPrep{C,T} <: DI.GradientPrep +struct ReverseDiffGradientPrep{SIG,C,T} <: DI.GradientPrep{SIG} + _sig::Val{SIG} config::C tape::T end -function DI.prepare_gradient(f, ::AutoReverseDiff{compile}, x) where {compile} +function DI.prepare_gradient( + strict::Val, f, backend::AutoReverseDiff{compile}, x +) where {compile} + _sig = DI.signature(f, backend, x; strict) if compile tape = ReverseDiff.compile(GradientTape(f, x)) - return ReverseDiffGradientPrep(; config=nothing, tape=tape) + return ReverseDiffGradientPrep(_sig, nothing, tape) else config = GradientConfig(x) - return ReverseDiffGradientPrep(; config=config, tape=nothing) + return ReverseDiffGradientPrep(_sig, config, nothing) end end function DI.value_and_gradient!( - f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x + f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251 if compile result = gradient!(result, prep.tape, x) @@ -99,6 +108,7 @@ end function DI.value_and_gradient( f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) # GradientResult tries to mutate an SArray result = MutableDiffResult(zero(eltype(x)), (similar(x),)) if compile @@ -110,8 +120,9 @@ function DI.value_and_gradient( end function DI.gradient!( - f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x + f, grad, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return gradient!(grad, prep.tape, x) else @@ -120,8 +131,9 @@ function DI.gradient!( end function DI.gradient( - f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return gradient!(prep.tape, x) else @@ -132,20 +144,22 @@ end ### With contexts function DI.prepare_gradient( - f, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) config = GradientConfig(x) - return ReverseDiffGradientPrep(; config=config, tape=nothing) + return ReverseDiffGradientPrep(_sig, config, nothing) end function DI.value_and_gradient!( f, grad, prep::ReverseDiffGradientPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251 result = gradient!(result, fc, x, prep.config) @@ -153,8 +167,13 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) # GradientResult tries to mutate an SArray result = MutableDiffResult(zero(eltype(x)), (similar(x),)) @@ -166,17 +185,23 @@ function DI.gradient!( f, grad, prep::ReverseDiffGradientPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return gradient!(grad, fc, x, prep.config) end function DI.gradient( - f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::ReverseDiffGradientPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return gradient(fc, x, prep.config) end @@ -185,24 +210,29 @@ end ### Without contexts -@kwdef struct ReverseDiffOneArgJacobianPrep{C,T} <: DI.JacobianPrep +struct ReverseDiffOneArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} config::C tape::T end -function DI.prepare_jacobian(f, ::AutoReverseDiff{compile}, x) where {compile} +function DI.prepare_jacobian( + strict::Val, f, backend::AutoReverseDiff{compile}, x +) where {compile} + _sig = DI.signature(f, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f, x)) - return ReverseDiffOneArgJacobianPrep(; config=nothing, tape=tape) + return ReverseDiffOneArgJacobianPrep(_sig, nothing, tape) else config = JacobianConfig(x) - return ReverseDiffOneArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffOneArgJacobianPrep(_sig, config, nothing) end end function DI.value_and_jacobian!( - f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) y = f(x) result = DiffResult(y, (jac,)) if compile @@ -216,8 +246,9 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return f(x), jacobian!(prep.tape, x) else @@ -226,8 +257,9 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, jac, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return jacobian!(jac, prep.tape, x) else @@ -236,8 +268,9 @@ function DI.jacobian!( end function DI.jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffOneArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return jacobian!(prep.tape, x) else @@ -248,20 +281,22 @@ end ### With contexts function DI.prepare_jacobian( - f, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) config = JacobianConfig(x) - return ReverseDiffOneArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffOneArgJacobianPrep(_sig, config, nothing) end function DI.value_and_jacobian!( f, jac, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) y = fc(x) result = DiffResult(y, (jac,)) @@ -274,10 +309,11 @@ end function DI.value_and_jacobian( f, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return fc(x), jacobian(fc, x, prep.config) end @@ -286,10 +322,11 @@ function DI.jacobian!( f, jac, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return jacobian!(jac, fc, x, prep.config) end @@ -297,10 +334,11 @@ end function DI.jacobian( f, prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return jacobian(fc, x, prep.config) end @@ -309,30 +347,31 @@ end ### Without contexts -@kwdef struct ReverseDiffHessianPrep{G<:ReverseDiffGradientPrep,HC,HT} <: DI.HessianPrep +struct ReverseDiffHessianPrep{SIG,G<:ReverseDiffGradientPrep,HC,HT} <: DI.HessianPrep{SIG} + _sig::Val{SIG} gradient_prep::G hessian_config::HC hessian_tape::HT end -function DI.prepare_hessian(f, backend::AutoReverseDiff{compile}, x) where {compile} - gradient_prep = DI.prepare_gradient(f, backend, x) +function DI.prepare_hessian( + strict::Val, f, backend::AutoReverseDiff{compile}, x +) where {compile} + _sig = DI.signature(f, backend, x; strict) + gradient_prep = DI.prepare_gradient(strict, f, backend, x) if compile hessian_tape = ReverseDiff.compile(HessianTape(f, x)) - return ReverseDiffHessianPrep(; - gradient_prep, hessian_config=nothing, hessian_tape=hessian_tape - ) + return ReverseDiffHessianPrep(_sig, gradient_prep, nothing, hessian_tape) else hessian_config = HessianConfig(x) - return ReverseDiffHessianPrep(; - gradient_prep, hessian_config=hessian_config, hessian_tape=nothing - ) + return ReverseDiffHessianPrep(_sig, gradient_prep, hessian_config, nothing) end end function DI.hessian!( - f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x + f, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return hessian!(hess, prep.hessian_tape, x) else @@ -341,8 +380,9 @@ function DI.hessian!( end function DI.hessian( - f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x + f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) if compile return hessian!(prep.hessian_tape, x) else @@ -353,6 +393,7 @@ end function DI.value_gradient_and_hessian!( f, grad, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) y = f(x) DI.gradient!(f, grad, prep.gradient_prep, backend, x) DI.hessian!(f, hess, prep, backend, x) @@ -362,6 +403,7 @@ end function DI.value_gradient_and_hessian( f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f, prep, backend, x) y = f(x) grad = DI.gradient(f, prep.gradient_prep, backend, x) hess = DI.hessian(f, prep, backend, x) @@ -371,30 +413,35 @@ end ### With contexts function DI.prepare_hessian( - f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} - gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) + _sig = DI.signature(f, backend, x, contexts...; strict) + gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) hessian_config = HessianConfig(x) - return ReverseDiffHessianPrep(; - gradient_prep, hessian_config=hessian_config, hessian_tape=nothing - ) + return ReverseDiffHessianPrep(_sig, gradient_prep, hessian_config, nothing) end function DI.hessian!( f, hess, prep::ReverseDiffHessianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return hessian!(hess, fc, x, prep.hessian_config) end function DI.hessian( - f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + f, + prep::ReverseDiffHessianPrep, + backend::AutoReverseDiff, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) return hessian(fc, x, prep.hessian_config) end @@ -408,6 +455,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) @@ -421,6 +469,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(DI.unwrap, contexts)...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index f76e7c824..531212fe5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -1,9 +1,16 @@ ## Pullback function DI.prepare_pullback( - f!, y, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f!, + y, + backend::AutoReverseDiff, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} - return DI.NoPullbackPrep() + _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end ### Array in @@ -11,12 +18,13 @@ end function DI.value_and_pullback( f!, y, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -34,12 +42,13 @@ function DI.value_and_pullback!( f!, y, tx::NTuple, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -57,12 +66,13 @@ end function DI.pullback( f!, y, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -79,12 +89,13 @@ function DI.pullback!( f!, y, tx::NTuple, - ::DI.NoPullbackPrep, - ::AutoReverseDiff, + prep::DI.NoPullbackPrep, + backend::AutoReverseDiff, x::AbstractArray, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) fc! = DI.with_contexts(f!, contexts...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) @@ -103,12 +114,13 @@ end function DI.value_and_pullback( f!, y, - ::DI.NoPullbackPrep, + prep::DI.NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, ty, contexts...) x_array = [x] function f!_array(_y::AbstractArray, _x_array, args...) return f!(_y, only(_x_array), args...) @@ -121,24 +133,29 @@ end ### Without contexts -@kwdef struct ReverseDiffTwoArgJacobianPrep{C,T} <: DI.JacobianPrep +struct ReverseDiffTwoArgJacobianPrep{SIG,C,T} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} config::C tape::T end -function DI.prepare_jacobian(f!, y, ::AutoReverseDiff{compile}, x) where {compile} +function DI.prepare_jacobian( + strict::Val, f!, y, backend::AutoReverseDiff{compile}, x +) where {compile} + _sig = DI.signature(f!, y, backend, x; strict) if compile tape = ReverseDiff.compile(JacobianTape(f!, y, x)) - return ReverseDiffTwoArgJacobianPrep(; config=nothing, tape=tape) + return ReverseDiffTwoArgJacobianPrep(_sig, nothing, tape) else config = JacobianConfig(y, x) - return ReverseDiffTwoArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffTwoArgJacobianPrep(_sig, config, nothing) end end function DI.value_and_jacobian( - f!, y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) if compile @@ -150,8 +167,9 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) result = MutableDiffResult(y, (jac,)) if compile result = jacobian!(result, prep.tape, x) @@ -162,8 +180,9 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, y, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) if compile jac = jacobian!(prep.tape, x) else @@ -173,8 +192,9 @@ function DI.jacobian( end function DI.jacobian!( - f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, ::AutoReverseDiff{compile}, x + f!, y, jac, prep::ReverseDiffTwoArgJacobianPrep, backend::AutoReverseDiff{compile}, x ) where {compile} + DI.check_prep(f!, y, prep, backend, x) if compile jac = jacobian!(jac, prep.tape, x) else @@ -186,20 +206,22 @@ end ### With contexts function DI.prepare_jacobian( - f!, y, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} + strict::Val, f!, y, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) config = JacobianConfig(y, x) - return ReverseDiffTwoArgJacobianPrep(; config=config, tape=nothing) + return ReverseDiffTwoArgJacobianPrep(_sig, config, nothing) end function DI.value_and_jacobian( f!, y, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) @@ -212,10 +234,11 @@ function DI.value_and_jacobian!( y, jac, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) result = jacobian!(result, fc!, y, x, prep.config) @@ -226,10 +249,11 @@ function DI.jacobian( f!, y, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) jac = jacobian(fc!, y, x, prep.config) return jac @@ -240,10 +264,11 @@ function DI.jacobian!( y, jac, prep::ReverseDiffTwoArgJacobianPrep, - ::AutoReverseDiff, + backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) fc! = DI.with_contexts(f!, contexts...) jac = jacobian!(jac, fc!, y, x, prep.config) return jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl index 1bc99a3e7..203992008 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl @@ -28,7 +28,7 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite if DI.pushforward_performance(backend) isa DI.PushforwardFast p = similar(y) prep = DI.prepare_pushforward_same_point( - f, backend, x, (DI.basis(x, first(eachindex(x))),) + f, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) ) for (kj, j) in enumerate(eachindex(x)) pushforward!(f, (p,), prep, backend, x, (DI.basis(x, j),)) @@ -42,7 +42,7 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite else p = similar(x) prep = DI.prepare_pullback_same_point( - f, backend, x, (DI.basis(y, first(eachindex(y))),) + f, backend, x, (DI.basis(y, first(eachindex(y))),); strict=Val(true) ) for (ki, i) in enumerate(eachindex(y)) pullback!(f, (p,), prep, backend, x, (DI.basis(y, i),)) @@ -64,7 +64,7 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{ if DI.pushforward_performance(backend) isa DI.PushforwardFast p = similar(y) prep = DI.prepare_pushforward_same_point( - f!, y, backend, x, (DI.basis(x, first(eachindex(x))),) + f!, y, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) ) for (kj, j) in enumerate(eachindex(x)) pushforward!(f!, y, (p,), prep, backend, x, (DI.basis(x, j),)) @@ -78,7 +78,7 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{ else p = similar(x) prep = DI.prepare_pullback_same_point( - f!, y, backend, x, (DI.basis(y, first(eachindex(y))),) + f!, y, backend, x, (DI.basis(y, first(eachindex(y))),); strict=Val(true) ) for (ki, i) in enumerate(eachindex(y)) pullback!(f!, y, (p,), prep, backend, x, (DI.basis(y, i),)) @@ -98,7 +98,9 @@ function ADTypes.hessian_sparsity(f, x, detector::DI.DenseSparsityDetector{:iter n = length(x) I, J = Int[], Int[] p = similar(x) - prep = DI.prepare_hvp_same_point(f, backend, x, (DI.basis(x, first(eachindex(x))),)) + prep = DI.prepare_hvp_same_point( + f, backend, x, (DI.basis(x, first(eachindex(x))),); strict=Val(true) + ) for (kj, j) in enumerate(eachindex(x)) hvp!(f, (p,), prep, backend, x, (DI.basis(x, j),)) for ki in LinearIndices(p) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl index ea8ef94b9..2de265807 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl @@ -15,7 +15,7 @@ using SparseMatrixColorings: decompress! import SparseMatrixColorings as SMC -abstract type SparseJacobianPrep <: DI.JacobianPrep end +abstract type SparseJacobianPrep{SIG} <: DI.JacobianPrep{SIG} end SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result) SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 61131dc71..1084583f2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -1,4 +1,5 @@ struct SparseHessianPrep{ + SIG, BS<:DI.BatchSizeSettings, C<:AbstractColoringResult{:symmetric,:column}, M<:AbstractMatrix{<:Number}, @@ -6,7 +7,8 @@ struct SparseHessianPrep{ R<:AbstractVector{<:NTuple}, E2<:DI.HVPPrep, E1<:DI.GradientPrep, -} <: DI.HessianPrep +} <: DI.HessianPrep{SIG} + _sig::Val{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -24,7 +26,7 @@ SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result) ## Hessian, one argument function DI.prepare_hessian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) sparsity = DI.hessian_sparsity_with_contexts( @@ -37,18 +39,20 @@ function DI.prepare_hessian( N = length(column_groups(coloring_result)) batch_size_settings = DI.pick_batchsize(DI.outer(dense_backend), N) return _prepare_sparse_hessian_aux( - batch_size_settings, coloring_result, f, backend, x, contexts... + strict, batch_size_settings, coloring_result, f, backend, x, contexts... ) end function _prepare_sparse_hessian_aux( + strict::Val, batch_size_settings::DI.BatchSizeSettings{B}, coloring_result::AbstractColoringResult{:symmetric,:column}, f::F, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {B,F,C} + _sig = DI.signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) @@ -58,9 +62,10 @@ function _prepare_sparse_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = DI.prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...) - gradient_prep = DI.prepare_gradient(f, DI.inner(dense_backend), x, contexts...) + hvp_prep = DI.prepare_hvp(strict, f, dense_backend, x, batched_seeds[1], contexts...) + gradient_prep = DI.prepare_gradient(strict, f, DI.inner(dense_backend), x, contexts...) return SparseHessianPrep( + _sig, batch_size_settings, coloring_result, compressed_matrix, @@ -74,11 +79,12 @@ end function DI.hessian!( f::F, hess, - prep::SparseHessianPrep{<:DI.BatchSizeSettings{B}}, + prep::SparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {F,B,C} +) where {F,SIG,B,C} + DI.check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, coloring_result, @@ -118,8 +124,9 @@ function DI.hessian!( end function DI.hessian( - f::F, prep::SparseHessianPrep{B}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} -) where {F,B,C} + f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} +) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) hess = similar(sparsity_pattern(prep), eltype(x)) return DI.hessian!(f, hess, prep, backend, x, contexts...) end @@ -133,6 +140,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, DI.inner(dense_ad(backend)), x, contexts... ) @@ -143,6 +151,7 @@ end function DI.value_gradient_and_hessian( f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, DI.inner(dense_ad(backend)), x, contexts... ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 4193fd7a0..85eac3b0f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -1,13 +1,15 @@ ## Preparation struct PushforwardSparseJacobianPrep{ + SIG, BS<:DI.BatchSizeSettings, C<:AbstractColoringResult{:nonsymmetric,:column}, M<:AbstractMatrix{<:Number}, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:DI.PushforwardPrep, -} <: SparseJacobianPrep +} <: SparseJacobianPrep{SIG} + _sig::Val{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -17,13 +19,15 @@ struct PushforwardSparseJacobianPrep{ end struct PullbackSparseJacobianPrep{ + SIG, BS<:DI.BatchSizeSettings, C<:AbstractColoringResult{:nonsymmetric,:row}, M<:AbstractMatrix{<:Number}, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:DI.PullbackPrep, -} <: SparseJacobianPrep +} <: SparseJacobianPrep{SIG} + _sig::Val{SIG} batch_size_settings::BS coloring_result::C compressed_matrix::M @@ -33,29 +37,30 @@ struct PullbackSparseJacobianPrep{ end function DI.prepare_jacobian( - f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} + strict::Val, f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) y = f(x, map(DI.unwrap, contexts)...) perf = DI.pushforward_performance(dense_backend) - return _prepare_sparse_jacobian_aux(perf, y, (f,), backend, x, contexts...) + return _prepare_sparse_jacobian_aux(strict, perf, y, (f,), backend, x, contexts...) end function DI.prepare_jacobian( - f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} + strict::Val, f!::F, y, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) perf = DI.pushforward_performance(dense_backend) - return _prepare_sparse_jacobian_aux(perf, y, (f!, y), backend, x, contexts...) + return _prepare_sparse_jacobian_aux(strict, perf, y, (f!, y), backend, x, contexts...) end function _prepare_sparse_jacobian_aux( + strict::Val, perf::DI.PushforwardPerformance, y, f_or_f!y::FY, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -79,19 +84,21 @@ function _prepare_sparse_jacobian_aux( end batch_size_settings = DI.pick_batchsize(dense_backend, N) return _prepare_sparse_jacobian_aux_aux( - batch_size_settings, coloring_result, y, f_or_f!y, backend, x, contexts... + strict, batch_size_settings, coloring_result, y, f_or_f!y, backend, x, contexts... ) end function _prepare_sparse_jacobian_aux_aux( + strict::Val, batch_size_settings::DI.BatchSizeSettings{B}, coloring_result::AbstractColoringResult{:nonsymmetric,:column}, y, f_or_f!y::FY, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {B,FY,C} + _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) @@ -102,9 +109,10 @@ function _prepare_sparse_jacobian_aux_aux( ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = DI.prepare_pushforward( - f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... + strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PushforwardSparseJacobianPrep( + _sig, batch_size_settings, coloring_result, compressed_matrix, @@ -115,14 +123,16 @@ function _prepare_sparse_jacobian_aux_aux( end function _prepare_sparse_jacobian_aux_aux( + strict::Val, batch_size_settings::DI.BatchSizeSettings{B}, coloring_result::AbstractColoringResult{:nonsymmetric,:row}, y, f_or_f!y::FY, backend::AutoSparse, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {B,FY,C} + _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = row_groups(coloring_result) @@ -133,9 +143,10 @@ function _prepare_sparse_jacobian_aux_aux( ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] pullback_prep = DI.prepare_pullback( - f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... + strict, f_or_f!y..., dense_backend, x, batched_seeds[1], contexts... ) return PullbackSparseJacobianPrep( + _sig, batch_size_settings, coloring_result, compressed_matrix, @@ -155,12 +166,14 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return _sparse_jacobian_aux!((f,), jac, prep, backend, x, contexts...) end function DI.jacobian( f::F, prep::SparseJacobianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) jac = similar(sparsity_pattern(prep), eltype(x)) return DI.jacobian!(f, jac, prep, backend, x, contexts...) end @@ -168,6 +181,7 @@ end function DI.value_and_jacobian( f::F, prep::SparseJacobianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -179,6 +193,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian!(f, jac, prep, backend, x, contexts...) end @@ -194,6 +209,7 @@ function DI.jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return _sparse_jacobian_aux!((f!, y), jac, prep, backend, x, contexts...) end @@ -205,6 +221,7 @@ function DI.jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = similar(sparsity_pattern(prep), promote_type(eltype(x), eltype(y))) return DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) end @@ -217,6 +234,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac @@ -231,6 +249,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac @@ -241,11 +260,11 @@ end function _sparse_jacobian_aux!( f_or_f!y::FY, jac, - prep::PushforwardSparseJacobianPrep{<:DI.BatchSizeSettings{B}}, + prep::PushforwardSparseJacobianPrep{SIG,<:DI.BatchSizeSettings{B}}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, coloring_result, @@ -287,11 +306,11 @@ end function _sparse_jacobian_aux!( f_or_f!y::FY, jac, - prep::PullbackSparseJacobianPrep{<:DI.BatchSizeSettings{B}}, + prep::PullbackSparseJacobianPrep{SIG,<:DI.BatchSizeSettings{B}}, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, coloring_result, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index 4ab180aad..c5d0d2847 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -1,6 +1,7 @@ ## Preparation struct MixedModeSparseJacobianPrep{ + SIG, BSf<:DI.BatchSizeSettings, BSr<:DI.BatchSizeSettings, C<:AbstractColoringResult{:nonsymmetric,:bidirectional}, @@ -11,7 +12,8 @@ struct MixedModeSparseJacobianPrep{ Rr<:Vector{<:NTuple}, Ef<:DI.PushforwardPrep, Er<:DI.PullbackPrep, -} <: SparseJacobianPrep +} <: SparseJacobianPrep{SIG} + _sig::Val{SIG} batch_size_settings_forward::BSf batch_size_settings_reverse::BSr coloring_result::C @@ -26,20 +28,34 @@ struct MixedModeSparseJacobianPrep{ end function DI.prepare_jacobian( - f::F, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C} + strict::Val, + f::F, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context,C}; ) where {F,C} y = f(x, map(DI.unwrap, contexts)...) - return _prepare_mixed_sparse_jacobian_aux(y, (f,), backend, x, contexts...) + return _prepare_mixed_sparse_jacobian_aux(strict, y, (f,), backend, x, contexts...) end function DI.prepare_jacobian( - f!::F, y, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C} + strict::Val, + f!::F, + y, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context,C}; ) where {F,C} - return _prepare_mixed_sparse_jacobian_aux(y, (f!, y), backend, x, contexts...) + return _prepare_mixed_sparse_jacobian_aux(strict, y, (f!, y), backend, x, contexts...) end function _prepare_mixed_sparse_jacobian_aux( - y, f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C} + strict::Val, + y, + f_or_f!y::FY, + backend::AutoSparse{<:DI.MixedMode}, + x, + contexts::Vararg{DI.Context,C}; ) where {FY,C} dense_backend = dense_ad(backend) sparsity = DI.jacobian_sparsity_with_contexts( @@ -59,6 +75,7 @@ function _prepare_mixed_sparse_jacobian_aux( batch_size_settings_reverse = DI.pick_batchsize(DI.reverse_backend(dense_backend), Nr) return _prepare_mixed_sparse_jacobian_aux_aux( + strict, batch_size_settings_forward, batch_size_settings_reverse, coloring_result, @@ -66,11 +83,12 @@ function _prepare_mixed_sparse_jacobian_aux( f_or_f!y, backend, x, - contexts..., + contexts...; ) end function _prepare_mixed_sparse_jacobian_aux_aux( + strict::Val, batch_size_settings_forward::DI.BatchSizeSettings{Bf}, batch_size_settings_reverse::DI.BatchSizeSettings{Br}, coloring_result::AbstractColoringResult{:nonsymmetric,:bidirectional}, @@ -78,8 +96,9 @@ function _prepare_mixed_sparse_jacobian_aux_aux( f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {Bf,Br,FY,C} + _sig = DI.signature(f_or_f!y..., backend, x, contexts...; strict) Nf, Af = batch_size_settings_forward.N, batch_size_settings_forward.A Nr, Ar = batch_size_settings_reverse.N, batch_size_settings_reverse.A @@ -109,21 +128,24 @@ function _prepare_mixed_sparse_jacobian_aux_aux( ] pushforward_prep = DI.prepare_pushforward( + strict, f_or_f!y..., DI.forward_backend(dense_backend), x, batched_seeds_forward[1], - contexts..., + contexts...; ) pullback_prep = DI.prepare_pullback( + strict, f_or_f!y..., DI.reverse_backend(dense_backend), x, batched_seeds_reverse[1], - contexts..., + contexts...; ) return MixedModeSparseJacobianPrep( + _sig, batch_size_settings_forward, batch_size_settings_reverse, coloring_result, @@ -144,12 +166,12 @@ function _sparse_jacobian_aux!( f_or_f!y::FY, jac, prep::MixedModeSparseJacobianPrep{ - <:DI.BatchSizeSettings{Bf},<:DI.BatchSizeSettings{Br} + SIG,<:DI.BatchSizeSettings{Bf},<:DI.BatchSizeSettings{Br} }, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}, -) where {FY,Bf,Br,C} +) where {FY,SIG,Bf,Br,C} (; batch_size_settings_forward, batch_size_settings_reverse, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 9f0102a8d..105d8a6a1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -1,13 +1,15 @@ ## Pushforward -struct SymbolicsOneArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep +struct SymbolicsOneArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} pf_exe::E1 pf_exe!::E1! end function DI.prepare_pushforward( - f, ::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) dx_var = variablize(dx, :dx) @@ -22,17 +24,18 @@ function DI.prepare_pushforward( elseif res isa RuntimeGeneratedFunction res, nothing end - return SymbolicsOneArgPushforwardPrep(pf_exe, pf_exe!) + return SymbolicsOneArgPushforwardPrep(_sig, pf_exe, pf_exe!) end function DI.pushforward( f, prep::SymbolicsOneArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) ty = map(tx) do dx dy = prep.pf_exe(x, dx, map(DI.unwrap, contexts)...) end @@ -43,11 +46,12 @@ function DI.pushforward!( f, ty::NTuple, prep::SymbolicsOneArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.pf_exe!(dy, x, dx, map(DI.unwrap, contexts)...) @@ -63,6 +67,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward(f, prep, backend, x, tx, contexts...) end @@ -76,20 +81,23 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Derivative -struct SymbolicsOneArgDerivativePrep{E1,E1!} <: DI.DerivativePrep +struct SymbolicsOneArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) @@ -100,16 +108,17 @@ function DI.prepare_derivative( elseif res isa RuntimeGeneratedFunction res, nothing end - return SymbolicsOneArgDerivativePrep(der_exe, der_exe!) + return SymbolicsOneArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.derivative( f, prep::SymbolicsOneArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.der_exe(x, map(DI.unwrap, contexts)...) end @@ -117,10 +126,11 @@ function DI.derivative!( f, der, prep::SymbolicsOneArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end @@ -132,6 +142,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative(f, prep, backend, x, contexts...) end @@ -144,20 +155,23 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient -struct SymbolicsOneArgGradientPrep{E1,E1!} <: DI.GradientPrep +struct SymbolicsOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG} + _sig::Val{SIG} grad_exe::E1 grad_exe!::E1! end function DI.prepare_gradient( - f, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) # Symbolic.gradient only accepts vectors @@ -165,12 +179,17 @@ function DI.prepare_gradient( res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false)) (grad_exe, grad_exe!) = res - return SymbolicsOneArgGradientPrep(grad_exe, grad_exe!) + return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!) end function DI.gradient( - f, prep::SymbolicsOneArgGradientPrep, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + f, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return reshape(prep.grad_exe(vec(x), map(DI.unwrap, contexts)...), size(x)) end @@ -178,10 +197,11 @@ function DI.gradient!( f, grad, prep::SymbolicsOneArgGradientPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.grad_exe!(vec(grad), vec(x), map(DI.unwrap, contexts)...) return grad end @@ -193,6 +213,7 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end @@ -204,23 +225,27 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian -struct SymbolicsOneArgJacobianPrep{E1,E1!} <: DI.JacobianPrep +struct SymbolicsOneArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end function DI.prepare_jacobian( + strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) jac_var = if backend isa AutoSparse @@ -231,16 +256,17 @@ function DI.prepare_jacobian( res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) (jac_exe, jac_exe!) = res - return SymbolicsOneArgJacobianPrep(jac_exe, jac_exe!) + return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!) end function DI.jacobian( f, prep::SymbolicsOneArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.jac_exe(x, map(DI.unwrap, contexts)...) end @@ -248,10 +274,11 @@ function DI.jacobian!( f, jac, prep::SymbolicsOneArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end @@ -263,6 +290,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end @@ -274,24 +302,28 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return f(x, map(DI.unwrap, contexts)...), DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Hessian -struct SymbolicsOneArgHessianPrep{G,E2,E2!} <: DI.HessianPrep +struct SymbolicsOneArgHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG} + _sig::Val{SIG} gradient_prep::G hess_exe::E2 hess_exe!::E2! end function DI.prepare_hessian( + strict::Val, f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) # Symbolic.hessian only accepts vectors @@ -304,17 +336,18 @@ function DI.prepare_hessian( res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false)) (hess_exe, hess_exe!) = res - gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) - return SymbolicsOneArgHessianPrep(gradient_prep, hess_exe, hess_exe!) + gradient_prep = DI.prepare_gradient(strict, f, dense_ad(backend), x, contexts...) + return SymbolicsOneArgHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!) end function DI.hessian( f, prep::SymbolicsOneArgHessianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.hess_exe(vec(x), map(DI.unwrap, contexts)...) end @@ -322,10 +355,11 @@ function DI.hessian!( f, hess, prep::SymbolicsOneArgHessianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.hess_exe!(hess, vec(x), map(DI.unwrap, contexts)...) return hess end @@ -337,6 +371,7 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, grad = DI.value_and_gradient( f, prep.gradient_prep, dense_ad(backend), x, contexts... ) @@ -353,6 +388,7 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_gradient!( f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... ) @@ -362,15 +398,17 @@ end ## HVP -struct SymbolicsOneArgHVPPrep{G,E2,E2!} <: DI.HVPPrep +struct SymbolicsOneArgHVPPrep{SIG,G,E2,E2!} <: DI.HVPPrep{SIG} + _sig::Val{SIG} gradient_prep::G hvp_exe::E2 hvp_exe!::E2! end function DI.prepare_hvp( - f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) dx_var = variablize(dx, :dx) @@ -384,18 +422,19 @@ function DI.prepare_hvp( ) (hvp_exe, hvp_exe!) = res - gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) - return SymbolicsOneArgHVPPrep(gradient_prep, hvp_exe, hvp_exe!) + gradient_prep = DI.prepare_gradient(strict, f, backend, x, contexts...) + return SymbolicsOneArgHVPPrep(_sig, gradient_prep, hvp_exe, hvp_exe!) end function DI.hvp( f, prep::SymbolicsOneArgHVPPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return map(tx) do dx dg_vec = prep.hvp_exe(vec(x), vec(dx), map(DI.unwrap, contexts)...) reshape(dg_vec, size(x)) @@ -406,11 +445,12 @@ function DI.hvp!( f, tg::NTuple, prep::SymbolicsOneArgHVPPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] prep.hvp_exe!(vec(dg), vec(x), vec(dx), map(DI.unwrap, contexts)...) @@ -426,6 +466,7 @@ function DI.gradient_and_hvp( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) tg = DI.hvp(f, prep, backend, x, tx, contexts...) grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -441,6 +482,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) DI.hvp!(f, tg, prep, backend, x, tx, contexts...) DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) return grad, tg @@ -448,15 +490,17 @@ end ## Second derivative -struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: DI.SecondDerivativePrep +struct SymbolicsOneArgSecondDerivativePrep{SIG,D,E1,E1!} <: DI.SecondDerivativePrep{SIG} + _sig::Val{SIG} derivative_prep::D der2_exe::E1 der2_exe!::E1! end function DI.prepare_second_derivative( - f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) x_var = variablize(x, :x) context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) @@ -468,17 +512,18 @@ function DI.prepare_second_derivative( elseif res isa RuntimeGeneratedFunction res, nothing end - derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) - return SymbolicsOneArgSecondDerivativePrep(derivative_prep, der2_exe, der2_exe!) + derivative_prep = DI.prepare_derivative(strict, f, backend, x, contexts...) + return SymbolicsOneArgSecondDerivativePrep(_sig, derivative_prep, der2_exe, der2_exe!) end function DI.second_derivative( f, prep::SymbolicsOneArgSecondDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return prep.der2_exe(x, map(DI.unwrap, contexts)...) end @@ -486,10 +531,11 @@ function DI.second_derivative!( f, der2, prep::SymbolicsOneArgSecondDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) prep.der2_exe!(der2, x, map(DI.unwrap, contexts)...) return der2 end @@ -501,6 +547,7 @@ function DI.value_derivative_and_second_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) der2 = DI.second_derivative(f, prep, backend, x, contexts...) return y, der, der2 @@ -515,6 +562,7 @@ function DI.value_derivative_and_second_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) DI.second_derivative!(f, der2, prep, backend, x, contexts...) return y, der, der2 diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index ffe6ee0f4..58623720e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -1,13 +1,21 @@ ## Pushforward -struct SymbolicsTwoArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep +struct SymbolicsTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} pushforward_exe::E1 pushforward_exe!::E1! end function DI.prepare_pushforward( - f!, y, ::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, + f!, + y, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) dx = first(tx) x_var = variablize(x, :x) dx_var = variablize(dx, :dx) @@ -20,18 +28,19 @@ function DI.prepare_pushforward( res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) (pushforward_exe, pushforward_exe!) = res - return SymbolicsTwoArgPushforwardPrep(pushforward_exe, pushforward_exe!) + return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!) end function DI.pushforward( f!, y, prep::SymbolicsTwoArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx dy = prep.pushforward_exe(x, dx, map(DI.unwrap, contexts)...) end @@ -43,11 +52,12 @@ function DI.pushforward!( y, ty::NTuple, prep::SymbolicsTwoArgPushforwardPrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] prep.pushforward_exe!(dy, x, dx, map(DI.unwrap, contexts)...) @@ -64,6 +74,7 @@ function DI.value_and_pushforward( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -79,6 +90,7 @@ function DI.value_and_pushforward!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, ty @@ -86,14 +98,16 @@ end ## Derivative -struct SymbolicsTwoArgDerivativePrep{E1,E1!} <: DI.DerivativePrep +struct SymbolicsTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{SIG} + _sig::Val{SIG} der_exe::E1 der_exe!::E1! end function DI.prepare_derivative( - f!, y, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} + strict::Val, f!, y, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) y_var = variablize(y, :y) context_vars = variablize(contexts) @@ -102,17 +116,18 @@ function DI.prepare_derivative( res = build_function(der_var, x_var, context_vars...; expression=Val(false)) (der_exe, der_exe!) = res - return SymbolicsTwoArgDerivativePrep(der_exe, der_exe!) + return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!) end function DI.derivative( f!, y, prep::SymbolicsTwoArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return prep.der_exe(x, map(DI.unwrap, contexts)...) end @@ -121,10 +136,11 @@ function DI.derivative!( y, der, prep::SymbolicsTwoArgDerivativePrep, - ::AutoSymbolics, + backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end @@ -137,6 +153,7 @@ function DI.value_and_derivative( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) der = DI.derivative(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, der @@ -151,6 +168,7 @@ function DI.value_and_derivative!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.derivative!(f!, y, der, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, der @@ -158,18 +176,21 @@ end ## Jacobian -struct SymbolicsTwoArgJacobianPrep{E1,E1!} <: DI.JacobianPrep +struct SymbolicsTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG} + _sig::Val{SIG} jac_exe::E1 jac_exe!::E1! end function DI.prepare_jacobian( + strict::Val, f!, y, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, - contexts::Vararg{DI.Context,C}, + contexts::Vararg{DI.Context,C}; ) where {C} + _sig = DI.signature(f!, y, backend, x, contexts...; strict) x_var = variablize(x, :x) y_var = variablize(y, :y) context_vars = variablize(contexts) @@ -182,17 +203,18 @@ function DI.prepare_jacobian( res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) (jac_exe, jac_exe!) = res - return SymbolicsTwoArgJacobianPrep(jac_exe, jac_exe!) + return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!) end function DI.jacobian( f!, y, prep::SymbolicsTwoArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) return prep.jac_exe(x, map(DI.unwrap, contexts)...) end @@ -201,10 +223,11 @@ function DI.jacobian!( y, jac, prep::SymbolicsTwoArgJacobianPrep, - ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end @@ -217,6 +240,7 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) jac = DI.jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac @@ -231,6 +255,7 @@ function DI.value_and_jacobian!( x, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f!, y, prep, backend, x, contexts...) DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(DI.unwrap, contexts)...) return y, jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 17d885ca7..aaeae6cc7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -9,37 +9,47 @@ DI.inplace_support(::AutoTracker) = DI.InPlaceNotSupported() ## Pullback -struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep +struct TrackerPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} y::Y pb::PB end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} + strict::Val, + f, + backend::AutoTracker, + x, + ty::NTuple, + contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} - return DI.NoPullbackPrep() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( f, - ::DI.NoPullbackPrep, - ::AutoTracker, + prep::DI.NoPullbackPrep, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) + _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) y, pb = forward(f, x, map(DI.unwrap, contexts)...) - return TrackerPullbackPrepSamePoint(y, pb) + return TrackerPullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( f, - ::DI.NoPullbackPrep, - ::AutoTracker, + prep::DI.NoPullbackPrep, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy data(first(pb(dy))) @@ -50,11 +60,12 @@ end function DI.value_and_pullback( f, prep::TrackerPullbackPrepSamePoint, - ::AutoTracker, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy data(first(pb(dy))) @@ -65,11 +76,12 @@ end function DI.pullback( f, prep::TrackerPullbackPrepSamePoint, - ::AutoTracker, + backend::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy data(first(pb(dy))) @@ -80,21 +92,32 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} + strict::Val, f, backend::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}; ) where {C} - return DI.NoGradientPrep() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + prep::DI.NoGradientPrep, + backend::AutoTracker, + x, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) end @@ -107,6 +130,7 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end @@ -119,6 +143,7 @@ function DI.gradient!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index adf1c397e..c7a246211 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -21,27 +21,42 @@ translate(c::DI.Cache) = Buffer(DI.unwrap(c)) ## Pullback -struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep +struct ZygotePullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG} + _sig::Val{SIG} y::Y pb::PB end function DI.prepare_pullback( - f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} - return DI.NoPullbackPrep() + _sig = DI.signature(f, backend, x, ty, contexts...; strict) + return DI.NoPullbackPrep(_sig) end function DI.prepare_pullback_same_point( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.NoPullbackPrep, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}; ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) + _sig = DI.signature(f, backend, x, ty, contexts...; strict=DI.is_strict(prep)) y, pb = pullback(f, x, map(translate, contexts)...) - return ZygotePullbackPrepSamePoint(y, pb) + return ZygotePullbackPrepSamePoint(_sig, y, pb) end function DI.value_and_pullback( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.NoPullbackPrep, + backend::AutoZygote, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) y, pb = pullback(f, x, map(translate, contexts)...) tx = map(ty) do dy first(pb(dy)) @@ -52,11 +67,12 @@ end function DI.value_and_pullback( f, prep::ZygotePullbackPrepSamePoint, - ::AutoZygote, + backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; y, pb) = prep tx = map(ty) do dy first(pb(dy)) @@ -67,11 +83,12 @@ end function DI.pullback( f, prep::ZygotePullbackPrepSamePoint, - ::AutoZygote, + backend::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, ty, contexts...) (; pb) = prep tx = map(ty) do dy first(pb(dy)) @@ -81,20 +98,25 @@ end ## Gradient -function DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{DI.Context,C}) where {C} - return DI.NoGradientPrep() +function DI.prepare_gradient( + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} +) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoGradientPrep(_sig) end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) (; val, grad) = withgradient(f, x, map(translate, contexts)...) return val, first(grad) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) grad = gradient(f, x, map(translate, contexts)...) return first(grad) end @@ -102,6 +124,7 @@ end function DI.value_and_gradient!( f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end @@ -109,18 +132,23 @@ end function DI.gradient!( f, grad, prep::DI.NoGradientPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end ## Jacobian -function DI.prepare_jacobian(f, ::AutoZygote, x, contexts::Vararg{DI.Context,C}) where {C} - return DI.NoJacobianPrep() +function DI.prepare_jacobian( + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} +) where {C} + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoJacobianPrep(_sig) end function DI.value_and_jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y = f(x, map(translate, contexts)...) # https://github.com/FluxML/Zygote.jl/issues/1506 jac = jacobian(f, x, map(translate, contexts)...) @@ -128,8 +156,9 @@ function DI.value_and_jacobian( end function DI.jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C} + f, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) jac = jacobian(f, x, map(translate, contexts)...) return first(jac) end @@ -137,6 +166,7 @@ end function DI.value_and_jacobian!( f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end @@ -144,6 +174,7 @@ end function DI.jacobian!( f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Context,C} ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -151,47 +182,61 @@ end # Beware, this uses ForwardDiff for the inner differentiation +struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG} + _sig::Val{SIG} + fd_prep::P +end + function DI.prepare_hvp( - f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} + strict::Val, f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} ) where {C} - return DI.prepare_hvp(f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + fd_prep = DI.prepare_hvp( + strict, f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + ) + return ZygoteHVPPrep(_sig, fd_prep) end function DI.hvp( f, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.hvp( + f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + ) end function DI.hvp!( f, tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.hvp!( - f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + f, tg, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) end function DI.gradient_and_hvp( f, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... ) end @@ -199,28 +244,42 @@ function DI.gradient_and_hvp!( f, grad, tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, + prep::ZygoteHVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} + DI.check_prep(f, prep, backend, x, tx, contexts...) return DI.gradient_and_hvp!( - f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + f, + grad, + tg, + prep.fd_prep, + DI.SecondOrder(AutoForwardDiff(), backend), + x, + tx, + contexts..., ) end ## Hessian function DI.prepare_hessian( - f, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} + strict::Val, f, backend::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} - return DI.NoHessianPrep() + _sig = DI.signature(f, backend, x, contexts...; strict) + return DI.NoHessianPrep(_sig) end function DI.hessian( - f, ::DI.NoHessianPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} + f, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) fc = DI.with_contexts(f, contexts...) hess = hessian(fc, x) return hess @@ -234,6 +293,7 @@ function DI.hessian!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} + DI.check_prep(f, prep, backend, x, contexts...) return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end @@ -244,7 +304,8 @@ function DI.value_gradient_and_hessian( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} - y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + y, grad = DI.value_and_gradient(f, backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) return y, grad, hess end @@ -258,7 +319,8 @@ function DI.value_gradient_and_hessian!( x, contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} - y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) + DI.check_prep(f, prep, backend, x, contexts...) + y, _ = DI.value_and_gradient!(f, grad, backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) return y, grad, hess end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 2cbd312ff..984a194f7 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -41,13 +41,13 @@ include("docstrings.jl") include("first_order/mixed_mode.jl") include("second_order/second_order.jl") +include("utils/context.jl") include("utils/prep.jl") include("utils/traits.jl") include("utils/basis.jl") include("utils/batchsize.jl") include("utils/check.jl") include("utils/errors.jl") -include("utils/context.jl") include("utils/linalg.jl") include("utils/sparse.jl") diff --git a/DifferentiationInterface/src/docstrings.jl b/DifferentiationInterface/src/docstrings.jl index 66834d1e2..e14ccfb52 100644 --- a/DifferentiationInterface/src/docstrings.jl +++ b/DifferentiationInterface/src/docstrings.jl @@ -19,11 +19,14 @@ function docstring_prepare(operator; samepoint=false, inplace=false) Create a `prep` object that can be given to [`$(operator)`](@ref) and its variants to speed them up$(samepoint_warning(samepoint)). Depending on the backend, this can have several effects (preallocating memory, recording an execution trace) which are transparent to the user. + $(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "") !!! warning - The preparation result is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified. - Otherwise, preparation will be invalidated and you will need to run it again. - $(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "") + The preparation result `prep` is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified. + Otherwise, preparation becomes invalid and you need to run it again. + In some settings, invalid preparations may still give correct results (e.g. for backends that require no preparation), but this is not a semantic guarantee and should not be relied upon. + + When `strict=Val(true)`, type checking is enforced between preparation and execution (but size checking is left to the user). """ end diff --git a/DifferentiationInterface/src/fallbacks/change_prep.jl b/DifferentiationInterface/src/fallbacks/change_prep.jl index 48ea62583..492813391 100644 --- a/DifferentiationInterface/src/fallbacks/change_prep.jl +++ b/DifferentiationInterface/src/fallbacks/change_prep.jl @@ -45,22 +45,25 @@ for op in [ @eval function $prep_op!( f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - return $prep_op(f, backend, x, contexts...) + check_prep(f, old_prep, backend, x, contexts...) + return $prep_op(is_strict(old_prep), f, backend, x, contexts...) end op == :gradient && continue # 2-arg @eval function $prep_op!( - f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} - return $prep_op(f!, y, backend, x, contexts...) + check_prep(f!, y, old_prep, backend, x, contexts...) + return $prep_op(is_strict(old_prep), f!, y, backend, x, contexts...) end elseif op in (:second_derivative, :hessian) # 1-arg @eval function $prep_op!( - f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C} + f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} - return $prep_op(f, backend, x, contexts...) + check_prep(f, old_prep, backend, x, contexts...) + return $prep_op(is_strict(old_prep), f, backend, x, contexts...) end elseif op in (:pushforward, :pullback, :hvp) @@ -71,9 +74,10 @@ for op in [ backend::AbstractADType, x, seed::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} - return $prep_op(f, backend, x, seed, contexts...) + check_prep(f, old_prep, backend, x, seed, contexts...) + return $prep_op(is_strict(old_prep), f, backend, x, seed, contexts...) end @eval function $prep_op_same_point( f::F, @@ -83,12 +87,18 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, seed, contexts...) return prep end @eval function $prep_op_same_point( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + strict::Val, + f::F, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(strict, f, backend, x, seed, contexts...) return $prep_op_same_point(f, prep, backend, x, seed, contexts...) end op == :hvp && continue @@ -102,7 +112,8 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - return $prep_op(f!, y, backend, x, seed, contexts...) + check_prep(f!, y, old_prep, backend, x, seed, contexts...) + return $prep_op(is_strict(old_prep), f!, y, backend, x, seed, contexts...) end @eval function $prep_op_same_point( f!::F, @@ -113,12 +124,19 @@ for op in [ seed::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, seed, contexts...) return prep end @eval function $prep_op_same_point( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) + prep = $prep_op(strict, f!, y, backend, x, seed, contexts...) return $prep_op_same_point(f!, y, prep, backend, x, seed, contexts...) end end diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index 0f2ecd4e6..81c7b2bd0 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -45,25 +45,25 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op!(f, result, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result, prep, backend, x, contexts...) end op == :gradient && continue @@ -71,25 +71,25 @@ for op in [ @eval function $op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $op(f!, y, prep, backend, x, contexts...) end @eval function $op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $op!(f!, y, result, prep, backend, x, contexts...) end @eval function $val_and_op( f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $val_and_op(f!, y, prep, backend, x, contexts...) end @eval function $val_and_op!( f!::F, y, result, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, contexts...) return $val_and_op!(f!, y, result, prep, backend, x, contexts...) end @@ -98,51 +98,51 @@ for op in [ @eval function $op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op(f, prep, backend, x, contexts...) end @eval function $op!( f::F, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $op!(f, result2, prep, backend, x, contexts...) end @eval function $val_and_op( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op(f, prep, backend, x, contexts...) end @eval function $val_and_op!( f::F, result1, result2, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, contexts...) + prep = $prep_op(Val(true), f, backend, x, contexts...) return $val_and_op!(f, result1, result2, prep, backend, x, contexts...) end elseif op in (:pushforward, :pullback, :hvp) @eval function $op( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $op(f, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $op(f, prep, backend, x, tang, contexts...) end @eval function $op!( f::F, result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $op!(f, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $op!(f, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( - f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f::F, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $val_and_op(f, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $val_and_op(f, prep, backend, x, tang, contexts...) end if op in (:pushforward, :pullback) @@ -151,11 +151,11 @@ for op in [ result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $val_and_op!(f, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) + return $val_and_op!(f, result, prep, backend, x, tang, contexts...) end elseif op == :hvp @eval function $val_and_op!( @@ -164,12 +164,12 @@ for op in [ result2::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f, backend, x, tang, contexts...) return $val_and_op!( - f, result1, result2, prep, backend, x, seed, contexts... + f, result1, result2, prep, backend, x, tang, contexts... ) end end @@ -177,10 +177,10 @@ for op in [ op == :hvp && continue @eval function $op( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $op(f!, y, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $op(f!, y, prep, backend, x, tang, contexts...) end @eval function $op!( f!::F, @@ -188,17 +188,17 @@ for op in [ result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $op!(f!, y, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $op!(f!, y, result, prep, backend, x, tang, contexts...) end @eval function $val_and_op( - f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} + f!::F, y, backend::AbstractADType, x, tang::NTuple, contexts::Vararg{Context,C} ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $val_and_op(f!, y, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $val_and_op(f!, y, prep, backend, x, tang, contexts...) end @eval function $val_and_op!( f!::F, @@ -206,11 +206,11 @@ for op in [ result::NTuple, backend::AbstractADType, x, - seed::NTuple, + tang::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - prep = $prep_op(f!, y, backend, x, seed, contexts...) - return $val_and_op!(f!, y, result, prep, backend, x, seed, contexts...) + prep = $prep_op(Val(true), f!, y, backend, x, tang, contexts...) + return $val_and_op!(f!, y, result, prep, backend, x, tang, contexts...) end end end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 84b080ae5..c8753432d 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -1,12 +1,14 @@ ## Docstrings """ - prepare_derivative(f, backend, x, [contexts...]) -> prep - prepare_derivative(f!, y, backend, x, [contexts...]) -> prep + prepare_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_derivative(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("derivative"; inplace=true)) """ -function prepare_derivative end +function prepare_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_derivative(strict, args...) +end """ prepare!_derivative(f, prep, backend, x, [contexts...]) -> new_prep @@ -58,22 +60,27 @@ function derivative! end ## Preparation -struct PushforwardDerivativePrep{E<:PushforwardPrep} <: DerivativePrep +struct PushforwardDerivativePrep{SIG,E<:PushforwardPrep} <: DerivativePrep{SIG} + _sig::Val{SIG} pushforward_prep::E end function prepare_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} - pushforward_prep = prepare_pushforward(f, backend, x, (one(x),), contexts...) - return PushforwardDerivativePrep(pushforward_prep) + _sig = signature(f, backend, x, contexts...; strict) + pushforward_prep = prepare_pushforward(strict, f, backend, x, (one(x),), contexts...) + return PushforwardDerivativePrep(_sig, pushforward_prep) end function prepare_derivative( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} + strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} - pushforward_prep = prepare_pushforward(f!, y, backend, x, (one(x),), contexts...) - return PushforwardDerivativePrep(pushforward_prep) + _sig = signature(f!, y, backend, x, contexts...; strict) + pushforward_prep = prepare_pushforward( + strict, f!, y, backend, x, (one(x),), contexts... + ) + return PushforwardDerivativePrep(_sig, pushforward_prep) end ## One argument @@ -85,6 +92,7 @@ function value_and_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, ty = value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -99,6 +107,7 @@ function value_and_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -112,6 +121,7 @@ function derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) ty = pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) return only(ty) end @@ -124,6 +134,7 @@ function derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der end @@ -138,6 +149,7 @@ function value_and_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) y, ty = value_and_pushforward( f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -153,6 +165,7 @@ function value_and_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) y, _ = value_and_pushforward!( f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... ) @@ -167,6 +180,7 @@ function derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...) return only(ty) end @@ -180,6 +194,7 @@ function derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 17c39ac92..8ddfb7137 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -1,11 +1,13 @@ ## Docstrings """ - prepare_gradient(f, backend, x, [contexts...]) -> prep + prepare_gradient(f, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("gradient")) """ -function prepare_gradient end +function prepare_gradient(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_gradient(strict, args...) +end """ prepare!_gradient(f, prep, backend, x, [contexts...]) -> new_prep @@ -52,27 +54,31 @@ function gradient! end ## Preparation -struct PullbackGradientPrep{Y,E<:PullbackPrep} <: GradientPrep +struct PullbackGradientPrep{SIG,Y,E<:PullbackPrep} <: GradientPrep{SIG} + _sig::Val{SIG} + y::Y pullback_prep::E end function prepare_gradient( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + _sig = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? - pullback_prep = prepare_pullback(f, backend, x, (true,), contexts...) - return PullbackGradientPrep{typeof(y),typeof(pullback_prep)}(pullback_prep) + pullback_prep = prepare_pullback(strict, f, backend, x, (one(typeof(y)),), contexts...) + return PullbackGradientPrep(_sig, y, pullback_prep) end ## One argument function value_and_gradient( f::F, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) y, tx = value_and_pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...) return y, only(tx) end @@ -80,11 +86,12 @@ end function value_and_gradient!( f::F, grad, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pullback!( f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts... ) @@ -93,11 +100,12 @@ end function gradient( f::F, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) tx = pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...) return only(tx) end @@ -105,11 +113,12 @@ end function gradient!( f::F, grad, - prep::PullbackGradientPrep{Y}, + prep::PullbackGradientPrep{SIG,Y}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,Y,C} +) where {F,SIG,Y,C} + check_prep(f, prep, backend, x, contexts...) pullback!(f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...) return grad end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 546c2b80f..201cbea23 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -1,12 +1,14 @@ ## Docstrings """ - prepare_jacobian(f, backend, x, [contexts...]) -> prep - prepare_jacobian(f!, y, backend, x, [contexts...]) -> prep + prepare_jacobian(f, backend, x, [contexts...]; strict=Val(false)) -> prep + prepare_jacobian(f!, y, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("jacobian"; inplace=true)) """ -function prepare_jacobian end +function prepare_jacobian(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_jacobian(strict, args...) +end """ prepare!_jacobian(f, prep, backend, x, [contexts...]) -> new_prep @@ -58,14 +60,16 @@ function jacobian! end ## Preparation -abstract type StandardJacobianPrep <: JacobianPrep end +abstract type StandardJacobianPrep{SIG} <: JacobianPrep{SIG} end struct PushforwardJacobianPrep{ + SIG, BS<:BatchSizeSettings, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:PushforwardPrep, -} <: StandardJacobianPrep +} <: StandardJacobianPrep{SIG} + _sig::Val{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -73,11 +77,13 @@ struct PushforwardJacobianPrep{ end struct PullbackJacobianPrep{ + SIG, BS<:BatchSizeSettings, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:PullbackPrep, -} <: StandardJacobianPrep +} <: StandardJacobianPrep{SIG} + _sig::Val{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -85,7 +91,7 @@ struct PullbackJacobianPrep{ end function prepare_jacobian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} y = f(x, map(unwrap, contexts)...) perf = pushforward_performance(backend) @@ -97,12 +103,12 @@ function prepare_jacobian( end # function barrier return _prepare_jacobian_aux( - perf, batch_size_settings, y, (f,), backend, x, contexts... + strict, perf, batch_size_settings, y, (f,), backend, x, contexts... ) end function prepare_jacobian( - f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} + strict::Val, f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; ) where {F,C} perf = pushforward_performance(backend) # type-unstable @@ -113,19 +119,21 @@ function prepare_jacobian( end # function barrier return _prepare_jacobian_aux( - perf, batch_size_settings, y, (f!, y), backend, x, contexts... + strict, perf, batch_size_settings, y, (f!, y), backend, x, contexts... ) end function _prepare_jacobian_aux( + strict::Val, ::PushforwardFast, batch_size_settings::BatchSizeSettings{B}, y, f_or_f!y::FY, backend::AbstractADType, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {B,FY,C} + _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] batched_seeds = [ @@ -133,31 +141,35 @@ function _prepare_jacobian_aux( ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = prepare_pushforward( - f_or_f!y..., backend, x, batched_seeds[1], contexts... + strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... ) return PushforwardJacobianPrep( - batch_size_settings, batched_seeds, batched_results, pushforward_prep + _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep ) end function _prepare_jacobian_aux( + strict::Val, ::PushforwardSlow, batch_size_settings::BatchSizeSettings{B}, y, f_or_f!y::FY, backend::AbstractADType, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {B,FY,C} + _sig = signature(f_or_f!y..., backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(y, ind) for ind in eachindex(y)] batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - pullback_prep = prepare_pullback(f_or_f!y..., backend, x, batched_seeds[1], contexts...) + pullback_prep = prepare_pullback( + strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... + ) return PullbackJacobianPrep( - batch_size_settings, batched_seeds, batched_results, pullback_prep + _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep ) end @@ -170,6 +182,7 @@ function jacobian( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return _jacobian_aux((f,), prep, backend, x, contexts...) end @@ -181,18 +194,21 @@ function jacobian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return _jacobian_aux!((f,), jac, prep, backend, x, contexts...) end function value_and_jacobian( f::F, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return f(x, map(unwrap, contexts)...), jacobian(f, prep, backend, x, contexts...) end function value_and_jacobian!( f::F, jac, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + check_prep(f, prep, backend, x, contexts...) return f(x, map(unwrap, contexts)...), jacobian!(f, jac, prep, backend, x, contexts...) end @@ -206,6 +222,7 @@ function jacobian( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) return _jacobian_aux((f!, y), prep, backend, x, contexts...) end @@ -218,12 +235,14 @@ function jacobian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) return _jacobian_aux!((f!, y), jac, prep, backend, x, contexts...) end function value_and_jacobian( f!::F, y, prep::JacobianPrep, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) jac = jacobian(f!, y, prep, backend, x, contexts...) f!(y, x, map(unwrap, contexts)...) return y, jac @@ -238,6 +257,7 @@ function value_and_jacobian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, contexts...) jacobian!(f!, y, jac, prep, backend, x, contexts...) f!(y, x, map(unwrap, contexts)...) return y, jac @@ -247,11 +267,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PushforwardJacobianPrep{<:BatchSizeSettings{B,true,aligned}}, + prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B,true,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pushforward_prep) = prep (; B_last) = batch_size_settings dy_batch = pushforward( @@ -267,11 +287,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PushforwardJacobianPrep{<:BatchSizeSettings{B,false,aligned}}, + prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pushforward_prep) = prep (; A, B_last) = batch_size_settings @@ -300,11 +320,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PullbackJacobianPrep{<:BatchSizeSettings{B,true,aligned}}, + prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B,true,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pullback_prep) = prep (; B_last) = batch_size_settings dx_batch = pullback( @@ -323,11 +343,11 @@ end function _jacobian_aux( f_or_f!y::FY, - prep::PullbackJacobianPrep{<:BatchSizeSettings{B,false,aligned}}, + prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,aligned,C} +) where {FY,SIG,B,aligned,C} (; batch_size_settings, batched_seeds, pullback_prep) = prep (; A, B_last) = batch_size_settings @@ -355,11 +375,11 @@ end function _jacobian_aux!( f_or_f!y::FY, jac, - prep::PushforwardJacobianPrep{<:BatchSizeSettings{B}}, + prep::PushforwardJacobianPrep{SIG,<:BatchSizeSettings{B}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep (; N) = batch_size_settings @@ -391,11 +411,11 @@ end function _jacobian_aux!( f_or_f!y::FY, jac, - prep::PullbackJacobianPrep{<:BatchSizeSettings{B}}, + prep::PullbackJacobianPrep{SIG,<:BatchSizeSettings{B}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {FY,B,C} +) where {FY,SIG,B,C} (; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep (; N) = batch_size_settings diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index d80a4ff7e..63afa0656 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -1,12 +1,14 @@ ## Docstrings """ - prepare_pullback(f, backend, x, ty, [contexts...]) -> prep - prepare_pullback(f!, y, backend, x, ty, [contexts...]) -> prep + prepare_pullback(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep + prepare_pullback(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("pullback"; inplace=true)) """ -function prepare_pullback end +function prepare_pullback(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pullback(strict, args...) +end """ prepare!_pullback(f, prep, backend, x, ty, [contexts...]) -> new_prep @@ -17,12 +19,14 @@ $(docstring_prepare!("pullback")) function prepare!_pullback end """ - prepare_pullback_same_point(f, backend, x, ty, [contexts...]) -> prep_same - prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]) -> prep_same + prepare_pullback_same_point(f, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same + prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]; strict=Val(false)) -> prep_same $(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ -function prepare_pullback_same_point end +function prepare_pullback_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pullback_same_point(strict, args...) +end """ value_and_pullback(f, [prep,] backend, x, ty, [contexts...]) -> (y, tx) @@ -85,51 +89,62 @@ function pullback! end ## Preparation -struct PushforwardPullbackPrep{E} <: PullbackPrep +struct PushforwardPullbackPrep{SIG,E} <: PullbackPrep{SIG} + _sig::Val{SIG} pushforward_prep::E end function prepare_pullback( - f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pullback_aux( - pullback_performance(backend), f, backend, x, ty, contexts... + strict, pullback_performance(backend), f, backend, x, ty, contexts... ) end function prepare_pullback( - f!::F, y, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AbstractADType, + x, + ty::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pullback_aux( - pullback_performance(backend), f!, y, backend, x, ty, contexts... + strict, pullback_performance(backend), f!, y, backend, x, ty, contexts... ) end function _prepare_pullback_aux( + strict::Val, ::PullbackSlow, f::F, backend::AbstractADType, x, ty::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(f, backend, x, (dx,), contexts...) - return PushforwardPullbackPrep(pushforward_prep) + pushforward_prep = prepare_pushforward(strict, f, backend, x, (dx,), contexts...) + return PushforwardPullbackPrep(_sig, pushforward_prep) end function _prepare_pullback_aux( + strict::Val, ::PullbackSlow, f!::F, y, backend::AbstractADType, x, ty::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f!, y, backend, x, ty, contexts...; strict) dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) - pushforward_prep = prepare_pushforward(f!, y, backend, x, (dx,), contexts...) - return PushforwardPullbackPrep(pushforward_prep) + pushforward_prep = prepare_pushforward(strict, f!, y, backend, x, (dx,), contexts...) + return PushforwardPullbackPrep(_sig, pushforward_prep) end ## One argument @@ -202,6 +217,7 @@ function value_and_pullback( ty::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep y = f(x, map(unwrap, contexts)...) tx = ntuple( @@ -220,6 +236,7 @@ function value_and_pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) y, new_tx = value_and_pullback(f, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) return y, tx @@ -233,6 +250,7 @@ function pullback( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback(f, prep, backend, x, ty, contexts...)[2] end @@ -245,6 +263,7 @@ function pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2] end @@ -325,6 +344,7 @@ function value_and_pullback( ty::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep tx = ntuple( b -> _pullback_via_pushforward( @@ -346,6 +366,7 @@ function value_and_pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) y, new_tx = value_and_pullback(f!, y, prep, backend, x, ty, contexts...) foreach(copyto!, tx, new_tx) return y, tx @@ -360,6 +381,7 @@ function pullback( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback(f!, y, prep, backend, x, ty, contexts...)[2] end @@ -373,5 +395,6 @@ function pullback!( ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback!(f!, y, tx, prep, backend, x, ty, contexts...)[2] end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 91a967ad6..883681031 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -1,12 +1,14 @@ ## Docstrings """ - prepare_pushforward(f, backend, x, tx, [contexts...]) -> prep - prepare_pushforward(f!, y, backend, x, tx, [contexts...]) -> prep + prepare_pushforward(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep + prepare_pushforward(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("pushforward"; inplace=true)) """ -function prepare_pushforward end +function prepare_pushforward(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pushforward(strict, args...) +end """ prepare!_pushforward(f, prep, backend, x, tx, [contexts...]) -> new_prep @@ -17,12 +19,14 @@ $(docstring_prepare!("pushforward")) function prepare!_pushforward end """ - prepare_pushforward_same_point(f, backend, x, tx, [contexts...]) -> prep_same - prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]) -> prep_same + prepare_pushforward_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same + prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same $(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ -function prepare_pushforward_same_point end +function prepare_pushforward_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_pushforward_same_point(strict, args...) +end """ value_and_pushforward(f, [prep,] backend, x, tx, [contexts...]) -> (y, ty) @@ -85,52 +89,63 @@ function pushforward! end ## Preparation -struct PullbackPushforwardPrep{E} <: PushforwardPrep +struct PullbackPushforwardPrep{SIG,E} <: PushforwardPrep{SIG} + _sig::Val{SIG} pullback_prep::E end function prepare_pushforward( - f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pushforward_aux( - pushforward_performance(backend), f, backend, x, tx, contexts... + strict, pushforward_performance(backend), f, backend, x, tx, contexts... ) end function prepare_pushforward( - f!::F, y, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} return _prepare_pushforward_aux( - pushforward_performance(backend), f!, y, backend, x, tx, contexts... + strict, pushforward_performance(backend), f!, y, backend, x, tx, contexts... ) end function _prepare_pushforward_aux( + strict::Val, ::PushforwardSlow, f::F, backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(f, backend, x, (dy,), contexts...) - return PullbackPushforwardPrep(pullback_prep) + pullback_prep = prepare_pullback(strict, f, backend, x, (dy,), contexts...) + return PullbackPushforwardPrep(_sig, pullback_prep) end function _prepare_pushforward_aux( + strict::Val, ::PushforwardSlow, f!::F, y, backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f!, y, backend, x, tx, contexts...; strict) dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) - pullback_prep = prepare_pullback(f!, y, backend, x, (dy,), contexts...) - return PullbackPushforwardPrep(pullback_prep) + pullback_prep = prepare_pullback(strict, f!, y, backend, x, (dy,), contexts...) + return PullbackPushforwardPrep(_sig, pullback_prep) end ## One argument @@ -205,6 +220,7 @@ function value_and_pushforward( tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, tx, contexts...) (; pullback_prep) = prep y = f(x, map(unwrap, contexts)...) ty = ntuple( @@ -223,6 +239,7 @@ function value_and_pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) y, new_ty = value_and_pushforward(f, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) return y, ty @@ -236,6 +253,7 @@ function pushforward( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] end @@ -248,6 +266,7 @@ function pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...)[2] end @@ -297,6 +316,7 @@ function value_and_pushforward( tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) (; pullback_prep) = prep ty = ntuple( b -> @@ -317,6 +337,7 @@ function value_and_pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) y, new_ty = value_and_pushforward(f!, y, prep, backend, x, tx, contexts...) foreach(copyto!, ty, new_ty) return y, ty @@ -331,6 +352,7 @@ function pushforward( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2] end @@ -344,6 +366,7 @@ function pushforward!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)[2] end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 496675503..35c21a71e 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -1,14 +1,14 @@ abstract type FromPrimitive{inplace} <: AbstractADType end -check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) +check_available(backend::FromPrimitive) = check_available(backend.backend) inplace_support(::FromPrimitive{true}) = InPlaceSupported() inplace_support(::FromPrimitive{false}) = InPlaceNotSupported() -function inner_preparation_behavior(fromprim::FromPrimitive) - return inner_preparation_behavior(fromprim.backend) +function inner_preparation_behavior(backend::FromPrimitive) + return inner_preparation_behavior(backend.backend) end -function pick_batchsize(fromprim::FromPrimitive, N::Integer) - return pick_batchsize(fromprim.backend, N) +function pick_batchsize(backend::FromPrimitive, N::Integer) + return pick_batchsize(backend.backend, N) end """ @@ -29,41 +29,56 @@ end ADTypes.mode(::AutoForwardFromPrimitive) = ADTypes.ForwardMode() function threshold_batchsize( - fromprim::AutoForwardFromPrimitive{inplace}, dimension::Integer + backend::AutoForwardFromPrimitive{inplace}, dimension::Integer ) where {inplace} return AutoForwardFromPrimitive( - threshold_batchsize(fromprim.backend, dimension); inplace + threshold_batchsize(backend.backend, dimension); inplace ) end -struct FromPrimitivePushforwardPrep{E<:PushforwardPrep} <: PushforwardPrep +struct FromPrimitivePushforwardPrep{SIG,E<:PushforwardPrep} <: PushforwardPrep{SIG} + _sig::Val{SIG} pushforward_prep::E end function prepare_pushforward( - f::F, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, + f::F, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - primitive_prep = prepare_pushforward(f, fromprim.backend, x, tx, contexts...) - return FromPrimitivePushforwardPrep(primitive_prep) + _sig = signature(f, backend, x, tx, contexts...; strict) + primitive_prep = prepare_pushforward(strict, f, backend.backend, x, tx, contexts...) + return FromPrimitivePushforwardPrep(_sig, primitive_prep) end function prepare_pushforward( - f!::F, y, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - primitive_prep = prepare_pushforward(f!, y, fromprim.backend, x, tx, contexts...) - return FromPrimitivePushforwardPrep(primitive_prep) + _sig = signature(f!, y, backend, x, tx, contexts...; strict) + primitive_prep = prepare_pushforward(strict, f!, y, backend.backend, x, tx, contexts...) + return FromPrimitivePushforwardPrep(_sig, primitive_prep) end function value_and_pushforward( f::F, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward( - f, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -71,13 +86,14 @@ function value_and_pushforward( f!::F, y, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward( - f!, y, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f!, y, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -85,13 +101,14 @@ function value_and_pushforward!( f::F, ty::NTuple, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return value_and_pushforward!( - f, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f, ty, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -100,13 +117,14 @@ function value_and_pushforward!( y, ty::NTuple, prep::FromPrimitivePushforwardPrep, - fromprim::AutoForwardFromPrimitive, + backend::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) return value_and_pushforward!( - f!, y, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + f!, y, ty, prep.pushforward_prep, backend.backend, x, tx, contexts... ) end @@ -128,53 +146,69 @@ end ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode() function threshold_batchsize( - fromprim::AutoReverseFromPrimitive{inplace}, dimension::Integer + backend::AutoReverseFromPrimitive{inplace}, dimension::Integer ) where {inplace} return AutoReverseFromPrimitive( - threshold_batchsize(fromprim.backend, dimension); inplace + threshold_batchsize(backend.backend, dimension); inplace ) end -struct FromPrimitivePullbackPrep{E<:PullbackPrep} <: PullbackPrep +struct FromPrimitivePullbackPrep{SIG,E<:PullbackPrep} <: PullbackPrep{SIG} + _sig::Val{SIG} pullback_prep::E end function prepare_pullback( - f::F, fromprim::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C} + strict::Val, + f::F, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - primitive_prep = prepare_pullback(f, fromprim.backend, x, ty, contexts...) - return FromPrimitivePullbackPrep(primitive_prep) + _sig = signature(f, backend, x, ty, contexts...; strict) + primitive_prep = prepare_pullback(strict, f, backend.backend, x, ty, contexts...) + return FromPrimitivePullbackPrep(_sig, primitive_prep) end function prepare_pullback( - f!::F, y, fromprim::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AutoReverseFromPrimitive, + x, + ty::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - primitive_prep = prepare_pullback(f!, y, fromprim.backend, x, ty, contexts...) - return FromPrimitivePullbackPrep(primitive_prep) + _sig = signature(f!, y, backend, x, ty, contexts...; strict) + primitive_prep = prepare_pullback(strict, f!, y, backend.backend, x, ty, contexts...) + return FromPrimitivePullbackPrep(_sig, primitive_prep) end function value_and_pullback( f::F, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - return value_and_pullback(f, prep.pullback_prep, fromprim.backend, x, ty, contexts...) + check_prep(f, prep, backend, x, ty, contexts...) + return value_and_pullback(f, prep.pullback_prep, backend.backend, x, ty, contexts...) end function value_and_pullback( f!::F, y, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback( - f!, y, prep.pullback_prep, fromprim.backend, x, ty, contexts... + f!, y, prep.pullback_prep, backend.backend, x, ty, contexts... ) end @@ -182,13 +216,14 @@ function value_and_pullback!( f::F, tx::NTuple, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) return value_and_pullback!( - f, tx, prep.pullback_prep, fromprim.backend, x, ty, contexts... + f, tx, prep.pullback_prep, backend.backend, x, ty, contexts... ) end @@ -197,12 +232,13 @@ function value_and_pullback!( y, tx::NTuple, prep::FromPrimitivePullbackPrep, - fromprim::AutoReverseFromPrimitive, + backend::AutoReverseFromPrimitive, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) return value_and_pullback!( - f!, y, tx, prep.pullback_prep, fromprim.backend, x, ty, contexts... + f!, y, tx, prep.pullback_prep, backend.backend, x, ty, contexts... ) end diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index b36f36503..bfa6656f4 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -37,25 +37,39 @@ function threshold_batchsize( end function prepare_pushforward( - f::F, ::AutoSimpleFiniteDiff, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, + f::F, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - return NoPushforwardPrep() + _sig = signature(f, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function prepare_pushforward( - f!::F, y, ::AutoSimpleFiniteDiff, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AutoSimpleFiniteDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - return NoPushforwardPrep() + _sig = signature(f!, y, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function value_and_pushforward( f::F, - ::NoPushforwardPrep, + prep::NoPushforwardPrep, backend::AutoSimpleFiniteDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, tx, contexts...) ε = eltype(x)(backend.ε) y = f(x, map(unwrap, contexts)...) ty = map(tx) do dx @@ -69,12 +83,13 @@ end function value_and_pushforward( f!::F, y, - ::NoPushforwardPrep, + prep::NoPushforwardPrep, backend::AutoSimpleFiniteDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) ε = eltype(x)(backend.ε) ty = map(tx) do dx f!(y, x + ε * dx, map(unwrap, contexts)...) diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index a340edd8f..1b3c6c0eb 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -21,25 +21,34 @@ check_available(::AutoZeroForward) = true inplace_support(::AutoZeroForward) = InPlaceSupported() function prepare_pushforward( - f::F, ::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} - return NoPushforwardPrep() + _sig = signature(f, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function prepare_pushforward( - f!::F, y, ::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AutoZeroForward, + x, + tx::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - return NoPushforwardPrep() + _sig = signature(f!, y, backend, x, tx, contexts...; strict) + return NoPushforwardPrep(_sig) end function value_and_pushforward( f::F, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, tx, contexts...) y = f(x, map(unwrap, contexts)...) ty = map(ReturnZero(y), tx) return y, ty @@ -48,12 +57,13 @@ end function value_and_pushforward( f!::F, y, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(unwrap, contexts)...) ty = map(ReturnZero(y), tx) return y, ty @@ -62,12 +72,13 @@ end function value_and_pushforward!( f::F, ty::NTuple, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) y = f(x, map(unwrap, contexts)...) for b in eachindex(ty) _zero!(ty[b]) @@ -79,12 +90,13 @@ function value_and_pushforward!( f!::F, y, ty::NTuple, - ::NoPushforwardPrep, - ::AutoZeroForward, + prep::NoPushforwardPrep, + backend::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, tx, contexts...) f!(y, x, map(unwrap, contexts)...) for b in eachindex(ty) _zero!(ty[b]) @@ -107,20 +119,34 @@ check_available(::AutoZeroReverse) = true inplace_support(::AutoZeroReverse) = InPlaceSupported() function prepare_pullback( - f::F, ::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}; ) where {F,C} - return NoPullbackPrep() + _sig = signature(f, backend, x, ty, contexts...; strict) + return NoPullbackPrep(_sig) end function prepare_pullback( - f!::F, y, ::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C} + strict::Val, + f!::F, + y, + backend::AutoZeroReverse, + x, + ty::NTuple, + contexts::Vararg{Context,C}; ) where {F,C} - return NoPullbackPrep() + _sig = signature(f!, y, backend, x, ty, contexts...; strict) + return NoPullbackPrep(_sig) end function value_and_pullback( - f::F, ::NoPullbackPrep, ::AutoZeroReverse, x, ty::NTuple{B}, contexts::Vararg{Context,C} + f::F, + prep::NoPullbackPrep, + backend::AutoZeroReverse, + x, + ty::NTuple{B}, + contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f, prep, backend, x, ty, contexts...) y = f(x, map(unwrap, contexts)...) tx = ntuple(ReturnZero(x), Val(B)) return y, tx @@ -129,12 +155,13 @@ end function value_and_pullback( f!::F, y, - ::NoPullbackPrep, - ::AutoZeroReverse, + prep::NoPullbackPrep, + backend::AutoZeroReverse, x, ty::NTuple{B}, contexts::Vararg{Context,C}, ) where {F,B,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(unwrap, contexts)...) tx = ntuple(ReturnZero(x), Val(B)) return y, tx @@ -143,12 +170,13 @@ end function value_and_pullback!( f::F, tx::NTuple, - ::NoPullbackPrep, - ::AutoZeroReverse, + prep::NoPullbackPrep, + backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, ty, contexts...) y = f(x, map(unwrap, contexts)...) for b in eachindex(tx) _zero!(tx[b]) @@ -160,12 +188,13 @@ function value_and_pullback!( f!::F, y, tx::NTuple, - ::NoPullbackPrep, - ::AutoZeroReverse, + prep::NoPullbackPrep, + backend::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f!, y, prep, backend, x, ty, contexts...) f!(y, x, map(unwrap, contexts)...) for b in eachindex(tx) _zero!(tx[b]) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 3dc8b7916..290adb525 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -1,11 +1,13 @@ ## Docstrings """ - prepare_hessian(f, backend, x, [contexts...]) -> prep + prepare_hessian(f, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("hessian")) """ -function prepare_hessian end +function prepare_hessian(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_hessian(strict, args...) +end """ prepare!_hessian(f, backend, x, [contexts...]) -> new_prep @@ -53,12 +55,14 @@ function value_gradient_and_hessian! end ## Preparation struct HVPGradientHessianPrep{ + SIG, BS<:BatchSizeSettings, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E2<:HVPPrep, E1<:GradientPrep, -} <: HessianPrep +} <: HessianPrep{SIG} + _sig::Val{SIG} batch_size_settings::BS batched_seeds::S batched_results::R @@ -67,31 +71,33 @@ struct HVPGradientHessianPrep{ end function prepare_hessian( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} # type-unstable batch_size_settings = pick_batchsize(outer(backend), x) # function barrier - return _prepare_hessian_aux(batch_size_settings, f, backend, x, contexts...) + return _prepare_hessian_aux(strict, batch_size_settings, f, backend, x, contexts...) end function _prepare_hessian_aux( + strict::Val, batch_size_settings::BatchSizeSettings{B}, f::F, backend::AbstractADType, x, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {B,F,C} + _sig = signature(f, backend, x, contexts...; strict) (; N, A) = batch_size_settings seeds = [basis(x, ind) for ind in eachindex(x)] batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = prepare_hvp(f, backend, x, batched_seeds[1], contexts...) - gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) + hvp_prep = prepare_hvp(strict, f, backend, x, batched_seeds[1], contexts...) + gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) return HVPGradientHessianPrep( - batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep + _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep ) end @@ -99,11 +105,12 @@ end function hessian( f::F, - prep::HVPGradientHessianPrep{<:BatchSizeSettings{B,true}}, + prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B,true}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,B,C} +) where {F,SIG,B,C} + check_prep(f, prep, backend, x, contexts...) (; batched_seeds, hvp_prep) = prep dg_batch = hvp(f, hvp_prep, backend, x, only(batched_seeds), contexts...) block = stack_vec_col(dg_batch) @@ -112,11 +119,12 @@ end function hessian( f::F, - prep::HVPGradientHessianPrep{<:BatchSizeSettings{B,false,aligned}}, + prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B,false,aligned}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,B,aligned,C} +) where {F,SIG,B,aligned,C} + check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, batched_seeds, hvp_prep) = prep (; A, B_last) = batch_size_settings @@ -139,11 +147,12 @@ end function hessian!( f::F, hess, - prep::HVPGradientHessianPrep{<:BatchSizeSettings{B}}, + prep::HVPGradientHessianPrep{SIG,<:BatchSizeSettings{B}}, backend::AbstractADType, x, contexts::Vararg{Context,C}, -) where {F,B,C} +) where {F,SIG,B,C} + check_prep(f, prep, backend, x, contexts...) (; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep (; N) = batch_size_settings @@ -173,6 +182,7 @@ function value_gradient_and_hessian( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, grad = value_and_gradient(f, prep.gradient_prep, inner(backend), x, contexts...) hess = hessian(f, prep, backend, x, contexts...) return y, grad, hess @@ -187,6 +197,7 @@ function value_gradient_and_hessian!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) y, _ = value_and_gradient!(f, grad, prep.gradient_prep, inner(backend), x, contexts...) hessian!(f, hess, prep, backend, x, contexts...) return y, grad, hess diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index edc1699f8..8caeb2b64 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -1,11 +1,13 @@ ## Docstrings """ - prepare_hvp(f, backend, x, tx, [contexts...]) -> prep + prepare_hvp(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("hvp")) """ -function prepare_hvp end +function prepare_hvp(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_hvp(strict, args...) +end """ prepare!_hvp(f, backend, x, tx, [contexts...]) -> new_prep @@ -15,11 +17,13 @@ $(docstring_prepare("hvp")) function prepare!_hvp end """ - prepare_hvp_same_point(f, backend, x, tx, [contexts...]) -> prep_same + prepare_hvp_same_point(f, backend, x, tx, [contexts...]; strict=Val(false)) -> prep_same $(docstring_prepare("hvp"; samepoint=true)) """ -function prepare_hvp_same_point end +function prepare_hvp_same_point(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_hvp_same_point(strict, args...) +end """ hvp(f, [prep,] backend, x, tx, [contexts...]) -> tg @@ -58,23 +62,25 @@ $(docstring_preparation_hint("hvp"; same_point=true)) function gradient_and_hvp! end function prepare_hvp( - f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}; ) where {F,C} return _prepare_hvp_aux( + strict, hvp_mode(backend), inner_preparation_behavior(outer(backend)), f, backend, x, tx, - contexts..., + contexts...; ) end ## Forward over anything -struct ForwardOverAnythingHVPPrep{G,GO,GI,PO,PI} <: HVPPrep +struct ForwardOverAnythingHVPPrep{SIG,G,GO,GI,PO,PI} <: HVPPrep{SIG} # pushforward of many pushforwards in theory, but pushforward of gradient in practice + _sig::Val{SIG} grad_buffer::G maybe_inner_gradient_prep::GO maybe_inner_gradient_in_prep::GI @@ -83,14 +89,16 @@ struct ForwardOverAnythingHVPPrep{G,GO,GI,PO,PI} <: HVPPrep end function _prepare_hvp_aux( + strict::Val, ::ForwardOverAnything, ::DontPrepareInner, f::F, backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Outer pushforward @@ -98,33 +106,41 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + strict, + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts...; ) else nothing end return ForwardOverAnythingHVPPrep( - grad_buffer, (), (), outer_pushforward_prep, outer_pushforward_in_prep + _sig, grad_buffer, (), (), outer_pushforward_prep, outer_pushforward_in_prep ) end function _prepare_hvp_aux( + strict::Val, ::ForwardOverAnything, ::PrepareInnerSimple, f::F, backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient - inner_gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) + inner_gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) inner_gradient_in_prep = inner_gradient_prep # Outer pushforward new_contexts = ( @@ -142,16 +158,23 @@ function _prepare_hvp_aux( contexts..., ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in... + strict, + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts_in...; ) else nothing end return ForwardOverAnythingHVPPrep( + _sig, grad_buffer, (inner_gradient_prep,), (inner_gradient_in_prep,), @@ -161,14 +184,16 @@ function _prepare_hvp_aux( end function _prepare_hvp_aux( + strict::Val, ::ForwardOverAnything, ::PrepareInnerOverload, f::F, backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f, backend, x, tx, contexts...; strict) grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient @@ -176,8 +201,10 @@ function _prepare_hvp_aux( xoi = overloaded_input( pushforward, shuffled_gradient!, grad_buffer, outer(backend), x, tx ) - inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...) - inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contexts...) + contextso = adapt_eltype.(contexts, Ref(eltype(xo))) + contextsoi = adapt_eltype.(contexts, Ref(eltype(xoi))) + inner_gradient_prep = prepare_gradient(strict, f, inner(backend), xo, contextso...) + inner_gradient_in_prep = prepare_gradient(strict, f, inner(backend), xoi, contextsoi...) # Outer pushforward new_contexts = ( FunctionContext(f), @@ -194,16 +221,23 @@ function _prepare_hvp_aux( contexts..., ) outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in... + strict, + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts_in...; ) else nothing end return ForwardOverAnythingHVPPrep( + _sig, grad_buffer, (inner_gradient_prep,), (inner_gradient_in_prep,), @@ -220,6 +254,7 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -243,6 +278,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _hvp_aux!( inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... ) @@ -317,6 +353,7 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -341,6 +378,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _gradient_and_hvp_aux!( inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... ) @@ -413,21 +451,24 @@ end ## Reverse over forward -struct ReverseOverForwardHVPPrep{G2<:GradientPrep,G1<:GradientPrep} <: HVPPrep +struct ReverseOverForwardHVPPrep{SIG,G2<:GradientPrep,G1<:GradientPrep} <: HVPPrep{SIG} # gradient of pushforward + _sig::Val{SIG} outer_gradient_prep::G2 gradient_prep::G1 end function _prepare_hvp_aux( + strict::Val, ::ReverseOverForward, ::InnerPreparationBehavior, f::F, backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), @@ -437,10 +478,10 @@ function _prepare_hvp_aux( contexts..., ) outer_gradient_prep = prepare_gradient( - shuffled_single_pushforward, outer(backend), x, new_contexts... + strict, shuffled_single_pushforward, outer(backend), x, new_contexts... ) - gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) - return ReverseOverForwardHVPPrep(outer_gradient_prep, gradient_prep) + gradient_prep = prepare_gradient(strict, f, inner(backend), x, contexts...) + return ReverseOverForwardHVPPrep(_sig, outer_gradient_prep, gradient_prep) end function hvp( @@ -449,8 +490,9 @@ function hvp( backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_gradient_prep) = prep rewrap = Rewrap(contexts...) tg = map(tx) do dx @@ -478,6 +520,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_gradient_prep) = prep rewrap = Rewrap(contexts...) for b in eachindex(tx, tg) @@ -505,6 +548,7 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) tg = hvp(f, prep, backend, x, tx, contexts...) grad = gradient(f, prep.gradient_prep, inner(backend), x, contexts...) return grad, tg @@ -520,6 +564,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) hvp!(f, tg, prep, backend, x, tx, contexts...) gradient!(f, grad, prep.gradient_prep, inner(backend), x, contexts...) return grad, tg @@ -527,39 +572,48 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{G,PO,PI} <: HVPPrep +struct ReverseOverReverseHVPPrep{SIG,G,PO,PI} <: HVPPrep{SIG} # pullback of gradient + _sig::Val{SIG} grad_buffer::G outer_pullback_prep::PO outer_pullback_in_prep::PI end function _prepare_hvp_aux( + strict::Val, ::ReverseOverReverse, ::InnerPreparationBehavior, f::F, backend::AbstractADType, x, tx::NTuple, - contexts::Vararg{Context,C}, + contexts::Vararg{Context,C}; ) where {F,C} + _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) outer_pullback_prep = prepare_pullback( - shuffled_gradient, outer(backend), x, tx, new_contexts... + strict, shuffled_gradient, outer(backend), x, tx, new_contexts... ) outer_pullback_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported prepare_pullback( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + strict, + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts...; ) else nothing end return ReverseOverReverseHVPPrep( - grad_buffer, outer_pullback_prep, outer_pullback_in_prep + _sig, grad_buffer, outer_pullback_prep, outer_pullback_in_prep ) end @@ -571,6 +625,7 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -590,6 +645,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _hvp_aux!( inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... ) @@ -650,6 +706,7 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -670,6 +727,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, tx, contexts...) return _gradient_and_hvp_aux!( inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... ) diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 52f07d74b..f26f335e6 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -1,11 +1,13 @@ ## Docstrings """ - prepare_second_derivative(f, backend, x, [contexts...]) -> prep + prepare_second_derivative(f, backend, x, [contexts...]; strict=Val(false)) -> prep $(docstring_prepare("second_derivative")) """ -function prepare_second_derivative end +function prepare_second_derivative(args::Vararg{Any,N}; strict=Val(false)) where {N} + return prepare_second_derivative(strict, args...) +end """ prepare!_second_derivative(f, prep, backend, x, [contexts...]) -> new_prep @@ -52,21 +54,23 @@ function value_derivative_and_second_derivative! end ## Preparation -struct DerivativeSecondDerivativePrep{E<:DerivativePrep} <: SecondDerivativePrep +struct DerivativeSecondDerivativePrep{SIG,E<:DerivativePrep} <: SecondDerivativePrep{SIG} + _sig::Val{SIG} outer_derivative_prep::E end function prepare_second_derivative( - f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + strict::Val, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} + _sig = signature(f, backend, x, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) outer_derivative_prep = prepare_derivative( - shuffled_derivative, outer(backend), x, new_contexts... + strict, shuffled_derivative, outer(backend), x, new_contexts... ) - return DerivativeSecondDerivativePrep(outer_derivative_prep) + return DerivativeSecondDerivativePrep(_sig, outer_derivative_prep) end ## One argument @@ -78,6 +82,7 @@ function second_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -95,6 +100,7 @@ function value_derivative_and_second_derivative( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -115,6 +121,7 @@ function second_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( @@ -134,6 +141,7 @@ function value_derivative_and_second_derivative!( x, contexts::Vararg{Context,C}, ) where {F,C} + check_prep(f, prep, backend, x, contexts...) (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 6d85ec0b9..65edbde0f 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -144,3 +144,6 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} tail_args = map(unwrap, contexts) return FixTail(f, tail_args...) end + +adapt_eltype(c::Constant, ::Type) = c +adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T)) diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index 0fb7af0df..cb3bb6ff6 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -1,49 +1,215 @@ -abstract type Prep end +abstract type Prep{SIG} end """ $(docstring_preptype("PushforwardPrep", "pushforward")) """ -abstract type PushforwardPrep <: Prep end -struct NoPushforwardPrep <: PushforwardPrep end +abstract type PushforwardPrep{SIG} <: Prep{SIG} end + +struct NoPushforwardPrep{SIG} <: PushforwardPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("PullbackPrep", "pullback")) """ -abstract type PullbackPrep <: Prep end -struct NoPullbackPrep <: PullbackPrep end +abstract type PullbackPrep{SIG} <: Prep{SIG} end + +struct NoPullbackPrep{SIG} <: PullbackPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("DerivativePrep", "derivative")) """ -abstract type DerivativePrep <: Prep end -struct NoDerivativePrep <: DerivativePrep end +abstract type DerivativePrep{SIG} <: Prep{SIG} end + +struct NoDerivativePrep{SIG} <: DerivativePrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("GradientPrep", "gradient")) """ -abstract type GradientPrep <: Prep end -struct NoGradientPrep <: GradientPrep end +abstract type GradientPrep{SIG} <: Prep{SIG} end + +struct NoGradientPrep{SIG} <: GradientPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("JacobianPrep", "jacobian")) """ -abstract type JacobianPrep <: Prep end -struct NoJacobianPrep <: JacobianPrep end +abstract type JacobianPrep{SIG} <: Prep{SIG} end + +struct NoJacobianPrep{SIG} <: JacobianPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("HVPPrep", "hvp")) """ -abstract type HVPPrep <: Prep end -struct NoHVPPrep <: HVPPrep end +abstract type HVPPrep{SIG} <: Prep{SIG} end + +struct NoHVPPrep{SIG} <: HVPPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("HessianPrep", "hessian")) """ -abstract type HessianPrep <: Prep end -struct NoHessianPrep <: HessianPrep end +abstract type HessianPrep{SIG} <: Prep{SIG} end + +struct NoHessianPrep{SIG} <: HessianPrep{SIG} + _sig::Val{SIG} +end """ $(docstring_preptype("SecondDerivativePrep", "second_derivative")) """ -abstract type SecondDerivativePrep <: Prep end -struct NoSecondDerivativePrep <: SecondDerivativePrep end +abstract type SecondDerivativePrep{SIG} <: Prep{SIG} end + +struct NoSecondDerivativePrep{SIG} <: SecondDerivativePrep{SIG} + _sig::Val{SIG} +end + +## Checks + +is_strict(::Prep{Nothing}) = Val(false) +is_strict(::Prep) = Val(true) + +struct PreparationMismatchError{SIG,EXEC_SIG} <: Exception + format::Vector{Symbol} +end + +function PreparationMismatchError( + ::Type{SIG}, ::Type{EXEC_SIG}; format +) where {SIG,EXEC_SIG} + return PreparationMismatchError{SIG,EXEC_SIG}(format) +end + +function Base.showerror( + io::IO, e::PreparationMismatchError{SIG,EXEC_SIG} +) where {SIG<:Tuple,EXEC_SIG<:Tuple} + println( + io, + "PreparationMismatchError (inconsistent types between preparation and execution):", + ) + for (s, pt, et) in zip(e.format, SIG.types, EXEC_SIG.types) + if pt == et + println(io, " - $s: ✅") + else + println(io, " - $s: ❌\n - prep: $pt\n - exec: $et") + end + end + println( + io, + "If you are confident that this check is superfluous, you can disable it by running preparation with the keyword argument `strict=Val(false)` inside DifferentiationInterface.", + ) + return nothing +end + +function signature( + f, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val{S} +) where {C,S} + if S + return Val(typeof((f, backend, x, contexts))) + else + return Val(Nothing) + end +end + +function signature( + f!, y, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Val{S} +) where {C,S} + if S + return Val(typeof((f!, y, backend, x, contexts))) + else + return Val(Nothing) + end +end + +function signature( + f, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}; strict::Val{S} +) where {C,S} + if S + return Val(typeof((f, backend, x, t, contexts))) + else + return Val(Nothing) + end +end + +function signature( + f!, + y, + backend::AbstractADType, + x, + t::NTuple, + contexts::Vararg{Context,C}; + strict::Val{S}, +) where {C,S} + if S + return Val(typeof((f!, y, backend, x, t, contexts))) + else + return Val(Nothing) + end +end + +function check_prep( + f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + EXEC_SIG = typeof((f, backend, x, contexts)) + if SIG != EXEC_SIG + throw( + PreparationMismatchError( + SIG, EXEC_SIG; format=[:f, :backend, :x, :contexts] + ), + ) + end + end +end + +function check_prep( + f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + EXEC_SIG = typeof((f!, y, backend, x, contexts)) + if SIG != EXEC_SIG + throw( + PreparationMismatchError( + SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :contexts] + ), + ) + end + end +end + +function check_prep( + f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + EXEC_SIG = typeof((f, backend, x, t, contexts)) + if SIG != EXEC_SIG + throw( + PreparationMismatchError( + SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts] + ), + ) + end + end +end + +function check_prep( + f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C} +) where {SIG,C} + if SIG !== Nothing + EXEC_SIG = typeof((f!, y, backend, x, t, contexts)) + if SIG != EXEC_SIG + throw( + PreparationMismatchError( + SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts] + ), + ) + end + end +end diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 71ea785db..4f38af5b1 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -30,13 +30,6 @@ test_differentiation( backends, default_scenarios(; include_constantified=true, include_cachified=true); logging=LOGGING, - excluded=SECOND_ORDER, -); - -test_differentiation( - SecondOrder(AutoPolyesterForwardDiff(), AutoPolyesterForwardDiff()), - default_scenarios(); - logging=LOGGING, ); @testset "Batch size" begin diff --git a/DifferentiationInterface/test/Core/Internals/signature.jl b/DifferentiationInterface/test/Core/Internals/signature.jl new file mode 100644 index 000000000..2d28ce7dc --- /dev/null +++ b/DifferentiationInterface/test/Core/Internals/signature.jl @@ -0,0 +1,125 @@ +using DifferentiationInterface +using DifferentiationInterface: AutoZeroReverse, AutoZeroForward +using Test + +backend = AutoZeroForward() +other_backend = AutoZeroReverse() +f(x, c) = x + c +f!(y, x) = y .= x +x = 1.0 +y = zeros(2) +c = 2.0 + +@testset "Out of place, no tangents" begin + prep = prepare_derivative(f, backend, x, Constant(c); strict=Val(true)) + prep_chill = prepare_derivative(f, backend, x, Constant(c); strict=Val(false)) + + @test_throws MethodError derivative(nothing, prep_chill, backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ❌ + - prep: typeof(f) + - exec: Nothing + - backend: ✅ + - x: ✅ + - contexts: ✅ + """ derivative(nothing, prep, backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ❌ + - prep: AutoZeroForward + - exec: AutoZeroReverse + - x: ✅ + - contexts: ✅ + """ derivative(f, prep, other_backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ✅ + - x: ❌ + - prep: Float64 + - exec: Int64 + - contexts: ✅ + """ derivative(f, prep, backend, 1, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ✅ + - x: ✅ + - contexts: ❌ + - prep: Tuple{Constant{Float64}} + - exec: Tuple{Constant{Int64}} + """ derivative(f, prep, backend, x, Constant(2)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ✅ + - backend: ✅ + - x: ✅ + - contexts: ❌ + - prep: Tuple{Constant{Float64}} + - exec: Tuple{Constant{Int64}, Constant{Int64}} + """ derivative(f, prep, backend, x, Constant(2), Constant(3)) +end + +@testset "In place, no tangents" begin + prep = prepare_derivative(f!, y, backend, x; strict=Val(true)) + prep_chill = prepare_derivative(f!, y, backend, x; strict=Val(false)) + + @test_throws MethodError derivative(nothing, y, prep_chill, backend, x, Constant(c)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f!: ❌ + - prep: typeof(f!) + - exec: Nothing + - y: ✅ + - backend: ✅ + - x: ✅ + - contexts: ✅ + """ derivative(nothing, y, prep, backend, x) +end + +@testset "Out of place, with tangents" begin + prep = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(true)) + prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(false)) + + @test_throws MethodError pushforward(nothing, prep_chill, backend, x, (x,)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f: ❌ + - prep: typeof(f) + - exec: Nothing + - backend: ✅ + - x: ✅ + - tang: ✅ + - contexts: ✅ + """ pushforward(nothing, prep, backend, x, (x,), Constant(c)) +end + +@testset "In place, with tangents" begin + prep = prepare_pushforward(f!, y, backend, x, (x,); strict=Val(true)) + prep_chill = prepare_pushforward( + f!, y, backend, x, (x,), Constant(c); strict=Val(false) + ) + + @test_throws MethodError pushforward(nothing, y, prep_chill, backend, x, (x,)) + + @test_throws """ + PreparationMismatchError (inconsistent types between preparation and execution): + - f!: ❌ + - prep: typeof(f!) + - exec: Nothing + - y: ✅ + - backend: ✅ + - x: ✅ + - tang: ✅ + - contexts: ✅ + """ pushforward(nothing, y, prep, backend, x, (x,)) +end diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 0571a73b3..6d56dabd4 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -17,24 +17,10 @@ for backend in zero_backends end @testset "Type stability" begin - test_differentiation( - AutoZeroForward(), - default_scenarios(; include_batchified=false, include_constantified=true); - correctness=false, - type_stability=:full, - logging=LOGGING, - ) - - test_differentiation( - AutoZeroReverse(), - default_scenarios(; include_batchified=false, include_constantified=true); - correctness=false, - type_stability=:full, - logging=LOGGING, - ) - test_differentiation( [ + AutoZeroForward(), + AutoZeroReverse(), SecondOrder(AutoZeroForward(), AutoZeroReverse()), SecondOrder(AutoZeroReverse(), AutoZeroForward()), ], diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 8c9c6fb84..729087caf 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.9.5" +version = "0.9.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 2c92546ed..a40645fbf 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -90,6 +90,7 @@ using DifferentiationInterface: pushforward_performance, pullback_performance using DifferentiationInterface: Rewrap, Context, Constant, Cache, unwrap +using DifferentiationInterface: PreparationMismatchError using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES using JET: @test_opt using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent @@ -97,7 +98,7 @@ using ProgressMeter: ProgressUnknown, next! using Random: AbstractRNG, default_rng, rand! using SparseArrays: SparseArrays, AbstractSparseMatrix, SparseMatrixCSC, nnz, sparse, spdiagm -using Test: @testset, @test +using Test: @testset, @test, @test_throws """ FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian] diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 9500cf7aa..9ed8013f8 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -1,3 +1,5 @@ +const PME = PreparationMismatchError + for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative @@ -52,6 +54,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -60,14 +63,20 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=Val(true), + ), ba, xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -94,6 +103,8 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -111,6 +122,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -119,14 +131,20 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=Val(true), + ), ba, xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -165,6 +183,10 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op!( + nothing, mysimilar(res1), prepstrict, ba, x, contexts... + ) + @test_throws PME $op!(nothing, mysimilar(res1), prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -184,6 +206,7 @@ for op in ALL_OPS xrand, yrand = myrandom(x), myrandom(y) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -192,7 +215,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -200,13 +223,14 @@ for op in ALL_OPS copy(new_smaller.y), ba, new_smaller.x, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -239,6 +263,10 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op( + nothing, mysimilar(y), prepstrict, ba, x, contexts... + ) + @test_throws PME $op(nothing, mysimilar(y), prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -256,6 +284,7 @@ for op in ALL_OPS xrand, yrand = myrandom(x), myrandom(y) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -264,7 +293,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -272,13 +301,14 @@ for op in ALL_OPS copy(new_smaller.y), ba, new_smaller.x, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -319,6 +349,12 @@ for op in ALL_OPS @test mynnz(res1_out2_noval) == mynnz(scen.res1) end end + @test_throws PME $val_and_op!( + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, contexts... + ) + @test_throws PME $op!( + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, contexts... + ) scenario_intact && @test new_scen == scen return nothing end @@ -337,6 +373,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -345,14 +382,20 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=Val(true), + ), ba, xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val, res2_out1_val = $val_and_op( @@ -381,6 +424,8 @@ for op in ALL_OPS @test mynnz(res2_out2_noval) == mynnz(scen.res2) end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -398,6 +443,7 @@ for op in ALL_OPS xrand = myrandom(x) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -406,14 +452,20 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, - $prep_op(new_smaller.f, ba, new_smaller.x, new_smaller.contexts...), + $prep_op( + new_smaller.f, + ba, + new_smaller.x, + new_smaller.contexts...; + strict=Val(true), + ), ba, xrand, contextsrand..., ) - [(), (prep,), (prepprep,)] + [(), (prep,), (prepstrict,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val, res2_in1_val = mysimilar(res1), mysimilar(res2) @@ -456,6 +508,10 @@ for op in ALL_OPS @test mynnz(res2_out2_noval) == mynnz(scen.res2) end end + @test_throws PME $val_and_op!( + nothing, mysimilar(res1), mysimilar(res2), prepstrict, ba, x, contexts... + ) + @test_throws PME $op!(nothing, mysimilar(res2), prepstrict, ba, x, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -474,6 +530,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -482,14 +539,15 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, @@ -497,7 +555,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_out1_val, res1_out1_val = $val_and_op( @@ -520,6 +578,8 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, tang, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, tang, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -537,6 +597,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -545,14 +606,15 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, @@ -560,7 +622,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res1_in1_val = mysimilar(res1) @@ -595,6 +657,12 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op!( + nothing, mysimilar(res1), prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $op!( + nothing, mysimilar(res1), prepstrict, ba, x, tang, contexts... + ) scenario_intact && @test new_scen == scen return nothing end @@ -612,6 +680,7 @@ for op in ALL_OPS xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -620,7 +689,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -629,7 +698,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, @@ -637,7 +707,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val = mysimilar(y) @@ -670,6 +740,12 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op( + nothing, mysimilar(y), prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $op( + nothing, mysimilar(y), prepstrict, ba, x, tang, contexts... + ) scenario_intact && @test new_scen == scen return nothing end @@ -687,6 +763,7 @@ for op in ALL_OPS xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -695,7 +772,7 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, copy(yrand), ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, copy(yrand), $prep_op( @@ -704,7 +781,8 @@ for op in ALL_OPS ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, @@ -712,7 +790,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, copy(yrand), ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) y_in1_val, res1_in1_val = mysimilar(y), mysimilar(res1) @@ -763,6 +841,12 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op!( + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $op!( + nothing, mysimilar(y), mysimilar(res1), prepstrict, ba, x, tang, contexts... + ) scenario_intact && @test new_scen == scen return nothing end @@ -781,6 +865,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -789,14 +874,15 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, @@ -804,7 +890,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_out1_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) @@ -827,6 +913,8 @@ for op in ALL_OPS end end end + @test_throws PME $val_and_op(nothing, prepstrict, ba, x, tang, contexts...) + @test_throws PME $op(nothing, prepstrict, ba, x, tang, contexts...) scenario_intact && @test new_scen == scen return nothing end @@ -844,6 +932,7 @@ for op in ALL_OPS xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) + local prepstrict preptup_cands_val, preptup_cands_noval = map(1:2) do _ new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba @@ -852,14 +941,15 @@ for op in ALL_OPS deepcopy(smaller) end prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( + prepstrict = $prep_op!( f, $prep_op( new_smaller.f, ba, new_smaller.x, new_smaller.tang, - new_smaller.contexts..., + new_smaller.contexts...; + strict=Val(true), ), ba, xrand, @@ -867,7 +957,7 @@ for op in ALL_OPS contextsrand..., ) prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - [(), (prep,), (prepprep,), (prep_same,)] + [(), (prep,), (prepstrict,), (prep_same,)] end for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_in1_noval = mysimilar(res2) @@ -918,6 +1008,19 @@ for op in ALL_OPS end end end + @test_throws PME $op!( + nothing, mysimilar(res2), prepstrict, ba, x, tang, contexts... + ) + @test_throws PME $val_and_op!( + nothing, + mysimilar(res1), + mysimilar(res2), + prepstrict, + ba, + x, + tang, + contexts..., + ) scenario_intact && @test new_scen == scen return nothing end