diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 2c6b6f5fd..c2ef71dda 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true # TODO: toggle + fail-fast: false # TODO: toggle matrix: version: - "1.10" diff --git a/DifferentiationInterface/docs/src/dev_guide.md b/DifferentiationInterface/docs/src/dev_guide.md index 10e74a905..b1b9f420f 100644 --- a/DifferentiationInterface/docs/src/dev_guide.md +++ b/DifferentiationInterface/docs/src/dev_guide.md @@ -36,7 +36,7 @@ Your AD package needs to be registered first. ### Core code In the main package, you should define a new struct `SuperDiffBackend` which subtypes [`ADTypes.AbstractADType`](@extref ADTypes), and endow it with the fields you need to parametrize your differentiation routines. -You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`. +You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.check_inplace`](@ref) on `SuperDiffBackend`. !!! info In the end, this backend struct will need to be contributed to [ADTypes.jl](https://github.com/SciML/ADTypes.jl). diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index bacb7baa4..9ce083126 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -18,7 +18,7 @@ const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}} const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} DI.check_available(::AutoChainRules) = true -DI.inplace_support(::AutoChainRules) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoChainRules) = false include("reverse_onearg.jl") include("differentiate_with.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 04d3f55f1..079e5dd1a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -6,11 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, - ::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoPullbackPrep() end @@ -21,7 +17,7 @@ function DI.prepare_pullback_same_point( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -34,7 +30,7 @@ function DI.value_and_pullback( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -50,7 +46,7 @@ function DI.value_and_pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -65,7 +61,7 @@ function DI.pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; pb) = prep tx = map(ty) do dy diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 2973f3b37..42569a19d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -5,7 +5,7 @@ import DifferentiationInterface as DI using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆ DI.check_available(::AutoDiffractor) = true -DI.inplace_support(::AutoDiffractor) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoDiffractor) = false DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index a6ad7cbc4..b64beee17 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f) end @inline function _translate( - backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache + backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.GeneralizedCache ) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 54dd38501..bd43c65d8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -6,7 +6,7 @@ using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp using LinearAlgebra: dot DI.check_available(::AutoFiniteDifferences) = true -DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoFiniteDifferences) = false ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index d4f8570b0..2ff24032d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -26,11 +26,11 @@ using ForwardDiff: value DI.check_available(::AutoForwardDiff) = true +DI.check_operator_overloading(::AutoForwardDiff) = true include("utils.jl") include("onearg.jl") include("twoarg.jl") -include("secondorder.jl") include("differentiate_with.jl") include("misc.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index d2b76a6d1..132156e33 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -1,16 +1,58 @@ ## Pushforward + +function DI.overloaded_input( + ::typeof(DI.pushforward), + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, tx) + return xdual +end +#= +function DI.overloaded_input( + ::typeof(DI.pushforward), + f!::F, + y, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + T = tag_type(f, backend, x) + xdual = if x isa Number + make_dual(T, x, tx) + else + make_dual_similar(T, x, tx) + end + return xdual +end +=# + DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp) DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) ## Derivative + function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) return DI.overloaded_input_type(prep.pushforward_prep) end -DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals) +function DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) + return typeof(prep.config.duals) +end ## Gradient + DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals) ## Jacobian -DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2]) -DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2]) + +function DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) + return typeof(prep.config.duals[2]) +end +function DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) + return typeof(prep.config.duals[2]) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index f50f030a7..9cd778b7a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -170,7 +170,7 @@ struct ForwardDiffOneArgDerivativePrep{E} <: DI.DerivativePrep end function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) return ForwardDiffOneArgDerivativePrep(pushforward_prep) @@ -181,7 +181,7 @@ function DI.value_and_derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -195,7 +195,7 @@ function DI.value_and_derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -208,7 +208,7 @@ function DI.derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) @@ -221,7 +221,7 @@ function DI.derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der @@ -236,7 +236,7 @@ function DI.value_and_gradient!( grad, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -255,7 +255,7 @@ function DI.value_and_gradient( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -273,7 +273,7 @@ function DI.gradient!( grad, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -288,7 +288,7 @@ function DI.gradient( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -309,7 +309,7 @@ function DI.prepare_gradient( f::F, backend::AutoForwardDiff, x::AbstractArray, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -324,7 +324,7 @@ function DI.value_and_gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -340,7 +340,7 @@ function DI.value_and_gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = GradientResult(x) @@ -355,7 +355,7 @@ function DI.gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -367,7 +367,7 @@ function DI.gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -383,7 +383,7 @@ function DI.value_and_jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -403,7 +403,7 @@ function DI.value_and_jacobian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -419,7 +419,7 @@ function DI.jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -434,7 +434,7 @@ function DI.jacobian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -452,7 +452,7 @@ struct ForwardDiffOneArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -467,7 +467,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -484,7 +484,7 @@ function DI.value_and_jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -497,7 +497,7 @@ function DI.jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -509,7 +509,7 @@ function DI.jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -519,7 +519,7 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} return DI.NoSecondDerivativePrep() end @@ -529,7 +529,7 @@ function DI.second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -544,7 +544,7 @@ function DI.second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -558,7 +558,7 @@ function DI.value_derivative_and_second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -577,7 +577,7 @@ function DI.value_derivative_and_second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -589,69 +589,6 @@ function DI.value_derivative_and_second_derivative!( return y, der, der2 end -## HVP - -function DI.prepare_hvp( - f::F, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.prepare_hvp(f, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.hvp(f, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.hvp!(f, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.gradient_and_hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.gradient_and_hvp!( - f, grad, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end - ## Hessian ### Unprepared, only when chunk size and tag are not specified @@ -661,7 +598,7 @@ function DI.hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -676,7 +613,7 @@ function DI.hessian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -693,7 +630,7 @@ function DI.value_gradient_and_hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -713,7 +650,7 @@ function DI.value_gradient_and_hessian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -734,7 +671,7 @@ struct ForwardDiffHessianPrep{C1,C2} <: DI.HessianPrep end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -751,7 +688,7 @@ function DI.hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -763,7 +700,7 @@ function DI.hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -777,7 +714,7 @@ function DI.value_gradient_and_hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -794,7 +731,7 @@ function DI.value_gradient_and_hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl deleted file mode 100644 index d01efd074..000000000 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ /dev/null @@ -1,144 +0,0 @@ -struct ForwardDiffOverSomethingHVPPrep{E1<:DI.GradientPrep,E2<:DI.PushforwardPrep} <: - DI.HVPPrep - inner_gradient_prep::E1 - outer_pushforward_prep::E2 -end - -function DI.prepare_hvp( - f::F, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - T = tag_type(DI.shuffled_gradient, DI.outer(backend), x) - xdual = make_dual(T, x, tx) - inner_gradient_prep = DI.prepare_gradient(f, DI.inner(backend), xdual, contexts...) - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - outer_pushforward_prep = DI.prepare_pushforward( - DI.shuffled_gradient, DI.outer(backend), x, tx, new_contexts... - ) - return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep) -end - -function DI.hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward!( - DI.shuffled_gradient, - tg, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) - return tg -end - -function DI.gradient_and_hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.value_and_pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - new_grad, _ = DI.value_and_pushforward!( - DI.shuffled_gradient, - tg, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg -end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 7289c5379..d6987e571 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -115,7 +115,7 @@ function DI.value_and_derivative( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -134,7 +134,7 @@ function DI.value_and_derivative!( der, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -152,7 +152,7 @@ function DI.derivative( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -169,7 +169,7 @@ function DI.derivative!( der, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -187,11 +187,7 @@ struct ForwardDiffTwoArgDerivativePrep{C} <: DI.DerivativePrep end function DI.prepare_derivative( - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc! = DI.with_contexts(f!, contexts...) tag = get_tag(fc!, backend, x) @@ -205,7 +201,7 @@ function DI.prepare!_derivative( old_prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} if y isa Vector (; config) = old_prep @@ -222,7 +218,7 @@ function DI.value_and_derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) @@ -238,7 +234,7 @@ function DI.value_and_derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) @@ -253,7 +249,7 @@ function DI.derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -267,7 +263,7 @@ function DI.derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -283,7 +279,7 @@ function DI.value_and_jacobian( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -303,7 +299,7 @@ function DI.value_and_jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -321,7 +317,7 @@ function DI.jacobian( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -338,7 +334,7 @@ function DI.jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -356,11 +352,7 @@ struct ForwardDiffTwoArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc! = DI.with_contexts(f!, contexts...) chunk = choose_chunk(backend, x) @@ -375,7 +367,7 @@ function DI.prepare!_jacobian( old_prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} if x isa Vector && y isa Vector (; config) = old_prep @@ -394,7 +386,7 @@ function DI.value_and_jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -411,7 +403,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) @@ -426,7 +418,7 @@ function DI.jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -440,7 +432,7 @@ function DI.jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 2c53d17e2..7e08abb5b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -77,15 +77,10 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} return ty end -# store preparation result with the right input eltype -struct PrepContext{T<:DI.Prep} <: DI.Context - data::T -end - -function _translate(::Type{T}, ::Val{B}, c::DI.ConstantOrFunctionOrBackend) where {T,B} +function _translate(::Type{T}, ::Val{B}, c::DI.GeneralizedConstant) where {T,B} return DI.unwrap(c) end -_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = DI.unwrap(c) +_translate(::Type{T}, ::Val{B}, c::DI.PrepContext) where {T,B} = DI.unwrap(c) function _translate(::Type{T}, ::Val{B}, c::DI.Cache) where {T,B} c0 = DI.unwrap(c) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl index b2d24dac3..533ba6961 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl @@ -5,6 +5,7 @@ using ADTypes: AutoGTPSA using GTPSA: GTPSA, TPS, Descriptor DI.check_available(::AutoGTPSA) = true +DI.check_operator_overloading(::AutoGTPSA) = true include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 667aadee9..22d658fed 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -14,6 +14,7 @@ using Mooncake: Mooncake DI.check_available(::AutoMooncake) = true +DI.check_operator_overloading(::AutoMooncake) = true copyto!!(dst::Number, src::Number) = convert(typeof(dst), src) copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeof(dst), src) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index 0efbc9695..ae6425a24 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -12,6 +12,7 @@ function single_threaded(backend::AutoPolyesterForwardDiff{chunksize,T}) where { end DI.check_available(::AutoPolyesterForwardDiff) = true +DI.check_operator_overloading(::AutoPolyesterForwardDiff) = true function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, x::AbstractArray) return DI.pick_batchsize(single_threaded(backend), x) @@ -31,4 +32,17 @@ end include("onearg.jl") include("twoarg.jl") +function DI.overloaded_input( + ::typeof(DI.pushforward), + f::F, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + return DI.overloaded_input( + DI.pushforward, f, single_threaded(backend), x, tx, contexts... + ) +end + end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 8f67ad78d..6571a3ab9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -286,78 +286,6 @@ function DI.value_gradient_and_hessian!( ) end -## HVP - -function DI.prepare_hvp( - f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} -) where {C} - return DI.prepare_hvp( - f, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp( - f, - prep::DI.HVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp!( - f, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp!( - f, tg, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp( - f, - prep::DI.HVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp!( - f, - grad, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.gradient_and_hvp!( - f, - grad, - tg, - prep, - DI.SecondOrder(single_threaded(backend), backend), - x, - tx, - contexts..., - ) -end - ## Second derivative function DI.prepare_second_derivative( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl index 061765e0f..ce23d24c6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl @@ -22,9 +22,11 @@ using ReverseDiff: hessian, hessian!, jacobian, - jacobian! + jacobian!, + value DI.check_available(::AutoReverseDiff) = true +DI.check_operator_overloading(::AutoReverseDiff) = true include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl index 2631469fc..bf4782a4e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl @@ -1,3 +1,16 @@ +## Pullback + +function DI.overloaded_input( + ::typeof(DI.pullback), + f, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + return nothing +end + ## Gradient DI.overloaded_input_type(prep::ReverseDiffGradientPrep) = typeof(prep.config.input) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index fb5da6b76..db4998021 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -5,7 +5,7 @@ import DifferentiationInterface as DI using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgradient DI.check_available(::AutoTracker) = true -DI.inplace_support(::AutoTracker) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoTracker) = false ## Pullback @@ -15,7 +15,7 @@ struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoPullbackPrep() end @@ -26,7 +26,7 @@ function DI.prepare_pullback_same_point( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(y, pb) @@ -38,7 +38,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -53,7 +53,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -68,7 +68,7 @@ function DI.pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -80,28 +80,20 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoTracker, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, - ::DI.NoGradientPrep, - ::AutoTracker, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, - ::DI.NoGradientPrep, - ::AutoTracker, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) @@ -113,7 +105,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -125,7 +117,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index b336b8810..97280aee6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -32,7 +32,8 @@ check_nothing(::Nothing, f, x, contexts) = throw(ZygoteNothingError(f, x, contex check_nothing(::Any, f, x, contexts) = nothing DI.check_available(::AutoZygote) = true -DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoZygote) = false +DI.check_operator_overloading(::AutoZygote) = false translate(c::DI.Context) = DI.unwrap(c) translate(c::DI.Cache) = Buffer(DI.unwrap(c)) @@ -183,7 +184,12 @@ function DI.prepare_hvp( end function DI.hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.ForwardOverAnythingHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, ) where {C} return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -191,7 +197,7 @@ end function DI.hvp!( f, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -203,7 +209,12 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.ForwardOverAnythingHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, ) where {C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -214,7 +225,7 @@ function DI.gradient_and_hvp!( f, grad, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -228,17 +239,13 @@ end ## Hessian function DI.prepare_hessian( - f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoHessianPrep() end function DI.hessian( - f, - ::DI.NoHessianPrep, - ::AutoZygote, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoHessianPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} fc = DI.with_contexts(f, contexts...) hess = hessian(fc, x) @@ -252,7 +259,7 @@ function DI.hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end @@ -262,7 +269,7 @@ function DI.value_gradient_and_hessian( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) @@ -276,7 +283,7 @@ function DI.value_gradient_and_hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 99ecfc2d4..57a11bf78 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -41,7 +41,6 @@ include("utils/prep.jl") include("utils/traits.jl") include("utils/basis.jl") include("utils/batchsize.jl") -include("utils/check.jl") include("utils/printing.jl") include("utils/context.jl") include("utils/linalg.jl") @@ -58,13 +57,13 @@ include("second_order/hessian.jl") include("fallbacks/no_prep.jl") include("fallbacks/change_prep.jl") +include("fallbacks/input.jl") include("misc/differentiate_with.jl") include("misc/from_primitive.jl") include("misc/sparsity_detector.jl") include("misc/simple_finite_diff.jl") include("misc/zero_backends.jl") -include("misc/overloading.jl") ## Exported diff --git a/DifferentiationInterface/src/fallbacks/input.jl b/DifferentiationInterface/src/fallbacks/input.jl new file mode 100644 index 000000000..a3162449d --- /dev/null +++ b/DifferentiationInterface/src/fallbacks/input.jl @@ -0,0 +1,26 @@ +function overloaded_input end +function overloaded_input_type end + +function error_if_overloading(backend) + if check_operator_overloading(backend) + throw( + ArgumentError( + "The current backend is based on operator overloading, a custom method for `overloaded_input` is therefore necessary. Please open an issue on DifferentiationInterface.jl if you encounter this error.", + ), + ) + end +end + +for op in [:pushforward, :pullback] + @eval function overloaded_input( + ::typeof($op), + f::F, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + error_if_overloading(backend) + return copy(x) + end +end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 13a260b02..f854fba1f 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -129,6 +129,19 @@ function shuffled_gradient( return gradient(f, backend, x, rewrap(unannotated_contexts...)...) end +#= +function shuffled_gradient!( + grad, + x, + f::F, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + return gradient!(f, grad, backend, x, rewrap(unannotated_contexts...)...) +end +=# + function shuffled_gradient( x, f::F, @@ -139,3 +152,17 @@ function shuffled_gradient( ) where {F,C} return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) end + +#= +function shuffled_gradient!( + grad, + x, + f::F, + prep::GradientPrep, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + return gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...) +end +=# diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 59c2663c1..d3878c638 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -1,7 +1,10 @@ abstract type FromPrimitive <: AbstractADType end check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) -inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend) +check_inplace(fromprim::FromPrimitive) = check_inplace(fromprim.backend) +function check_operator_overloading(fromprim::FromPrimitive) + return check_operator_overloading(fromprim.backend) +end function pick_batchsize(fromprim::FromPrimitive, N::Integer) return pick_batchsize(fromprim.backend, N) diff --git a/DifferentiationInterface/src/misc/overloading.jl b/DifferentiationInterface/src/misc/overloading.jl deleted file mode 100644 index bda61a192..000000000 --- a/DifferentiationInterface/src/misc/overloading.jl +++ /dev/null @@ -1,9 +0,0 @@ -""" - overloaded_input_type(prep) - -If it exists, return the overloaded input type which will be passed to the differentiated function when preparation result `prep` is reused. - -!!! danger - This function is experimental and not part of the public API. -""" -function overloaded_input_type end diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index 11f05ac1e..74f1b80f2 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -17,7 +17,7 @@ end ADTypes.mode(::AutoSimpleFiniteDiff) = ForwardMode() check_available(::AutoSimpleFiniteDiff) = true -inplace_support(::AutoSimpleFiniteDiff) = InPlaceSupported() +check_operator_overloading(::AutoSimpleFiniteDiff) = false function pick_batchsize(::AutoSimpleFiniteDiff{nothing}, N::Integer) B = reasonable_batchsize(N, 12) diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index a340edd8f..41a2d285e 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -18,7 +18,7 @@ struct AutoZeroForward <: AbstractADType end ADTypes.mode(::AutoZeroForward) = ForwardMode() check_available(::AutoZeroForward) = true -inplace_support(::AutoZeroForward) = InPlaceSupported() +check_operator_overloading(::AutoZeroForward) = false function prepare_pushforward( f::F, ::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C} @@ -104,7 +104,7 @@ struct AutoZeroReverse <: AbstractADType end ADTypes.mode(::AutoZeroReverse) = ReverseMode() check_available(::AutoZeroReverse) = true -inplace_support(::AutoZeroReverse) = InPlaceSupported() +check_operator_overloading(::AutoZeroReverse) = false function prepare_pullback( f::F, ::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C} diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index cea59e7ef..416a39547 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -74,15 +74,15 @@ function prepare_hvp( return _prepare_hvp_aux(hvp_mode(backend), f, backend, x, tx, contexts...) end -## Forward over forward +## Forward over anything -struct ForwardOverForwardHVPPrep{E2<:PushforwardPrep} <: HVPPrep - # pushforward of many pushforwards in theory, but pushforward of gradient in practice +struct ForwardOverAnythingHVPPrep{E1<:GradientPrep,E2<:PushforwardPrep} <: HVPPrep + inner_gradient_prep::E1 outer_pushforward_prep::E2 end function _prepare_hvp_aux( - ::ForwardOverForward, + ::Union{ForwardOverForward,ForwardOverReverse}, f::F, backend::AbstractADType, x, @@ -90,140 +90,46 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) - return ForwardOverForwardHVPPrep(outer_pushforward_prep) -end - -function hvp( - f::F, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) -end - -function hvp!( - f::F, - tg::NTuple, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) -end - -function gradient_and_hvp( - f::F, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return value_and_pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) -end - -function gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + new_contexts_unknown_prep = ( + FunctionContext(f), + UnknownContext(), # placeholder for inner_gradient_prep + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) - new_grad, _ = value_and_pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., + xo = overloaded_input( + pushforward, shuffled_gradient, outer(backend), x, tx, new_contexts_unknown_prep... ) - return copyto!(grad, new_grad), tg -end - -## Forward over reverse - -struct ForwardOverReverseHVPPrep{E2<:PushforwardPrep} <: HVPPrep - # pushforward of gradient - outer_pushforward_prep::E2 -end - -function _prepare_hvp_aux( - ::ForwardOverReverse, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - rewrap = Rewrap(contexts...) + inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - return ForwardOverReverseHVPPrep(outer_pushforward_prep) + return ForwardOverAnythingHVPPrep(inner_gradient_prep, outer_pushforward_prep) end function hvp( f::F, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... @@ -233,16 +139,20 @@ end function hvp!( f::F, tg::NTuple, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return pushforward!( shuffled_gradient, @@ -257,16 +167,20 @@ end function gradient_and_hvp( f::F, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... @@ -277,16 +191,20 @@ function gradient_and_hvp!( f::F, grad, tg::NTuple, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) new_grad, _ = value_and_pushforward!( shuffled_gradient, @@ -303,7 +221,6 @@ end ## Reverse over forward struct ReverseOverForwardHVPPrep{E2<:GradientPrep,E1<:GradientPrep} <: HVPPrep - # gradient of pushforward outer_gradient_prep::E2 gradient_prep::E1 end @@ -415,8 +332,9 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{E2<:PullbackPrep} <: HVPPrep - # pullback of gradient +struct ReverseOverReverseHVPPrep{E1<:Union{Nothing,GradientPrep},E2<:PullbackPrep} <: + HVPPrep + inner_gradient_prep::E1 outer_pullback_prep::E2 end @@ -429,13 +347,38 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + new_contexts_unknown_prep = ( + FunctionContext(f), + UnknownContext(), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + xo = overloaded_input( + pullback, shuffled_gradient, backend, x, tx, new_contexts_unknown_prep... ) + if isnothing(xo) + inner_gradient_prep = nothing + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + inner_gradient_prep = prepare_gradient(f, backend, xo, contexts...) + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end outer_pullback_prep = prepare_pullback( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - return ReverseOverReverseHVPPrep(outer_pullback_prep) + return ReverseOverReverseHVPPrep(inner_gradient_prep, outer_pullback_prep) end function hvp( @@ -446,11 +389,24 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end return pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -465,11 +421,24 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end return pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -483,11 +452,24 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end return value_and_pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -503,11 +485,24 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end new_grad, _ = value_and_pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl deleted file mode 100644 index 6e2d947a3..000000000 --- a/DifferentiationInterface/src/utils/check.jl +++ /dev/null @@ -1,24 +0,0 @@ -""" - check_available(backend) - -Check whether `backend` is available (i.e. whether the extension is loaded). -""" -check_available(backend::AbstractADType) = false - -function check_available(backend::SecondOrder) - return check_available(inner(backend)) && check_available(outer(backend)) -end - -check_available(backend::AutoSparse) = check_available(dense_ad(backend)) - -function check_available(backend::MixedMode) - return check_available(forward_backend(backend)) && - check_available(reverse_backend(backend)) -end - -""" - check_inplace(backend) - -Check whether `backend` supports differentiation of in-place functions. -""" -check_inplace(backend::AbstractADType) = Bool(inplace_support(backend)) diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 15a9d4a0e..4ab112054 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -1,12 +1,3 @@ -struct FixTail{F,A<:Tuple} - f::F - tail_args::A -end - -function (ft::FixTail)(args::Vararg{Any,N}) where {N} - return ft.f(args..., ft.tail_args...) -end - """ Context @@ -22,12 +13,15 @@ abstract type Context end unwrap(c::Context) = c.data Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) +abstract type GeneralizedConstant <: Context end +abstract type GeneralizedCache <: Context end + ## Public contexts """ Constant -Concrete type of [`Context`](@ref) argument which is kept constant during differentiation. +Concrete subtype of [`Context`](@ref) argument which is kept constant during differentiation. Note that an operator can be prepared with an arbitrary value of the constant. However, same-point preparation must occur with the exact value that will be reused later. @@ -55,7 +49,7 @@ julia> gradient(f, AutoForwardDiff(), [1.0, 2.0], Constant(100)) 400.0 ``` """ -struct Constant{T} <: Context +struct Constant{T} <: GeneralizedConstant data::T end @@ -65,14 +59,14 @@ maker(::Constant) = constant_maker """ Cache -Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation. +Concrete subtype of [`Context`](@ref) argument which can be mutated with active values during differentiation. The initial values present inside the cache do not matter. !!! warning Most backends require any `Cache` context to be an `AbstractArray`. """ -struct Cache{T} <: Context +struct Cache{T} <: GeneralizedCache data::T end @@ -81,15 +75,47 @@ maker(::Cache) = cache_maker ## Internal contexts for passing stuff around -struct FunctionContext{T} <: Context +""" + FunctionContext + +Concrete subtype of [`Context`](@ref) argument designed to contain functions, for internal use only. + +It is mostly similar to [`Constant`](@ref). +""" +struct FunctionContext{T} <: GeneralizedConstant + data::T +end + +""" + BackendContext + +Concrete subtype of [`Context`](@ref) argument designed to contain backend objects, for internal use only. + +It is mostly similar to [`Constant`](@ref). +""" +struct BackendContext{T<:AbstractADType} <: GeneralizedConstant data::T end -struct BackendContext{T} <: Context +""" + PrepContext <: Context + +Concrete subtype of [`Context`](@ref) argument designed to contain preparation objects, for internal use only. + +It is mostly similar to [`Cache`](@ref). +""" +struct PrepContext{T<:Prep} <: GeneralizedCache data::T end -const ConstantOrFunctionOrBackend = Union{Constant,FunctionContext,BackendContext} +""" + UnknownContext <: Context + +Concrete subtype of [`Context`](@ref) argument designed as a placeholder for when a future context value is not yet known, for internal use only. + +It is relevant in second-order preparation. +""" +struct UnknownContext <: Context end ## Context manipulation @@ -109,6 +135,15 @@ function (r::Rewrap{C,T})(unannotated_contexts::Vararg{Any,C}) where {C,T} end end +struct FixTail{F,A<:Tuple} + f::F + tail_args::A +end + +function (ft::FixTail)(args::Vararg{Any,N}) where {N} + return ft.f(args..., ft.tail_args...) +end + with_contexts(f) = f function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index 055d32fd8..c8eaeb852 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -1,46 +1,80 @@ -## Mutation - -abstract type InPlaceBehavior end +## Availability """ - InPlaceSupported + check_available(backend) -Trait identifying backends that support in-place functions `f!(y, x)`. +Check whether `backend` is available (i.e. whether the extension is loaded). """ -struct InPlaceSupported <: InPlaceBehavior end +check_available(backend::AbstractADType) = false + +function check_available(backend::SecondOrder) + return check_available(inner(backend)) && check_available(outer(backend)) +end + +check_available(backend::AutoSparse) = check_available(dense_ad(backend)) + +function check_available(backend::MixedMode) + return check_available(forward_backend(backend)) && + check_available(reverse_backend(backend)) +end + +## Mutation """ - InPlaceNotSupported + check_inplace(backend) + +Check whether `backend` supports differentiation of in-place functions. -Trait identifying backends that do not support in-place functions `f!(y, x)`. +Returns `true` or `false` in a statically predictable way. + +!!! warning + This function defaults to `true` if the backend is not loaded. """ -struct InPlaceNotSupported <: InPlaceBehavior end +check_inplace(::AbstractADType) = true + +function check_inplace(backend::SecondOrder) + return check_inplace(inner(backend)) && check_inplace(outer(backend)) +end + +check_inplace(backend::AutoSparse) = check_inplace(dense_ad(backend)) + +function check_inplace(backend::MixedMode) + return check_inplace(forward_backend(backend)) && + check_inplace(reverse_backend(backend)) +end + +## Operator overloading """ - inplace_support(backend) + check_operator_overloading(backend) + +Check whether backend relies on operator overloading. + +Returns `true` or `false` in a statically predictable way. -Return [`InPlaceSupported`](@ref) or [`InPlaceNotSupported`](@ref) in a statically predictable way. +!!! warning + This function defaults to `false` if the backend is not loaded. """ -inplace_support(::AbstractADType) = InPlaceSupported() +check_operator_overloading(::AbstractADType) = false -function inplace_support(backend::SecondOrder) - if inplace_support(inner(backend)) isa InPlaceSupported && - inplace_support(outer(backend)) isa InPlaceSupported - return InPlaceSupported() - else - return InPlaceNotSupported() - end +function check_operator_overloading(::SecondOrder) + return throw( + ArgumentError( + "Operator overloading check does not make sense for second-order backend" + ), + ) end -inplace_support(backend::AutoSparse) = inplace_support(dense_ad(backend)) +function check_operator_overloading(backend::AutoSparse) + return check_operator_overloading(dense_ad(backend)) +end -function inplace_support(backend::MixedMode) - if Bool(inplace_support(forward_backend(backend))) && - Bool(inplace_support(reverse_backend(backend))) - return InPlaceSupported() - else - return InPlaceNotSupported() - end +function check_operator_overloading(::MixedMode) + return throw( + ArgumentError( + "Operator overloading check does not make sense for mixed-mode backend" + ), + ) end ## Pushforward @@ -161,9 +195,6 @@ end ## Conversions -Base.Bool(::InPlaceSupported) = true -Base.Bool(::InPlaceNotSupported) = false - Base.Bool(::PushforwardFast) = true Base.Bool(::PushforwardSlow) = false diff --git a/DifferentiationInterface/test/Back/ChainRules/zygote.jl b/DifferentiationInterface/test/Back/ChainRules/zygote.jl index ef54db928..67ae40dd9 100644 --- a/DifferentiationInterface/test/Back/ChainRules/zygote.jl +++ b/DifferentiationInterface/test/Back/ChainRules/zygote.jl @@ -11,15 +11,10 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoChainRules(ZygoteRuleConfig())] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoChainRules(ZygoteRuleConfig()), default_scenarios(); - excluded=[:second_derivative], + excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Diffractor/test.jl b/DifferentiationInterface/test/Back/Diffractor/test.jl index 08315b9c3..d90475db2 100644 --- a/DifferentiationInterface/test/Back/Diffractor/test.jl +++ b/DifferentiationInterface/test/Back/Diffractor/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoDiffractor()] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoDiffractor(), default_scenarios(; linalg=false); diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index f4803ccb3..989f28bf2 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -34,13 +34,6 @@ duplicated_backends = [ AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Duplicated), ] -@testset "Checks" begin - @testset "Check $(typeof(backend))" for backend in backends - @test check_available(backend) - @test check_inplace(backend) - end -end; - @testset "First order" begin test_differentiation( backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index 5276d6233..370d50046 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -12,11 +12,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoFiniteDiff()] - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( AutoFiniteDiff(), default_scenarios(; include_constantified=true, include_cachified=true); diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index cb83ab0cf..692feb5c8 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1))] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), default_scenarios(; include_constantified=true, include_cachified=true); diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 061836a81..ee65aafcc 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -23,11 +23,6 @@ backends = [ AutoForwardDiff(; tag=ForwardDiff.Tag(MyTag(), Float64)), ] -for backend in backends - @test check_available(backend) - @test check_inplace(backend) -end - ## Dense test_differentiation( diff --git a/DifferentiationInterface/test/Back/GTPSA/test.jl b/DifferentiationInterface/test/Back/GTPSA/test.jl index 68b66918e..16d13bb74 100644 --- a/DifferentiationInterface/test/Back/GTPSA/test.jl +++ b/DifferentiationInterface/test/Back/GTPSA/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoGTPSA()] - @test check_available(backend) - @test check_inplace(backend) -end - # Test no Descriptor (use context) test_differentiation( AutoGTPSA(), diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 35df1002c..eb83a03a2 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -12,11 +12,6 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())] -for backend in backends - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( backends, default_scenarios(; include_constantified=true, include_cachified=true); diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 21fb3d297..2a888574a 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -17,11 +17,6 @@ backends = [ AutoPolyesterForwardDiff(; chunksize=2), ] -for backend in backends - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( backends, default_scenarios(; include_constantified=true); logging=LOGGING ); diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index 8549de0e6..d506913a8 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -16,11 +16,6 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] second_order_backends = [SecondOrder(AutoForwardDiff(), AutoReverseDiff())] -for backend in vcat(backends, second_order_backends) - @test check_available(backend) - @test check_inplace(backend) -end - ## Dense test_differentiation( diff --git a/DifferentiationInterface/test/Back/SparsityDetector/test.jl b/DifferentiationInterface/test/Back/SparsityDetector/test.jl index 1ea6a975a..560e954f4 100644 --- a/DifferentiationInterface/test/Back/SparsityDetector/test.jl +++ b/DifferentiationInterface/test/Back/SparsityDetector/test.jl @@ -12,6 +12,9 @@ using Test rng = StableRNG(63) +using Random +rng = Random.default_rng() + const Jc = sprand(rng, Bool, 10, 20, 0.3) const Hc = sparse(Symmetric(sprand(rng, Bool, 20, 20, 0.3))) diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl index 191025ae2..24c6475f1 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoFastDifferentiation(), AutoSparse(AutoFastDifferentiation())] - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( AutoFastDifferentiation(), filter(default_scenarios(; include_constantified=true, include_cachified=true)) do s diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl index 31f8316a0..b4ca3b119 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoSymbolics(), AutoSparse(AutoSymbolics())] - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( AutoSymbolics(), default_scenarios(; include_constantified=true); logging=LOGGING ); diff --git a/DifferentiationInterface/test/Back/Tracker/test.jl b/DifferentiationInterface/test/Back/Tracker/test.jl index 0131a183e..557798ca0 100644 --- a/DifferentiationInterface/test/Back/Tracker/test.jl +++ b/DifferentiationInterface/test/Back/Tracker/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoTracker()] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoTracker(), default_scenarios(; include_constantified=true); diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 4f7943b17..a115bc83a 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -17,11 +17,6 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [AutoZygote()] second_order_backends = [SecondOrder(AutoForwardDiff(), AutoZygote())] -for backend in vcat(backends, second_order_backends) - @test check_available(backend) - @test !check_inplace(backend) -end - ## Dense @testset "Dense" begin diff --git a/DifferentiationInterface/test/Core/Internals/backends.jl b/DifferentiationInterface/test/Core/Internals/backends.jl index 5ce73706f..1f7bda4a1 100644 --- a/DifferentiationInterface/test/Core/Internals/backends.jl +++ b/DifferentiationInterface/test/Core/Internals/backends.jl @@ -8,7 +8,8 @@ using DifferentiationInterface: outer, forward_backend, reverse_backend, - inplace_support, + check_inplace, + check_operator_overloading, pushforward_performance, pullback_performance, hvp_mode @@ -24,7 +25,8 @@ rb = AutoReverseFromPrimitive(AutoSimpleFiniteDiff()) @test outer(backend) isa AutoSimpleFiniteDiff @test inner(backend) isa AutoReverseFromPrimitive @test mode(backend) isa ADTypes.ForwardMode - @test Bool(inplace_support(backend)) + @test check_inplace(backend) + @test_throws ArgumentError check_operator_overloading(backend) @test_throws ArgumentError pushforward_performance(backend) @test_throws ArgumentError pullback_performance(backend) end @@ -35,7 +37,8 @@ end @test mode(backend) isa DifferentiationInterface.ForwardAndReverseMode @test forward_backend(backend) isa AutoSimpleFiniteDiff @test reverse_backend(backend) isa AutoReverseFromPrimitive - @test Bool(inplace_support(backend)) + @test check_inplace(backend) + @test_throws ArgumentError check_operator_overloading(backend) @test_throws MethodError pushforward_performance(backend) @test_throws MethodError pullback_performance(backend) end @@ -44,9 +47,24 @@ end for dense_backend in [fb, rb] backend = AutoSparse(dense_backend) @test mode(backend) == ADTypes.mode(dense_backend) - @test Bool(inplace_support(backend)) + @test check_inplace(backend) + @test !check_operator_overloading(backend) @test_throws ArgumentError pushforward_performance(backend) @test_throws ArgumentError pullback_performance(backend) @test_throws ArgumentError hvp_mode(backend) end end + +struct FakeBackend <: AbstractADType end +struct FakeOOBackend <: AbstractADType end +DI.check_operator_overloading(::FakeOOBackend) = true + +@testset "Fake" begin + backend = FakeBackend() + @test !check_available(backend) + @test check_inplace(backend) + @test !check_operator_overloading(backend) + @test_throws ArgumentError DI.overloaded_input( + pushforward, identity, FakeOOBackend(), 1.0, (1.0,) + ) +end diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 0571a73b3..37c2549c5 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -11,9 +11,8 @@ LOGGING = get(ENV, "CI", "false") == "false" zero_backends = [AutoZeroForward(), AutoZeroReverse()] -for backend in zero_backends - @test check_available(backend) - @test check_inplace(backend) +@testset "Correctness" begin + test_differentiation(zero_backends, map(zero, default_scenarios()); logging=LOGGING) end @testset "Type stability" begin diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 2c92546ed..49630a8e8 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -86,7 +86,7 @@ using DifferentiationInterface: inner, mode, outer, - inplace_support, + check_inplace, pushforward_performance, pullback_performance using DifferentiationInterface: Rewrap, Context, Constant, Cache, unwrap diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index ef9767df4..729048c16 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -266,8 +266,8 @@ end ## Array to number -const α = 4 -const β = 6 +const α = 3 +const β = 4 arr_to_num_linalg(x::AbstractArray) = sum(vec(x .^ α) .* transpose(vec(x .^ β))) diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index e113b912e..0517ddcfa 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -101,7 +101,7 @@ function order(scen::Scenario) end function compatible(backend::AbstractADType, scen::Scenario) - place_compatible = function_place(scen) == :out || Bool(inplace_support(backend)) + place_compatible = function_place(scen) == :out || check_inplace(backend) sparse_compatible = operator(scen) in (:jacobian, :hessian) || !isa(backend, AutoSparse) secondorder_compatible = order(scen) == 2 || !isa(backend, Union{SecondOrder,AutoSparse{<:SecondOrder}}) diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index 4ad566a35..2a6f3fcd8 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -117,6 +117,7 @@ function test_differentiation( @testset verbose = true "$title" begin @testset verbose = detailed "$backend" for (i, backend) in enumerate(backends) + @test DI.check_available(backend) filtered_scenarios = filter(s -> compatible(backend, s), scenarios) grouped_scenarios = group_by_operator(filtered_scenarios) @testset verbose = detailed "$op" for (j, (op, op_group)) in diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index e8f5ace6c..4e3051fcf 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -44,6 +44,7 @@ data1 = benchmark_differentiation( struct FakeBackend <: ADTypes.AbstractADType end ADTypes.mode(::FakeBackend) = ADTypes.ForwardMode() +DifferentiationInterface.check_available(::FakeBackend) = true data2 = benchmark_differentiation( FakeBackend(),