From f93ea8fe38b3f8d90b4c1e632951bdf78887e866 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 29 Jan 2025 21:02:31 +0100 Subject: [PATCH 01/24] perf: exploit in-place backends for HVPs --- .../docs/src/dev_guide.md | 2 +- ...fferentiationInterfaceChainRulesCoreExt.jl | 2 +- .../DifferentiationInterfaceDiffractorExt.jl | 2 +- ...rentiationInterfaceFiniteDifferencesExt.jl | 2 +- .../DifferentiationInterfaceTrackerExt.jl | 2 +- .../DifferentiationInterfaceZygoteExt.jl | 2 +- .../src/DifferentiationInterface.jl | 1 - .../src/first_order/gradient.jl | 23 ++ .../src/misc/from_primitive.jl | 2 +- .../src/misc/simple_finite_diff.jl | 1 - .../src/misc/zero_backends.jl | 2 - .../src/second_order/hvp.jl | 383 ++++++++++++++---- DifferentiationInterface/src/utils/check.jl | 24 -- DifferentiationInterface/src/utils/traits.jl | 58 ++- .../test/Core/Internals/backends.jl | 8 +- .../src/DifferentiationInterfaceTest.jl | 2 +- .../src/scenarios/default.jl | 4 +- .../src/scenarios/scenario.jl | 2 +- 18 files changed, 360 insertions(+), 162 deletions(-) delete mode 100644 DifferentiationInterface/src/utils/check.jl diff --git a/DifferentiationInterface/docs/src/dev_guide.md b/DifferentiationInterface/docs/src/dev_guide.md index 10e74a905..b1b9f420f 100644 --- a/DifferentiationInterface/docs/src/dev_guide.md +++ b/DifferentiationInterface/docs/src/dev_guide.md @@ -36,7 +36,7 @@ Your AD package needs to be registered first. ### Core code In the main package, you should define a new struct `SuperDiffBackend` which subtypes [`ADTypes.AbstractADType`](@extref ADTypes), and endow it with the fields you need to parametrize your differentiation routines. -You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`. +You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.check_inplace`](@ref) on `SuperDiffBackend`. !!! info In the end, this backend struct will need to be contributed to [ADTypes.jl](https://github.com/SciML/ADTypes.jl). diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index bacb7baa4..9ce083126 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -18,7 +18,7 @@ const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}} const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} DI.check_available(::AutoChainRules) = true -DI.inplace_support(::AutoChainRules) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoChainRules) = false include("reverse_onearg.jl") include("differentiate_with.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 2973f3b37..42569a19d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -5,7 +5,7 @@ import DifferentiationInterface as DI using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆ DI.check_available(::AutoDiffractor) = true -DI.inplace_support(::AutoDiffractor) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoDiffractor) = false DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 54dd38501..bd43c65d8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -6,7 +6,7 @@ using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp using LinearAlgebra: dot DI.check_available(::AutoFiniteDifferences) = true -DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoFiniteDifferences) = false ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index fb5da6b76..b95ed22f0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -5,7 +5,7 @@ import DifferentiationInterface as DI using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgradient DI.check_available(::AutoTracker) = true -DI.inplace_support(::AutoTracker) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoTracker) = false ## Pullback diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index d86f2ec89..7f5b004ea 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -25,7 +25,7 @@ check_nothing(::Nothing, f, x, contexts) = throw(ZygoteNothingError(f, x, contex check_nothing(::Any, f, x, contexts) = nothing DI.check_available(::AutoZygote) = true -DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() +DI.check_inplace(::AutoZygote) = false ## Pullback diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 99ecfc2d4..f8d1f6ca4 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -41,7 +41,6 @@ include("utils/prep.jl") include("utils/traits.jl") include("utils/basis.jl") include("utils/batchsize.jl") -include("utils/check.jl") include("utils/printing.jl") include("utils/context.jl") include("utils/linalg.jl") diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 13a260b02..d689da5c1 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -129,6 +129,17 @@ function shuffled_gradient( return gradient(f, backend, x, rewrap(unannotated_contexts...)...) end +function shuffled_gradient!( + grad, + x, + f::F, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + return gradient!(f, grad, backend, x, rewrap(unannotated_contexts...)...) +end + function shuffled_gradient( x, f::F, @@ -139,3 +150,15 @@ function shuffled_gradient( ) where {F,C} return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) end + +function shuffled_gradient!( + grad, + x, + f::F, + prep::GradientPrep, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + return gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...) +end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 59c2663c1..5645443fd 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -1,7 +1,7 @@ abstract type FromPrimitive <: AbstractADType end check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) -inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend) +check_inplace(fromprim::FromPrimitive) = check_inplace(fromprim.backend) function pick_batchsize(fromprim::FromPrimitive, N::Integer) return pick_batchsize(fromprim.backend, N) diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index 11f05ac1e..6bdbd0e84 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -17,7 +17,6 @@ end ADTypes.mode(::AutoSimpleFiniteDiff) = ForwardMode() check_available(::AutoSimpleFiniteDiff) = true -inplace_support(::AutoSimpleFiniteDiff) = InPlaceSupported() function pick_batchsize(::AutoSimpleFiniteDiff{nothing}, N::Integer) B = reasonable_batchsize(N, 12) diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index a340edd8f..4c253cfab 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -18,7 +18,6 @@ struct AutoZeroForward <: AbstractADType end ADTypes.mode(::AutoZeroForward) = ForwardMode() check_available(::AutoZeroForward) = true -inplace_support(::AutoZeroForward) = InPlaceSupported() function prepare_pushforward( f::F, ::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C} @@ -104,7 +103,6 @@ struct AutoZeroReverse <: AbstractADType end ADTypes.mode(::AutoZeroReverse) = ReverseMode() check_available(::AutoZeroReverse) = true -inplace_support(::AutoZeroReverse) = InPlaceSupported() function prepare_pullback( f::F, ::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C} diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index cea59e7ef..96368e7d4 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -76,7 +76,8 @@ end ## Forward over forward -struct ForwardOverForwardHVPPrep{E2<:PushforwardPrep} <: HVPPrep +struct ForwardOverForwardHVPPrep{G,E2<:PushforwardPrep} <: HVPPrep + grad_buffer::G # pushforward of many pushforwards in theory, but pushforward of gradient in practice outer_pushforward_prep::E2 end @@ -93,10 +94,17 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) - return ForwardOverForwardHVPPrep(outer_pushforward_prep) + grad_buffer = similar(x) + if check_inplace(backend) + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + else + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + end + return ForwardOverForwardHVPPrep(grad_buffer, outer_pushforward_prep) end function hvp( @@ -107,14 +115,31 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) + if check_inplace(backend) + return pushforward( + shuffled_gradient!, + grad_buffer, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + return pushforward( + shuffled_gradient, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + end end function hvp!( @@ -126,20 +151,33 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) + if check_inplace(backend) + return pushforward!( + shuffled_gradient!, + grad_buffer, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + return pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + end end function gradient_and_hvp( @@ -150,14 +188,32 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return value_and_pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) + if check_inplace(backend) + y, tg = value_and_pushforward( + shuffled_gradient!, + grad_buffer, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return copy(y), tg + else + return value_and_pushforward( + shuffled_gradient, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + end end function gradient_and_hvp!( @@ -175,21 +231,35 @@ function gradient_and_hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - new_grad, _ = value_and_pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg + if check_inplace(backend) + return value_and_pushforward!( + shuffled_gradient!, + grad, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + new_grad, _ = value_and_pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return copyto!(grad, new_grad), tg + end end ## Forward over reverse -struct ForwardOverReverseHVPPrep{E2<:PushforwardPrep} <: HVPPrep +struct ForwardOverReverseHVPPrep{G,E2<:PushforwardPrep} <: HVPPrep + grad_buffer::G # pushforward of gradient outer_pushforward_prep::E2 end @@ -206,10 +276,17 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) - return ForwardOverReverseHVPPrep(outer_pushforward_prep) + grad_buffer = similar(x) + if check_inplace(backend) + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + else + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + end + return ForwardOverReverseHVPPrep(grad_buffer, outer_pushforward_prep) end function hvp( @@ -220,14 +297,31 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) + if check_inplace(backend) + return pushforward( + shuffled_gradient!, + grad_buffer, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + return pushforward( + shuffled_gradient, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + end end function hvp!( @@ -239,20 +333,33 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) + if check_inplace(backend) + return pushforward!( + shuffled_gradient!, + grad_buffer, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + return pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + end end function gradient_and_hvp( @@ -263,14 +370,32 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return value_and_pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) + if check_inplace(backend) + y, tg = value_and_pushforward( + shuffled_gradient!, + grad_buffer, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return copy(y), tg + else + return value_and_pushforward( + shuffled_gradient, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + end end function gradient_and_hvp!( @@ -288,16 +413,29 @@ function gradient_and_hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - new_grad, _ = value_and_pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg + if check_inplace(backend) + return value_and_pushforward!( + shuffled_gradient!, + grad, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + new_grad, _ = value_and_pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return copyto!(grad, new_grad), tg + end end ## Reverse over forward @@ -415,7 +553,8 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{E2<:PullbackPrep} <: HVPPrep +struct ReverseOverReverseHVPPrep{G,E2<:PullbackPrep} <: HVPPrep + grad_buffer::G # pullback of gradient outer_pullback_prep::E2 end @@ -432,10 +571,17 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - outer_pullback_prep = prepare_pullback( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) - return ReverseOverReverseHVPPrep(outer_pullback_prep) + grad_buffer = similar(x) + if check_inplace(backend) + outer_pullback_prep = prepare_pullback( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + else + outer_pullback_prep = prepare_pullback( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + end + return ReverseOverReverseHVPPrep(grad_buffer, outer_pullback_prep) end function hvp( @@ -446,14 +592,26 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; grad_buffer, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return pullback( - shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... - ) + if check_inplace(backend) + return pullback( + shuffled_gradient!, + grad_buffer, + outer_pullback_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + return pullback( + shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) + end end function hvp!( @@ -465,14 +623,33 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; grad_buffer, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return pullback!( - shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... - ) + if check_inplace(backend) + return pullback!( + shuffled_gradient!, + grad_buffer, + tg, + outer_pullback_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + return pullback!( + shuffled_gradient, + tg, + outer_pullback_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + end end function gradient_and_hvp( @@ -483,14 +660,27 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; grad_buffer, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - return value_and_pullback( - shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... - ) + if check_inplace(backend) + y, tg = value_and_pullback( + shuffled_gradient!, + grad_buffer, + outer_pullback_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return copy(y), tg + else + return value_and_pullback( + shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) + end end function gradient_and_hvp!( @@ -508,8 +698,27 @@ function gradient_and_hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - new_grad, _ = value_and_pullback!( - shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... - ) - return copyto!(grad, new_grad), tg + if check_inplace(backend) + return value_and_pullback!( + shuffled_gradient!, + grad, + tg, + outer_pullback_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + else + new_grad, _ = value_and_pullback!( + shuffled_gradient, + tg, + outer_pullback_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return copyto!(grad, new_grad), tg + end end diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl deleted file mode 100644 index 6e2d947a3..000000000 --- a/DifferentiationInterface/src/utils/check.jl +++ /dev/null @@ -1,24 +0,0 @@ -""" - check_available(backend) - -Check whether `backend` is available (i.e. whether the extension is loaded). -""" -check_available(backend::AbstractADType) = false - -function check_available(backend::SecondOrder) - return check_available(inner(backend)) && check_available(outer(backend)) -end - -check_available(backend::AutoSparse) = check_available(dense_ad(backend)) - -function check_available(backend::MixedMode) - return check_available(forward_backend(backend)) && - check_available(reverse_backend(backend)) -end - -""" - check_inplace(backend) - -Check whether `backend` supports differentiation of in-place functions. -""" -check_inplace(backend::AbstractADType) = Bool(inplace_support(backend)) diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index 055d32fd8..797dbb0dd 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -1,46 +1,43 @@ -## Mutation - -abstract type InPlaceBehavior end +## Availability """ - InPlaceSupported + check_available(backend) -Trait identifying backends that support in-place functions `f!(y, x)`. +Check whether `backend` is available (i.e. whether the extension is loaded). """ -struct InPlaceSupported <: InPlaceBehavior end +check_available(backend::AbstractADType) = false -""" - InPlaceNotSupported +function check_available(backend::SecondOrder) + return check_available(inner(backend)) && check_available(outer(backend)) +end -Trait identifying backends that do not support in-place functions `f!(y, x)`. -""" -struct InPlaceNotSupported <: InPlaceBehavior end +check_available(backend::AutoSparse) = check_available(dense_ad(backend)) + +function check_available(backend::MixedMode) + return check_available(forward_backend(backend)) && + check_available(reverse_backend(backend)) +end + +## Mutation """ - inplace_support(backend) + check_inplace(backend) + +Check whether `backend` supports differentiation of in-place functions. -Return [`InPlaceSupported`](@ref) or [`InPlaceNotSupported`](@ref) in a statically predictable way. +Returns `true` or `false` in a statically predictable way. """ -inplace_support(::AbstractADType) = InPlaceSupported() +check_inplace(::AbstractADType) = true -function inplace_support(backend::SecondOrder) - if inplace_support(inner(backend)) isa InPlaceSupported && - inplace_support(outer(backend)) isa InPlaceSupported - return InPlaceSupported() - else - return InPlaceNotSupported() - end +function check_inplace(backend::SecondOrder) + return check_inplace(inner(backend)) && check_inplace(outer(backend)) end -inplace_support(backend::AutoSparse) = inplace_support(dense_ad(backend)) +check_inplace(backend::AutoSparse) = check_inplace(dense_ad(backend)) -function inplace_support(backend::MixedMode) - if Bool(inplace_support(forward_backend(backend))) && - Bool(inplace_support(reverse_backend(backend))) - return InPlaceSupported() - else - return InPlaceNotSupported() - end +function check_inplace(backend::MixedMode) + return check_inplace(forward_backend(backend)) && + check_inplace(reverse_backend(backend)) end ## Pushforward @@ -161,9 +158,6 @@ end ## Conversions -Base.Bool(::InPlaceSupported) = true -Base.Bool(::InPlaceNotSupported) = false - Base.Bool(::PushforwardFast) = true Base.Bool(::PushforwardSlow) = false diff --git a/DifferentiationInterface/test/Core/Internals/backends.jl b/DifferentiationInterface/test/Core/Internals/backends.jl index 5ce73706f..2edca6aa3 100644 --- a/DifferentiationInterface/test/Core/Internals/backends.jl +++ b/DifferentiationInterface/test/Core/Internals/backends.jl @@ -8,7 +8,7 @@ using DifferentiationInterface: outer, forward_backend, reverse_backend, - inplace_support, + check_inplace, pushforward_performance, pullback_performance, hvp_mode @@ -24,7 +24,7 @@ rb = AutoReverseFromPrimitive(AutoSimpleFiniteDiff()) @test outer(backend) isa AutoSimpleFiniteDiff @test inner(backend) isa AutoReverseFromPrimitive @test mode(backend) isa ADTypes.ForwardMode - @test Bool(inplace_support(backend)) + @test check_inplace(backend) @test_throws ArgumentError pushforward_performance(backend) @test_throws ArgumentError pullback_performance(backend) end @@ -35,7 +35,7 @@ end @test mode(backend) isa DifferentiationInterface.ForwardAndReverseMode @test forward_backend(backend) isa AutoSimpleFiniteDiff @test reverse_backend(backend) isa AutoReverseFromPrimitive - @test Bool(inplace_support(backend)) + @test check_inplace(backend) @test_throws MethodError pushforward_performance(backend) @test_throws MethodError pullback_performance(backend) end @@ -44,7 +44,7 @@ end for dense_backend in [fb, rb] backend = AutoSparse(dense_backend) @test mode(backend) == ADTypes.mode(dense_backend) - @test Bool(inplace_support(backend)) + @test check_inplace(backend) @test_throws ArgumentError pushforward_performance(backend) @test_throws ArgumentError pullback_performance(backend) @test_throws ArgumentError hvp_mode(backend) diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 2c92546ed..49630a8e8 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -86,7 +86,7 @@ using DifferentiationInterface: inner, mode, outer, - inplace_support, + check_inplace, pushforward_performance, pullback_performance using DifferentiationInterface: Rewrap, Context, Constant, Cache, unwrap diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index ef9767df4..729048c16 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -266,8 +266,8 @@ end ## Array to number -const α = 4 -const β = 6 +const α = 3 +const β = 4 arr_to_num_linalg(x::AbstractArray) = sum(vec(x .^ α) .* transpose(vec(x .^ β))) diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index e113b912e..0517ddcfa 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -101,7 +101,7 @@ function order(scen::Scenario) end function compatible(backend::AbstractADType, scen::Scenario) - place_compatible = function_place(scen) == :out || Bool(inplace_support(backend)) + place_compatible = function_place(scen) == :out || check_inplace(backend) sparse_compatible = operator(scen) in (:jacobian, :hessian) || !isa(backend, AutoSparse) secondorder_compatible = order(scen) == 2 || !isa(backend, Union{SecondOrder,AutoSparse{<:SecondOrder}}) From 6426efa370c9ab8aebda205746b648f3e714b727 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 29 Jan 2025 21:04:08 +0100 Subject: [PATCH 02/24] only outer backend matters --- .../src/second_order/hvp.jl | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 96368e7d4..c7a574300 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -95,7 +95,7 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - if check_inplace(backend) + if check_inplace(outer(backend)) outer_pushforward_prep = prepare_pushforward( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) @@ -120,7 +120,7 @@ function hvp( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return pushforward( shuffled_gradient!, grad_buffer, @@ -156,7 +156,7 @@ function hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return pushforward!( shuffled_gradient!, grad_buffer, @@ -193,7 +193,7 @@ function gradient_and_hvp( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) y, tg = value_and_pushforward( shuffled_gradient!, grad_buffer, @@ -231,7 +231,7 @@ function gradient_and_hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return value_and_pushforward!( shuffled_gradient!, grad, @@ -277,7 +277,7 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - if check_inplace(backend) + if check_inplace(outer(backend)) outer_pushforward_prep = prepare_pushforward( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) @@ -302,7 +302,7 @@ function hvp( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return pushforward( shuffled_gradient!, grad_buffer, @@ -338,7 +338,7 @@ function hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return pushforward!( shuffled_gradient!, grad_buffer, @@ -375,7 +375,7 @@ function gradient_and_hvp( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) y, tg = value_and_pushforward( shuffled_gradient!, grad_buffer, @@ -413,7 +413,7 @@ function gradient_and_hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return value_and_pushforward!( shuffled_gradient!, grad, @@ -572,7 +572,7 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - if check_inplace(backend) + if check_inplace(outer(backend)) outer_pullback_prep = prepare_pullback( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) @@ -597,7 +597,7 @@ function hvp( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return pullback( shuffled_gradient!, grad_buffer, @@ -628,7 +628,7 @@ function hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return pullback!( shuffled_gradient!, grad_buffer, @@ -665,7 +665,7 @@ function gradient_and_hvp( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) y, tg = value_and_pullback( shuffled_gradient!, grad_buffer, @@ -698,7 +698,7 @@ function gradient_and_hvp!( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(backend) + if check_inplace(outer(backend)) return value_and_pullback!( shuffled_gradient!, grad, From 811198b76f9e7226f0befa372f80b7a34fc0a70c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 29 Jan 2025 22:35:27 +0100 Subject: [PATCH 03/24] Only in-place --- .../secondorder.jl | 35 ++- .../src/second_order/hvp.jl | 221 ++++++------------ 2 files changed, 93 insertions(+), 163 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl index d01efd074..39c066c98 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl @@ -1,7 +1,10 @@ -struct ForwardDiffOverSomethingHVPPrep{E1<:DI.GradientPrep,E2<:DI.PushforwardPrep} <: - DI.HVPPrep +struct ForwardDiffOverSomethingHVPPrep{ + G,E1<:DI.GradientPrep,E2<:DI.PushforwardPrep,E2IP<:DI.PushforwardPrep +} <: DI.HVPPrep + grad_buffer::G inner_gradient_prep::E1 outer_pushforward_prep::E2 + outer_pushforward_prep_inplace::E2IP end function DI.prepare_hvp( @@ -22,10 +25,19 @@ function DI.prepare_hvp( DI.Constant(rewrap), contexts..., ) + grad_buffer = similar(x) outer_pushforward_prep = DI.prepare_pushforward( DI.shuffled_gradient, DI.outer(backend), x, tx, new_contexts... ) - return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep) + outer_pushforward_prep_inplace = DI.prepare_pushforward( + DI.shuffled_gradient!, grad_buffer, DI.outer(backend), x, tx, new_contexts... + ) + return ForwardDiffOverSomethingHVPPrep( + grad_buffer, + inner_gradient_prep, + outer_pushforward_prep, + outer_pushforward_prep_inplace, + ) end function DI.hvp( @@ -64,7 +76,7 @@ function DI.hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep + (; grad_buffer, inner_gradient_prep, outer_pushforward_prep_inplace) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( DI.FunctionContext(f), @@ -74,9 +86,10 @@ function DI.hvp!( contexts..., ) return DI.pushforward!( - DI.shuffled_gradient, + DI.shuffled_gradient!, + grad_buffer, tg, - outer_pushforward_prep, + outer_pushforward_prep_inplace, DI.outer(backend), x, tx, @@ -122,7 +135,7 @@ function DI.gradient_and_hvp!( tx::NTuple, contexts::Vararg{DI.Context,C}, ) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep_inplace) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( DI.FunctionContext(f), @@ -131,14 +144,14 @@ function DI.gradient_and_hvp!( DI.Constant(rewrap), contexts..., ) - new_grad, _ = DI.value_and_pushforward!( - DI.shuffled_gradient, + return DI.value_and_pushforward!( + DI.shuffled_gradient!, + grad, tg, - outer_pushforward_prep, + outer_pushforward_prep_inplace, DI.outer(backend), x, tx, new_contexts..., ) - return copyto!(grad, new_grad), tg end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index c7a574300..fd0101ed1 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -76,10 +76,11 @@ end ## Forward over forward -struct ForwardOverForwardHVPPrep{G,E2<:PushforwardPrep} <: HVPPrep +struct ForwardOverForwardHVPPrep{G,E2<:PushforwardPrep,E2IP} <: HVPPrep grad_buffer::G # pushforward of many pushforwards in theory, but pushforward of gradient in practice outer_pushforward_prep::E2 + outer_pushforward_prep_inplace::E2IP end function _prepare_hvp_aux( @@ -95,16 +96,19 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - if check_inplace(outer(backend)) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep_inplace = if check_inplace(outer(backend)) + prepare_pushforward( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) else - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) + nothing end - return ForwardOverForwardHVPPrep(grad_buffer, outer_pushforward_prep) + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + return ForwardOverForwardHVPPrep( + grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace + ) end function hvp( @@ -115,31 +119,14 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_prep) = prep + (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(outer(backend)) - return pushforward( - shuffled_gradient!, - grad_buffer, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - else - return pushforward( - shuffled_gradient, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - end + return pushforward( + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... + ) end function hvp!( @@ -151,7 +138,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -161,7 +148,7 @@ function hvp!( shuffled_gradient!, grad_buffer, tg, - outer_pushforward_prep, + outer_pushforward_prep_inplace, outer(backend), x, tx, @@ -188,32 +175,14 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_prep) = prep + (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(outer(backend)) - y, tg = value_and_pushforward( - shuffled_gradient!, - grad_buffer, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copy(y), tg - else - return value_and_pushforward( - shuffled_gradient, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - end + return value_and_pushforward( + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... + ) end function gradient_and_hvp!( @@ -226,7 +195,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; outer_pushforward_prep, outer_pushforward_prep_inplace) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -236,7 +205,7 @@ function gradient_and_hvp!( shuffled_gradient!, grad, tg, - outer_pushforward_prep, + outer_pushforward_prep_inplace, outer(backend), x, tx, @@ -258,10 +227,11 @@ end ## Forward over reverse -struct ForwardOverReverseHVPPrep{G,E2<:PushforwardPrep} <: HVPPrep +struct ForwardOverReverseHVPPrep{G,E2<:PushforwardPrep,E2IP} <: HVPPrep grad_buffer::G # pushforward of gradient outer_pushforward_prep::E2 + outer_pushforward_prep_inplace::E2IP end function _prepare_hvp_aux( @@ -277,16 +247,19 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - if check_inplace(outer(backend)) - outer_pushforward_prep = prepare_pushforward( + outer_pushforward_prep_inplace = if check_inplace(outer(backend)) + prepare_pushforward( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) else - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) + nothing end - return ForwardOverReverseHVPPrep(grad_buffer, outer_pushforward_prep) + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + return ForwardOverReverseHVPPrep( + grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace + ) end function hvp( @@ -297,31 +270,14 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_prep) = prep + (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(outer(backend)) - return pushforward( - shuffled_gradient!, - grad_buffer, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - else - return pushforward( - shuffled_gradient, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - end + return pushforward( + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... + ) end function hvp!( @@ -333,7 +289,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_prep) = prep + (; grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -343,7 +299,7 @@ function hvp!( shuffled_gradient!, grad_buffer, tg, - outer_pushforward_prep, + outer_pushforward_prep_inplace, outer(backend), x, tx, @@ -370,32 +326,14 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_prep) = prep + (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(outer(backend)) - y, tg = value_and_pushforward( - shuffled_gradient!, - grad_buffer, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copy(y), tg - else - return value_and_pushforward( - shuffled_gradient, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - end + return value_and_pushforward( + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... + ) end function gradient_and_hvp!( @@ -408,7 +346,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; outer_pushforward_prep, outer_pushforward_prep_inplace) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -418,7 +356,7 @@ function gradient_and_hvp!( shuffled_gradient!, grad, tg, - outer_pushforward_prep, + outer_pushforward_prep_inplace, outer(backend), x, tx, @@ -553,10 +491,11 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{G,E2<:PullbackPrep} <: HVPPrep +struct ReverseOverReverseHVPPrep{G,E2<:PullbackPrep,E2IP} <: HVPPrep grad_buffer::G # pullback of gradient outer_pullback_prep::E2 + outer_pullback_prep_inplace::E2IP end function _prepare_hvp_aux( @@ -572,16 +511,19 @@ function _prepare_hvp_aux( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) - if check_inplace(outer(backend)) - outer_pullback_prep = prepare_pullback( + outer_pullback_prep_inplace = if check_inplace(outer(backend)) + prepare_pullback( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) else - outer_pullback_prep = prepare_pullback( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) + nothing end - return ReverseOverReverseHVPPrep(grad_buffer, outer_pullback_prep) + outer_pullback_prep = prepare_pullback( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + return ReverseOverReverseHVPPrep( + grad_buffer, outer_pullback_prep, outer_pullback_prep_inplace + ) end function hvp( @@ -592,26 +534,14 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pullback_prep) = prep + (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(outer(backend)) - return pullback( - shuffled_gradient!, - grad_buffer, - outer_pullback_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - else - return pullback( - shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... - ) - end + return pullback( + shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) end function hvp!( @@ -623,7 +553,7 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pullback_prep) = prep + (; grad_buffer, outer_pullback_prep, outer_pullback_prep_inplace) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -633,7 +563,7 @@ function hvp!( shuffled_gradient!, grad_buffer, tg, - outer_pullback_prep, + outer_pullback_prep_inplace, outer(backend), x, tx, @@ -660,27 +590,14 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pullback_prep) = prep + (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - if check_inplace(outer(backend)) - y, tg = value_and_pullback( - shuffled_gradient!, - grad_buffer, - outer_pullback_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copy(y), tg - else - return value_and_pullback( - shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... - ) - end + return value_and_pullback( + shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) end function gradient_and_hvp!( @@ -693,7 +610,7 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; outer_pullback_prep, outer_pullback_prep_inplace) = prep rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... @@ -703,7 +620,7 @@ function gradient_and_hvp!( shuffled_gradient!, grad, tg, - outer_pullback_prep, + outer_pullback_prep_inplace, outer(backend), x, tx, From 50617924df2aa66e542c8fdecdad61c2728231d2 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 14:45:33 +0100 Subject: [PATCH 04/24] Prepare inner gradient in HVP --- ...fferentiationInterfaceChainRulesCoreExt.jl | 1 + .../reverse_onearg.jl | 14 +- .../DifferentiationInterfaceDiffractorExt.jl | 1 + .../DifferentiationInterfaceEnzymeExt.jl | 1 + ...ntiationInterfaceFastDifferentiationExt.jl | 1 + .../DifferentiationInterfaceFiniteDiffExt.jl | 1 + ...rentiationInterfaceFiniteDifferencesExt.jl | 1 + .../DifferentiationInterfaceForwardDiffExt.jl | 2 +- .../misc.jl | 73 ++++ .../onearg.jl | 137 ++---- .../secondorder.jl | 157 ------- .../twoarg.jl | 48 +- .../utils.jl | 9 +- .../DifferentiationInterfaceGTPSAExt.jl | 1 + .../DifferentiationInterfaceMooncakeExt.jl | 1 + ...tiationInterfacePolyesterForwardDiffExt.jl | 1 + .../DifferentiationInterfaceReverseDiffExt.jl | 1 + .../DifferentiationInterfaceSymbolicsExt.jl | 1 + .../DifferentiationInterfaceTrackerExt.jl | 29 +- .../DifferentiationInterfaceZygoteExt.jl | 75 ++-- .../src/DifferentiationInterface.jl | 1 + .../src/fallbacks/input.jl | 70 +++ .../src/misc/from_primitive.jl | 3 + .../src/misc/simple_finite_diff.jl | 1 + .../src/misc/zero_backends.jl | 2 + .../src/second_order/hvp.jl | 412 +++++------------- DifferentiationInterface/src/utils/context.jl | 65 ++- DifferentiationInterface/src/utils/traits.jl | 11 + 28 files changed, 442 insertions(+), 678 deletions(-) delete mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl create mode 100644 DifferentiationInterface/src/fallbacks/input.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index 9ce083126..250360e84 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -19,6 +19,7 @@ const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} DI.check_available(::AutoChainRules) = true DI.check_inplace(::AutoChainRules) = false +DI.check_operator_overloading(::AutoChainRules) = false include("reverse_onearg.jl") include("differentiate_with.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 04d3f55f1..079e5dd1a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -6,11 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, - ::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoPullbackPrep() end @@ -21,7 +17,7 @@ function DI.prepare_pullback_same_point( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -34,7 +30,7 @@ function DI.value_and_pullback( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -50,7 +46,7 @@ function DI.value_and_pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -65,7 +61,7 @@ function DI.pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; pb) = prep tx = map(ty) do dy diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 42569a19d..24a260032 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -6,6 +6,7 @@ using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, DI.check_available(::AutoDiffractor) = true DI.check_inplace(::AutoDiffractor) = false +DI.check_operator_overloading(::AutoDiffractor) = false DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 328bffaf3..7a01c61dd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -44,6 +44,7 @@ using Enzyme: onehot DI.check_available(::AutoEnzyme) = true +DI.check_operator_overloading(::AutoEnzyme) = false include("utils.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 430cc95f9..acd61b1be 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -17,6 +17,7 @@ using LinearAlgebra: dot using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoFastDifferentiation) = true +DI.operator_overloading(::AutoFastDifferentiation) = false myvec(x::Number) = [x] myvec(x::AbstractArray) = vec(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl index 4970eb96e..ce578b02d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl @@ -18,6 +18,7 @@ using FiniteDiff: using LinearAlgebra: dot, mul! DI.check_available(::AutoFiniteDiff) = true +DI.check_operator_overloading(::AutoFiniteDiff) = false # see https://github.com/SciML/ADTypes.jl/issues/33 diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index bd43c65d8..588143289 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -7,6 +7,7 @@ using LinearAlgebra: dot DI.check_available(::AutoFiniteDifferences) = true DI.check_inplace(::AutoFiniteDifferences) = false +DI.operator_overloading(::AutoFiniteDifferences) = false ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index d4f8570b0..2ff24032d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -26,11 +26,11 @@ using ForwardDiff: value DI.check_available(::AutoForwardDiff) = true +DI.check_operator_overloading(::AutoForwardDiff) = true include("utils.jl") include("onearg.jl") include("twoarg.jl") -include("secondorder.jl") include("differentiate_with.jl") include("misc.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index d2b76a6d1..cb25dcd56 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -1,16 +1,89 @@ ## Pushforward + +function DI.overloaded_input_type( + ::typeof(DI.pushforward), + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + return DI.overloaded_input_type(DI.prepare_pushforward(f, backend, x, tx, contexts...)) +end + +function DI.overloaded_input_type( + ::typeof(DI.pushforward), + f!::F, + y, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + return DI.overloaded_input_type( + DI.prepare_pushforward(f!, y, backend, x, tx, contexts...) + ) +end + DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp) DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) ## Derivative + +function DI.overloaded_input_type( + ::typeof(DI.derivative), + f::F, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}, +) where {F,C} + return DI.overloaded_input_type(DI.prepare_derivative(f, backend, x, contexts...)) +end + +function DI.overloaded_input_type( + ::typeof(DI.derivative), + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}, +) where {F,C} + return DI.overloaded_input_type(DI.prepare_derivative(f!, y, backend, x, contexts...)) +end + function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) return DI.overloaded_input_type(prep.pushforward_prep) end DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals) ## Gradient + +function DI.overloaded_input_type( + ::typeof(DI.gradient), f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} +) where {F,C} + return DI.overloaded_input_type(DI.prepare_gradient(f, backend, x, contexts...)) +end + DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals) ## Jacobian + +function DI.overloaded_input_type( + ::typeof(DI.jacobian), f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} +) where {F,C} + return DI.overloaded_input_type(DI.prepare_jacobian(f, backend, x, contexts...)) +end + +function DI.overloaded_input_type( + ::typeof(DI.jacobian), + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Vararg{DI.Context,C}, +) where {F,C} + return DI.overloaded_input_type(DI.prepare_jacobian(f!, y, backend, x, contexts...)) +end + DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2]) DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2]) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index f50f030a7..9cd778b7a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -170,7 +170,7 @@ struct ForwardDiffOneArgDerivativePrep{E} <: DI.DerivativePrep end function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) return ForwardDiffOneArgDerivativePrep(pushforward_prep) @@ -181,7 +181,7 @@ function DI.value_and_derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -195,7 +195,7 @@ function DI.value_and_derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -208,7 +208,7 @@ function DI.derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) @@ -221,7 +221,7 @@ function DI.derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der @@ -236,7 +236,7 @@ function DI.value_and_gradient!( grad, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -255,7 +255,7 @@ function DI.value_and_gradient( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -273,7 +273,7 @@ function DI.gradient!( grad, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -288,7 +288,7 @@ function DI.gradient( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -309,7 +309,7 @@ function DI.prepare_gradient( f::F, backend::AutoForwardDiff, x::AbstractArray, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -324,7 +324,7 @@ function DI.value_and_gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -340,7 +340,7 @@ function DI.value_and_gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = GradientResult(x) @@ -355,7 +355,7 @@ function DI.gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -367,7 +367,7 @@ function DI.gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -383,7 +383,7 @@ function DI.value_and_jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -403,7 +403,7 @@ function DI.value_and_jacobian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -419,7 +419,7 @@ function DI.jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -434,7 +434,7 @@ function DI.jacobian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -452,7 +452,7 @@ struct ForwardDiffOneArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -467,7 +467,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -484,7 +484,7 @@ function DI.value_and_jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -497,7 +497,7 @@ function DI.jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -509,7 +509,7 @@ function DI.jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -519,7 +519,7 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} return DI.NoSecondDerivativePrep() end @@ -529,7 +529,7 @@ function DI.second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -544,7 +544,7 @@ function DI.second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -558,7 +558,7 @@ function DI.value_derivative_and_second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -577,7 +577,7 @@ function DI.value_derivative_and_second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -589,69 +589,6 @@ function DI.value_derivative_and_second_derivative!( return y, der, der2 end -## HVP - -function DI.prepare_hvp( - f::F, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.prepare_hvp(f, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.hvp(f, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.hvp!(f, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) -end - -function DI.gradient_and_hvp( - f::F, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::DI.HVPPrep, - backend::AutoForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, -) where {F,C} - return DI.gradient_and_hvp!( - f, grad, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts... - ) -end - ## Hessian ### Unprepared, only when chunk size and tag are not specified @@ -661,7 +598,7 @@ function DI.hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -676,7 +613,7 @@ function DI.hessian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -693,7 +630,7 @@ function DI.value_gradient_and_hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -713,7 +650,7 @@ function DI.value_gradient_and_hessian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -734,7 +671,7 @@ struct ForwardDiffHessianPrep{C1,C2} <: DI.HessianPrep end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -751,7 +688,7 @@ function DI.hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -763,7 +700,7 @@ function DI.hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -777,7 +714,7 @@ function DI.value_gradient_and_hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -794,7 +731,7 @@ function DI.value_gradient_and_hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl deleted file mode 100644 index 39c066c98..000000000 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ /dev/null @@ -1,157 +0,0 @@ -struct ForwardDiffOverSomethingHVPPrep{ - G,E1<:DI.GradientPrep,E2<:DI.PushforwardPrep,E2IP<:DI.PushforwardPrep -} <: DI.HVPPrep - grad_buffer::G - inner_gradient_prep::E1 - outer_pushforward_prep::E2 - outer_pushforward_prep_inplace::E2IP -end - -function DI.prepare_hvp( - f::F, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - T = tag_type(DI.shuffled_gradient, DI.outer(backend), x) - xdual = make_dual(T, x, tx) - inner_gradient_prep = DI.prepare_gradient(f, DI.inner(backend), xdual, contexts...) - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - grad_buffer = similar(x) - outer_pushforward_prep = DI.prepare_pushforward( - DI.shuffled_gradient, DI.outer(backend), x, tx, new_contexts... - ) - outer_pushforward_prep_inplace = DI.prepare_pushforward( - DI.shuffled_gradient!, grad_buffer, DI.outer(backend), x, tx, new_contexts... - ) - return ForwardDiffOverSomethingHVPPrep( - grad_buffer, - inner_gradient_prep, - outer_pushforward_prep, - outer_pushforward_prep_inplace, - ) -end - -function DI.hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.hvp!( - f::F, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; grad_buffer, inner_gradient_prep, outer_pushforward_prep_inplace) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.pushforward!( - DI.shuffled_gradient!, - grad_buffer, - tg, - outer_pushforward_prep_inplace, - DI.outer(backend), - x, - tx, - new_contexts..., - ) - return tg -end - -function DI.gradient_and_hvp( - f::F, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.value_and_pushforward( - DI.shuffled_gradient, - outer_pushforward_prep, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end - -function DI.gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardDiffOverSomethingHVPPrep, - backend::DI.SecondOrder{<:AutoForwardDiff}, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {F,C} - (; inner_gradient_prep, outer_pushforward_prep_inplace) = prep - rewrap = DI.Rewrap(contexts...) - new_contexts = ( - DI.FunctionContext(f), - PrepContext(inner_gradient_prep), - DI.BackendContext(DI.inner(backend)), - DI.Constant(rewrap), - contexts..., - ) - return DI.value_and_pushforward!( - DI.shuffled_gradient!, - grad, - tg, - outer_pushforward_prep_inplace, - DI.outer(backend), - x, - tx, - new_contexts..., - ) -end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 7289c5379..d6987e571 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -115,7 +115,7 @@ function DI.value_and_derivative( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -134,7 +134,7 @@ function DI.value_and_derivative!( der, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -152,7 +152,7 @@ function DI.derivative( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -169,7 +169,7 @@ function DI.derivative!( der, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -187,11 +187,7 @@ struct ForwardDiffTwoArgDerivativePrep{C} <: DI.DerivativePrep end function DI.prepare_derivative( - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc! = DI.with_contexts(f!, contexts...) tag = get_tag(fc!, backend, x) @@ -205,7 +201,7 @@ function DI.prepare!_derivative( old_prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} if y isa Vector (; config) = old_prep @@ -222,7 +218,7 @@ function DI.value_and_derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) @@ -238,7 +234,7 @@ function DI.value_and_derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) @@ -253,7 +249,7 @@ function DI.derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -267,7 +263,7 @@ function DI.derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -283,7 +279,7 @@ function DI.value_and_jacobian( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -303,7 +299,7 @@ function DI.value_and_jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -321,7 +317,7 @@ function DI.jacobian( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -338,7 +334,7 @@ function DI.jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -356,11 +352,7 @@ struct ForwardDiffTwoArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {F,C} fc! = DI.with_contexts(f!, contexts...) chunk = choose_chunk(backend, x) @@ -375,7 +367,7 @@ function DI.prepare!_jacobian( old_prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} if x isa Vector && y isa Vector (; config) = old_prep @@ -394,7 +386,7 @@ function DI.value_and_jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -411,7 +403,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) @@ -426,7 +418,7 @@ function DI.jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -440,7 +432,7 @@ function DI.jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 2c53d17e2..7e08abb5b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -77,15 +77,10 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} return ty end -# store preparation result with the right input eltype -struct PrepContext{T<:DI.Prep} <: DI.Context - data::T -end - -function _translate(::Type{T}, ::Val{B}, c::DI.ConstantOrFunctionOrBackend) where {T,B} +function _translate(::Type{T}, ::Val{B}, c::DI.GeneralizedConstant) where {T,B} return DI.unwrap(c) end -_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = DI.unwrap(c) +_translate(::Type{T}, ::Val{B}, c::DI.PrepContext) where {T,B} = DI.unwrap(c) function _translate(::Type{T}, ::Val{B}, c::DI.Cache) where {T,B} c0 = DI.unwrap(c) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl index b2d24dac3..533ba6961 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/DifferentiationInterfaceGTPSAExt.jl @@ -5,6 +5,7 @@ using ADTypes: AutoGTPSA using GTPSA: GTPSA, TPS, Descriptor DI.check_available(::AutoGTPSA) = true +DI.check_operator_overloading(::AutoGTPSA) = true include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 667aadee9..22d658fed 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -14,6 +14,7 @@ using Mooncake: Mooncake DI.check_available(::AutoMooncake) = true +DI.check_operator_overloading(::AutoMooncake) = true copyto!!(dst::Number, src::Number) = convert(typeof(dst), src) copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeof(dst), src) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index 0efbc9695..c6650991a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -12,6 +12,7 @@ function single_threaded(backend::AutoPolyesterForwardDiff{chunksize,T}) where { end DI.check_available(::AutoPolyesterForwardDiff) = true +DI.check_operator_overloading(::AutoPolyesterForwardDiff) = true function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, x::AbstractArray) return DI.pick_batchsize(single_threaded(backend), x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl index 061765e0f..461e987d1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl @@ -25,6 +25,7 @@ using ReverseDiff: jacobian! DI.check_available(::AutoReverseDiff) = true +DI.check_operator_overloading(::AutoReverseDiff) = true include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 2dc8d0018..0e3711e80 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -19,6 +19,7 @@ using Symbolics: using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoSymbolics) = true +DI.check_operator_overloading(::AutoSymbolics) = false DI.pullback_performance(::AutoSymbolics) = DI.PullbackSlow() dense_ad(backend::AutoSymbolics) = backend diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index b95ed22f0..244d8342f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -6,6 +6,7 @@ using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgrad DI.check_available(::AutoTracker) = true DI.check_inplace(::AutoTracker) = false +DI.check_operator_overloading(::AutoTracker) = true ## Pullback @@ -15,7 +16,7 @@ struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoPullbackPrep() end @@ -26,7 +27,7 @@ function DI.prepare_pullback_same_point( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(y, pb) @@ -38,7 +39,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -53,7 +54,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -68,7 +69,7 @@ function DI.pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -80,28 +81,20 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoTracker, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, - ::DI.NoGradientPrep, - ::AutoTracker, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, - ::DI.NoGradientPrep, - ::AutoTracker, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) @@ -113,7 +106,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -125,7 +118,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 7f5b004ea..307804f87 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -26,6 +26,7 @@ check_nothing(::Any, f, x, contexts) = nothing DI.check_available(::AutoZygote) = true DI.check_inplace(::AutoZygote) = false +DI.check_operator_overloading(::AutoZygote) = false ## Pullback @@ -35,7 +36,7 @@ struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoPullbackPrep() end @@ -46,7 +47,7 @@ function DI.prepare_pullback_same_point( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) return ZygotePullbackPrepSamePoint(y, pb) @@ -58,7 +59,7 @@ function DI.value_and_pullback( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -74,7 +75,7 @@ function DI.value_and_pullback( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -90,7 +91,7 @@ function DI.pullback( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -103,17 +104,13 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, - ::DI.NoGradientPrep, - ::AutoZygote, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) check_nothing(first(grad), f, x, contexts) @@ -121,11 +118,7 @@ function DI.value_and_gradient( end function DI.gradient( - f, - ::DI.NoGradientPrep, - ::AutoZygote, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} grad = gradient(f, x, map(DI.unwrap, contexts)...) check_nothing(first(grad), f, x, contexts) @@ -138,7 +131,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -150,7 +143,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end @@ -158,17 +151,13 @@ end ## Jacobian function DI.prepare_jacobian( - f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoJacobianPrep() end function DI.value_and_jacobian( - f, - ::DI.NoJacobianPrep, - ::AutoZygote, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} y = f(x, map(DI.unwrap, contexts)...) # https://github.com/FluxML/Zygote.jl/issues/1506 @@ -178,11 +167,7 @@ function DI.value_and_jacobian( end function DI.jacobian( - f, - ::DI.NoJacobianPrep, - ::AutoZygote, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} jac = jacobian(f, x, map(DI.unwrap, contexts)...) check_nothing(first(jac), f, x, contexts) @@ -195,7 +180,7 @@ function DI.value_and_jacobian!( prep::DI.NoJacobianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) @@ -207,7 +192,7 @@ function DI.jacobian!( prep::DI.NoJacobianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -217,11 +202,7 @@ end # Beware, this uses ForwardDiff for the inner differentiation function DI.prepare_hvp( - f, - backend::AutoZygote, - x, - tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.prepare_hvp(f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -232,7 +213,7 @@ function DI.hvp( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -244,7 +225,7 @@ function DI.hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return DI.hvp!( f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -257,7 +238,7 @@ function DI.gradient_and_hvp( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -272,7 +253,7 @@ function DI.gradient_and_hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -282,17 +263,13 @@ end ## Hessian function DI.prepare_hessian( - f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoHessianPrep() end function DI.hessian( - f, - ::DI.NoHessianPrep, - ::AutoZygote, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoHessianPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} fc = DI.with_contexts(f, contexts...) hess = hessian(fc, x) @@ -306,7 +283,7 @@ function DI.hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end @@ -316,7 +293,7 @@ function DI.value_gradient_and_hessian( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) @@ -330,7 +307,7 @@ function DI.value_gradient_and_hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index f8d1f6ca4..59e23aaef 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -57,6 +57,7 @@ include("second_order/hessian.jl") include("fallbacks/no_prep.jl") include("fallbacks/change_prep.jl") +include("fallbacks/input.jl") include("misc/differentiate_with.jl") include("misc/from_primitive.jl") diff --git a/DifferentiationInterface/src/fallbacks/input.jl b/DifferentiationInterface/src/fallbacks/input.jl new file mode 100644 index 000000000..61237cce9 --- /dev/null +++ b/DifferentiationInterface/src/fallbacks/input.jl @@ -0,0 +1,70 @@ +function error_if_overloading(backend) + if check_operator_overloading(backend) + throw( + ArgumentError( + "The current backend is based on operator overloading, a custom method for `overloaded_input_type` is therefore necessary. Please open an issue on DifferentiationInterface.jl if you encounter this error.", + ), + ) + end +end + +for op in [ + :derivative, + :gradient, + :jacobian, + :second_derivative, + :hessian, + :pushforward, + :pullback, + :hvp, +] + if op in (:derivative, :jacobian, :gradient) + @eval function overloaded_input_type( + ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + error_if_overloading(backend) + return typeof(x) + end + op == :gradient && continue + @eval function overloaded_input_type( + ::typeof($op), f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + error_if_overloading(backend) + return typeof(x) + end + + elseif op in (:second_derivative, :hessian) + @eval function overloaded_input_type( + ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} + ) where {F,C} + error_if_overloading(backend) + return typeof(x) + end + + elseif op in (:pushforward, :pullback, :hvp) + @eval function overloaded_input_type( + ::typeof($op), + f::F, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + error_if_overloading(backend) + return typeof(x) + end + op == :hvp && continue + @eval function overloaded_input_type( + ::typeof($op), + f!::F, + y, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + error_if_overloading(backend) + return typeof(x) + end + end +end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 5645443fd..d3878c638 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -2,6 +2,9 @@ abstract type FromPrimitive <: AbstractADType end check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) check_inplace(fromprim::FromPrimitive) = check_inplace(fromprim.backend) +function check_operator_overloading(fromprim::FromPrimitive) + return check_operator_overloading(fromprim.backend) +end function pick_batchsize(fromprim::FromPrimitive, N::Integer) return pick_batchsize(fromprim.backend, N) diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index 6bdbd0e84..74f1b80f2 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -17,6 +17,7 @@ end ADTypes.mode(::AutoSimpleFiniteDiff) = ForwardMode() check_available(::AutoSimpleFiniteDiff) = true +check_operator_overloading(::AutoSimpleFiniteDiff) = false function pick_batchsize(::AutoSimpleFiniteDiff{nothing}, N::Integer) B = reasonable_batchsize(N, 12) diff --git a/DifferentiationInterface/src/misc/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl index 4c253cfab..41a2d285e 100644 --- a/DifferentiationInterface/src/misc/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -18,6 +18,7 @@ struct AutoZeroForward <: AbstractADType end ADTypes.mode(::AutoZeroForward) = ForwardMode() check_available(::AutoZeroForward) = true +check_operator_overloading(::AutoZeroForward) = false function prepare_pushforward( f::F, ::AutoZeroForward, x, tx::NTuple, contexts::Vararg{Context,C} @@ -103,6 +104,7 @@ struct AutoZeroReverse <: AbstractADType end ADTypes.mode(::AutoZeroReverse) = ReverseMode() check_available(::AutoZeroReverse) = true +check_operator_overloading(::AutoZeroReverse) = false function prepare_pullback( f::F, ::AutoZeroReverse, x, ty::NTuple, contexts::Vararg{Context,C} diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index fd0101ed1..b1e7276be 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -74,17 +74,15 @@ function prepare_hvp( return _prepare_hvp_aux(hvp_mode(backend), f, backend, x, tx, contexts...) end -## Forward over forward +## Forward over anything -struct ForwardOverForwardHVPPrep{G,E2<:PushforwardPrep,E2IP} <: HVPPrep - grad_buffer::G - # pushforward of many pushforwards in theory, but pushforward of gradient in practice +struct ForwardOverAnythingHVPPrep{E1<:GradientPrep,E2<:PushforwardPrep} <: HVPPrep + inner_gradient_prep::E1 outer_pushforward_prep::E2 - outer_pushforward_prep_inplace::E2IP end function _prepare_hvp_aux( - ::ForwardOverForward, + ::Union{ForwardOverForward,ForwardOverReverse}, f::F, backend::AbstractADType, x, @@ -92,188 +90,47 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - grad_buffer = similar(x) - outer_pushforward_prep_inplace = if check_inplace(outer(backend)) - prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... - ) - else - nothing - end - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) - return ForwardOverForwardHVPPrep( - grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace - ) -end - -function hvp( - f::F, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) -end - -function hvp!( - f::F, - tg::NTuple, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - if check_inplace(outer(backend)) - return pushforward!( - shuffled_gradient!, - grad_buffer, - tg, - outer_pushforward_prep_inplace, - outer(backend), - x, - tx, - new_contexts..., - ) - else - return pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - end -end - -function gradient_and_hvp( - f::F, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + new_contexts_unknown_prep = ( + FunctionContext(f), + UnknownContext(), # placeholder for inner_gradient_prep + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) - return value_and_pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... + XO = overloaded_input_type( + pushforward, shuffled_gradient, outer(backend), x, tx, new_contexts_unknown_prep... ) -end - -function gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep, outer_pushforward_prep_inplace) = prep - rewrap = Rewrap(contexts...) + xo = XO(x) + inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - if check_inplace(outer(backend)) - return value_and_pushforward!( - shuffled_gradient!, - grad, - tg, - outer_pushforward_prep_inplace, - outer(backend), - x, - tx, - new_contexts..., - ) - else - new_grad, _ = value_and_pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg - end -end - -## Forward over reverse - -struct ForwardOverReverseHVPPrep{G,E2<:PushforwardPrep,E2IP} <: HVPPrep - grad_buffer::G - # pushforward of gradient - outer_pushforward_prep::E2 - outer_pushforward_prep_inplace::E2IP -end - -function _prepare_hvp_aux( - ::ForwardOverReverse, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) - grad_buffer = similar(x) - outer_pushforward_prep_inplace = if check_inplace(outer(backend)) - prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... - ) - else - nothing - end outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - return ForwardOverReverseHVPPrep( - grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace - ) + return ForwardOverAnythingHVPPrep(inner_gradient_prep, outer_pushforward_prep) end function hvp( f::F, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... @@ -283,53 +140,48 @@ end function hvp!( f::F, tg::NTuple, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_prep, outer_pushforward_prep_inplace) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + return pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., ) - if check_inplace(outer(backend)) - return pushforward!( - shuffled_gradient!, - grad_buffer, - tg, - outer_pushforward_prep_inplace, - outer(backend), - x, - tx, - new_contexts..., - ) - else - return pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - end end function gradient_and_hvp( f::F, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... @@ -340,46 +192,36 @@ function gradient_and_hvp!( f::F, grad, tg::NTuple, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep, outer_pushforward_prep_inplace) = prep + (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) - if check_inplace(outer(backend)) - return value_and_pushforward!( - shuffled_gradient!, - grad, - tg, - outer_pushforward_prep_inplace, - outer(backend), - x, - tx, - new_contexts..., - ) - else - new_grad, _ = value_and_pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg - end + new_grad, _ = value_and_pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return copyto!(grad, new_grad), tg end ## Reverse over forward struct ReverseOverForwardHVPPrep{E2<:GradientPrep,E1<:GradientPrep} <: HVPPrep - # gradient of pushforward outer_gradient_prep::E2 gradient_prep::E1 end @@ -491,11 +333,9 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{G,E2<:PullbackPrep,E2IP} <: HVPPrep - grad_buffer::G - # pullback of gradient +struct ReverseOverReverseHVPPrep{E1<:GradientPrep,E2<:PullbackPrep} <: HVPPrep + inner_gradient_prep::E1 outer_pullback_prep::E2 - outer_pullback_prep_inplace::E2IP end function _prepare_hvp_aux( @@ -507,23 +347,29 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) + new_contexts_unknown_prep = ( + FunctionContext(f), + UnknownContext(), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + XO = overloaded_input_type( + pullback, shuffled_gradient, backend, x, tx, new_contexts_unknown_prep... + ) + xo = XO(x) + inner_gradient_prep = prepare_gradient(f, backend, xo, contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) - grad_buffer = similar(x) - outer_pullback_prep_inplace = if check_inplace(outer(backend)) - prepare_pullback( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... - ) - else - nothing - end outer_pullback_prep = prepare_pullback( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - return ReverseOverReverseHVPPrep( - grad_buffer, outer_pullback_prep, outer_pullback_prep_inplace - ) + return ReverseOverReverseHVPPrep(inner_gradient_prep, outer_pullback_prep) end function hvp( @@ -534,10 +380,14 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -553,33 +403,18 @@ function hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pullback_prep, outer_pullback_prep_inplace) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + return pullback!( + shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) - if check_inplace(outer(backend)) - return pullback!( - shuffled_gradient!, - grad_buffer, - tg, - outer_pullback_prep_inplace, - outer(backend), - x, - tx, - new_contexts..., - ) - else - return pullback!( - shuffled_gradient, - tg, - outer_pullback_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - end end function gradient_and_hvp( @@ -590,10 +425,14 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return value_and_pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -610,32 +449,17 @@ function gradient_and_hvp!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep, outer_pullback_prep_inplace) = prep + (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) - if check_inplace(outer(backend)) - return value_and_pullback!( - shuffled_gradient!, - grad, - tg, - outer_pullback_prep_inplace, - outer(backend), - x, - tx, - new_contexts..., - ) - else - new_grad, _ = value_and_pullback!( - shuffled_gradient, - tg, - outer_pullback_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg - end + new_grad, _ = value_and_pullback!( + shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) + return copyto!(grad, new_grad), tg end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 15a9d4a0e..178d83247 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -1,12 +1,3 @@ -struct FixTail{F,A<:Tuple} - f::F - tail_args::A -end - -function (ft::FixTail)(args::Vararg{Any,N}) where {N} - return ft.f(args..., ft.tail_args...) -end - """ Context @@ -22,12 +13,15 @@ abstract type Context end unwrap(c::Context) = c.data Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) +abstract type GeneralizedConstant <: Context end +abstract type GeneralizedCache <: Context end + ## Public contexts """ Constant -Concrete type of [`Context`](@ref) argument which is kept constant during differentiation. +Concrete subtype of [`Context`](@ref) argument which is kept constant during differentiation. Note that an operator can be prepared with an arbitrary value of the constant. However, same-point preparation must occur with the exact value that will be reused later. @@ -55,7 +49,7 @@ julia> gradient(f, AutoForwardDiff(), [1.0, 2.0], Constant(100)) 400.0 ``` """ -struct Constant{T} <: Context +struct Constant{T} <: GeneralizedConstant data::T end @@ -65,7 +59,7 @@ maker(::Constant) = constant_maker """ Cache -Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation. +Concrete subtype of [`Context`](@ref) argument which can be mutated with active values during differentiation. The initial values present inside the cache do not matter. @@ -81,15 +75,47 @@ maker(::Cache) = cache_maker ## Internal contexts for passing stuff around -struct FunctionContext{T} <: Context +""" + FunctionContext + +Concrete subtype of [`Context`](@ref) argument designed to contain functions, for internal use only. + +It is mostly similar to [`Constant`](@ref). +""" +struct FunctionContext{T} <: GeneralizedConstant + data::T +end + +""" + BackendContext + +Concrete subtype of [`Context`](@ref) argument designed to contain backend objects, for internal use only. + +It is mostly similar to [`Constant`](@ref). +""" +struct BackendContext{T<:AbstractADType} <: GeneralizedConstant data::T end -struct BackendContext{T} <: Context +""" + PrepContext <: Context + +Concrete subtype of [`Context`](@ref) argument designed to contain preparation objects, for internal use only. + +It is mostly similar to [`Cache`](@ref). +""" +struct PrepContext{T<:Prep} <: GeneralizedCache data::T end -const ConstantOrFunctionOrBackend = Union{Constant,FunctionContext,BackendContext} +""" + UnknownContext <: Context + +Concrete subtype of [`Context`](@ref) argument designed as a placeholder for when a future context value is not yet known, for internal use only. + +It is relevant in second-order preparation. +""" +struct UnknownContext <: Context end ## Context manipulation @@ -109,6 +135,15 @@ function (r::Rewrap{C,T})(unannotated_contexts::Vararg{Any,C}) where {C,T} end end +struct FixTail{F,A<:Tuple} + f::F + tail_args::A +end + +function (ft::FixTail)(args::Vararg{Any,N}) where {N} + return ft.f(args..., ft.tail_args...) +end + with_contexts(f) = f function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index 797dbb0dd..9f89e1953 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -40,6 +40,17 @@ function check_inplace(backend::MixedMode) check_inplace(reverse_backend(backend)) end +## Operator overloading + +""" + check_operator_overloading(backend) + +Check whether backend relies on operator overloading. + +Returns `true` or `false` in a statically predictable way. +""" +function check_operator_overloading end + ## Pushforward abstract type PushforwardPerformance end From 0acc1116a18bb57e753e9916f2e6d630f39dbc91 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 15:39:33 +0100 Subject: [PATCH 05/24] ReverseDiff fix --- .../misc.jl | 52 +++++++++++-------- .../utils.jl | 19 +++++-- .../jacobian.jl | 4 +- .../src/DifferentiationInterface.jl | 1 - .../src/fallbacks/input.jl | 32 ++++++++---- .../src/misc/overloading.jl | 9 ---- .../src/second_order/hvp.jl | 8 +-- .../test/Back/ForwardDiff/test.jl | 14 ++--- .../test/Back/ReverseDiff/test.jl | 8 +-- 9 files changed, 84 insertions(+), 63 deletions(-) delete mode 100644 DifferentiationInterface/src/misc/overloading.jl diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index cb25dcd56..85c208301 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -1,6 +1,6 @@ ## Pushforward -function DI.overloaded_input_type( +function DI.overloaded_input_example( ::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, @@ -8,10 +8,12 @@ function DI.overloaded_input_type( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - return DI.overloaded_input_type(DI.prepare_pushforward(f, backend, x, tx, contexts...)) + return DI.overloaded_input_example( + DI.prepare_pushforward(f, backend, x, tx, contexts...) + ) end -function DI.overloaded_input_type( +function DI.overloaded_input_example( ::typeof(DI.pushforward), f!::F, y, @@ -20,27 +22,27 @@ function DI.overloaded_input_type( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - return DI.overloaded_input_type( + return DI.overloaded_input_example( DI.prepare_pushforward(f!, y, backend, x, tx, contexts...) ) end -DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp) -DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) +DI.overloaded_input_example(prep::ForwardDiffOneArgPushforwardPrep) = copy(prep.xdual_tmp) +DI.overloaded_input_example(prep::ForwardDiffTwoArgPushforwardPrep) = copy(prep.xdual_tmp) ## Derivative -function DI.overloaded_input_type( +function DI.overloaded_input_example( ::typeof(DI.derivative), f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {F,C} - return DI.overloaded_input_type(DI.prepare_derivative(f, backend, x, contexts...)) + return DI.overloaded_input_example(DI.prepare_derivative(f, backend, x, contexts...)) end -function DI.overloaded_input_type( +function DI.overloaded_input_example( ::typeof(DI.derivative), f!::F, y, @@ -48,33 +50,35 @@ function DI.overloaded_input_type( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - return DI.overloaded_input_type(DI.prepare_derivative(f!, y, backend, x, contexts...)) + return DI.overloaded_input_example( + DI.prepare_derivative(f!, y, backend, x, contexts...) + ) end -function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) - return DI.overloaded_input_type(prep.pushforward_prep) +function DI.overloaded_input_example(prep::ForwardDiffOneArgDerivativePrep) + return DI.overloaded_input_example(prep.pushforward_prep) end -DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals) +DI.overloaded_input_example(prep::ForwardDiffTwoArgDerivativePrep) = copy(prep.config.duals) ## Gradient -function DI.overloaded_input_type( +function DI.overloaded_input_example( ::typeof(DI.gradient), f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} - return DI.overloaded_input_type(DI.prepare_gradient(f, backend, x, contexts...)) + return DI.overloaded_input_example(DI.prepare_gradient(f, backend, x, contexts...)) end -DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals) +DI.overloaded_input_example(prep::ForwardDiffGradientPrep) = copy(prep.config.duals) ## Jacobian -function DI.overloaded_input_type( +function DI.overloaded_input_example( ::typeof(DI.jacobian), f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} - return DI.overloaded_input_type(DI.prepare_jacobian(f, backend, x, contexts...)) + return DI.overloaded_input_example(DI.prepare_jacobian(f, backend, x, contexts...)) end -function DI.overloaded_input_type( +function DI.overloaded_input_example( ::typeof(DI.jacobian), f!::F, y, @@ -82,8 +86,12 @@ function DI.overloaded_input_type( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - return DI.overloaded_input_type(DI.prepare_jacobian(f!, y, backend, x, contexts...)) + return DI.overloaded_input_example(DI.prepare_jacobian(f!, y, backend, x, contexts...)) end -DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2]) -DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2]) +function DI.overloaded_input_example(prep::ForwardDiffOneArgJacobianPrep) + return copy(prep.config.duals[2]) +end +function DI.overloaded_input_example(prep::ForwardDiffTwoArgJacobianPrep) + return copy(prep.config.duals[2]) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl index 2631469fc..dccbb547e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl @@ -1,6 +1,19 @@ +## Pullback + +function DI.overloaded_input_example( + ::typeof(DI.pullback), + f, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + return copy(ReverseDiff.track(x)) +end + ## Gradient -DI.overloaded_input_type(prep::ReverseDiffGradientPrep) = typeof(prep.config.input) +DI.overloaded_input_example(prep::ReverseDiffGradientPrep) = copy(prep.config.input) ## Jacobian -DI.overloaded_input_type(prep::ReverseDiffOneArgJacobianPrep) = typeof(prep.config.input) -DI.overloaded_input_type(prep::ReverseDiffTwoArgJacobianPrep) = typeof(prep.config.input) +DI.overloaded_input_example(prep::ReverseDiffOneArgJacobianPrep) = copy(prep.config.input) +DI.overloaded_input_example(prep::ReverseDiffTwoArgJacobianPrep) = copy(prep.config.input) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 17928d44c..f8a32b06c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -335,6 +335,6 @@ end ## Operator overloading -function DI.overloaded_input_type(prep::PushforwardSparseJacobianPrep) - return DI.overloaded_input_type(prep.pushforward_prep) +function DI.overloaded_input_example(prep::PushforwardSparseJacobianPrep) + return DI.overloaded_input_example(prep.pushforward_prep) end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 59e23aaef..57a11bf78 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -64,7 +64,6 @@ include("misc/from_primitive.jl") include("misc/sparsity_detector.jl") include("misc/simple_finite_diff.jl") include("misc/zero_backends.jl") -include("misc/overloading.jl") ## Exported diff --git a/DifferentiationInterface/src/fallbacks/input.jl b/DifferentiationInterface/src/fallbacks/input.jl index 61237cce9..c4a821a4e 100644 --- a/DifferentiationInterface/src/fallbacks/input.jl +++ b/DifferentiationInterface/src/fallbacks/input.jl @@ -1,8 +1,18 @@ +""" + overloaded_input_example(prep) + +If it exists, return an example of overloaded input which will be passed to the differentiated function when preparation result `prep` is reused. + +!!! danger + This function is experimental and not part of the public API. +""" +function overloaded_input_example end + function error_if_overloading(backend) if check_operator_overloading(backend) throw( ArgumentError( - "The current backend is based on operator overloading, a custom method for `overloaded_input_type` is therefore necessary. Please open an issue on DifferentiationInterface.jl if you encounter this error.", + "The current backend is based on operator overloading, a custom method for `overloaded_input_example` is therefore necessary. Please open an issue on DifferentiationInterface.jl if you encounter this error.", ), ) end @@ -19,30 +29,30 @@ for op in [ :hvp, ] if op in (:derivative, :jacobian, :gradient) - @eval function overloaded_input_type( + @eval function overloaded_input_example( ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} error_if_overloading(backend) - return typeof(x) + return copy(x) end op == :gradient && continue - @eval function overloaded_input_type( + @eval function overloaded_input_example( ::typeof($op), f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} error_if_overloading(backend) - return typeof(x) + return copy(x) end elseif op in (:second_derivative, :hessian) - @eval function overloaded_input_type( + @eval function overloaded_input_example( ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} error_if_overloading(backend) - return typeof(x) + return copy(x) end elseif op in (:pushforward, :pullback, :hvp) - @eval function overloaded_input_type( + @eval function overloaded_input_example( ::typeof($op), f::F, backend::AbstractADType, @@ -51,10 +61,10 @@ for op in [ contexts::Vararg{Context,C}, ) where {F,C} error_if_overloading(backend) - return typeof(x) + return copy(x) end op == :hvp && continue - @eval function overloaded_input_type( + @eval function overloaded_input_example( ::typeof($op), f!::F, y, @@ -64,7 +74,7 @@ for op in [ contexts::Vararg{Context,C}, ) where {F,C} error_if_overloading(backend) - return typeof(x) + return copy(x) end end end diff --git a/DifferentiationInterface/src/misc/overloading.jl b/DifferentiationInterface/src/misc/overloading.jl deleted file mode 100644 index bda61a192..000000000 --- a/DifferentiationInterface/src/misc/overloading.jl +++ /dev/null @@ -1,9 +0,0 @@ -""" - overloaded_input_type(prep) - -If it exists, return the overloaded input type which will be passed to the differentiated function when preparation result `prep` is reused. - -!!! danger - This function is experimental and not part of the public API. -""" -function overloaded_input_type end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index b1e7276be..cb4b1877b 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -97,10 +97,10 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - XO = overloaded_input_type( + xo = overloaded_input_example( pushforward, shuffled_gradient, outer(backend), x, tx, new_contexts_unknown_prep... ) - xo = XO(x) + copyto!(xo, x) inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...) new_contexts = ( FunctionContext(f), @@ -354,10 +354,10 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - XO = overloaded_input_type( + xo = overloaded_input_example( pullback, shuffled_gradient, backend, x, tx, new_contexts_unknown_prep... ) - xo = XO(x) + copyto!(xo, x) inner_gradient_prep = prepare_gradient(f, backend, xo, contexts...) new_contexts = ( FunctionContext(f), diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 061836a81..d83d8bd2d 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -112,23 +112,23 @@ end; # Derivative x = 1.0 y = [1.0, 1.0] - @test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == + @test DI.overloaded_input_example(prepare_derivative(copy, backend, x)) isa ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,1} - @test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) == + @test DI.overloaded_input_example(prepare_derivative(copyto!, y, backend, x)) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} # Gradient x = [1.0, 1.0] - @test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) == + @test DI.overloaded_input_example(prepare_gradient(sum, backend, x)) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(sum),Float64},Float64,2}} # Jacobian x = [1.0, 0.0, 0.0] - @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == + @test DI.overloaded_input_example(prepare_jacobian(copy, backend, x)) isa ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,3} - @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == + @test DI.overloaded_input_example(prepare_jacobian(copyto!, similar(x), backend, x)) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,3}} - @test DI.overloaded_input_type( + @test DI.overloaded_input_example( prepare_jacobian(copyto!, similar(x), sparse_backend, x) - ) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} + ) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} end; diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index 8549de0e6..1cebc7ae7 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -45,17 +45,17 @@ test_differentiation( # Derivative x = 1.0 - @test_skip DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == + @test_skip DI.overloaded_input_example(prepare_derivative(copy, backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Gradient x = [1.0; 0.0; 0.0] - @test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) == + @test DI.overloaded_input_example(prepare_gradient(sum, backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Jacobian - @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == + @test DI.overloaded_input_example(prepare_jacobian(copy, backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} - @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == + @test DI.overloaded_input_example(prepare_jacobian(copyto!, similar(x), backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} end; From fc0b8b9f57a74fdf4e7322aa30271f050409af6e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:07:37 +0100 Subject: [PATCH 06/24] Separate overloaded_input and overloaded_input_type --- .../misc.jl | 86 +++++-------------- .../utils.jl | 20 ++++- .../jacobian.jl | 4 +- .../src/fallbacks/input.jl | 20 +++-- .../src/second_order/hvp.jl | 6 +- .../test/Back/ForwardDiff/test.jl | 15 ++-- .../test/Back/ReverseDiff/test.jl | 8 +- .../test/Back/SparsityDetector/test.jl | 3 + 8 files changed, 68 insertions(+), 94 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index 85c208301..e7793391e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -1,6 +1,6 @@ ## Pushforward -function DI.overloaded_input_example( +function DI.overloaded_input( ::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, @@ -8,12 +8,12 @@ function DI.overloaded_input_example( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - return DI.overloaded_input_example( - DI.prepare_pushforward(f, backend, x, tx, contexts...) - ) + T = tag_type(f, backend, x) + xdual = make_dual(T, x, tx) + return xdual end -function DI.overloaded_input_example( +function DI.overloaded_input( ::typeof(DI.pushforward), f!::F, y, @@ -22,76 +22,36 @@ function DI.overloaded_input_example( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - return DI.overloaded_input_example( - DI.prepare_pushforward(f!, y, backend, x, tx, contexts...) - ) + T = tag_type(f, backend, x) + xdual = if x isa Number + make_dual(T, x, tx) + else + make_dual_similar(T, x, tx) + end + return xdual end -DI.overloaded_input_example(prep::ForwardDiffOneArgPushforwardPrep) = copy(prep.xdual_tmp) -DI.overloaded_input_example(prep::ForwardDiffTwoArgPushforwardPrep) = copy(prep.xdual_tmp) +DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp) +DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) ## Derivative -function DI.overloaded_input_example( - ::typeof(DI.derivative), - f::F, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.overloaded_input_example(DI.prepare_derivative(f, backend, x, contexts...)) -end - -function DI.overloaded_input_example( - ::typeof(DI.derivative), - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.overloaded_input_example( - DI.prepare_derivative(f!, y, backend, x, contexts...) - ) +function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) + return DI.overloaded_input(prep.pushforward_prep) end - -function DI.overloaded_input_example(prep::ForwardDiffOneArgDerivativePrep) - return DI.overloaded_input_example(prep.pushforward_prep) +function DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) + return typeof(prep.config.duals) end -DI.overloaded_input_example(prep::ForwardDiffTwoArgDerivativePrep) = copy(prep.config.duals) ## Gradient -function DI.overloaded_input_example( - ::typeof(DI.gradient), f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {F,C} - return DI.overloaded_input_example(DI.prepare_gradient(f, backend, x, contexts...)) -end - -DI.overloaded_input_example(prep::ForwardDiffGradientPrep) = copy(prep.config.duals) +DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals) ## Jacobian -function DI.overloaded_input_example( - ::typeof(DI.jacobian), f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} -) where {F,C} - return DI.overloaded_input_example(DI.prepare_jacobian(f, backend, x, contexts...)) -end - -function DI.overloaded_input_example( - ::typeof(DI.jacobian), - f!::F, - y, - backend::AutoForwardDiff, - x, - contexts::Vararg{DI.Context,C}, -) where {F,C} - return DI.overloaded_input_example(DI.prepare_jacobian(f!, y, backend, x, contexts...)) -end - -function DI.overloaded_input_example(prep::ForwardDiffOneArgJacobianPrep) - return copy(prep.config.duals[2]) +function DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) + return typeof(prep.config.duals[2]) end -function DI.overloaded_input_example(prep::ForwardDiffTwoArgJacobianPrep) - return copy(prep.config.duals[2]) +function DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) + return typeof(prep.config.duals[2]) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl index dccbb547e..9ea057974 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl @@ -1,6 +1,6 @@ ## Pullback -function DI.overloaded_input_example( +function DI.overloaded_input( ::typeof(DI.pullback), f, ::AutoReverseDiff, @@ -11,9 +11,21 @@ function DI.overloaded_input_example( return copy(ReverseDiff.track(x)) end +function DI.overloaded_input( + ::typeof(DI.pullback), + f!, + y, + ::AutoReverseDiff, + x::AbstractArray, + ty::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + return copy(ReverseDiff.track(x)) +end + ## Gradient -DI.overloaded_input_example(prep::ReverseDiffGradientPrep) = copy(prep.config.input) +DI.overloaded_input_type(prep::ReverseDiffGradientPrep) = typeof(prep.config.input) ## Jacobian -DI.overloaded_input_example(prep::ReverseDiffOneArgJacobianPrep) = copy(prep.config.input) -DI.overloaded_input_example(prep::ReverseDiffTwoArgJacobianPrep) = copy(prep.config.input) +DI.overloaded_input_type(prep::ReverseDiffOneArgJacobianPrep) = typeof(prep.config.input) +DI.overloaded_input_type(prep::ReverseDiffTwoArgJacobianPrep) = typeof(prep.config.input) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index f8a32b06c..67883ea00 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -335,6 +335,6 @@ end ## Operator overloading -function DI.overloaded_input_example(prep::PushforwardSparseJacobianPrep) - return DI.overloaded_input_example(prep.pushforward_prep) +function DI.overloaded_input(prep::PushforwardSparseJacobianPrep) + return DI.overloaded_input(prep.pushforward_prep) end diff --git a/DifferentiationInterface/src/fallbacks/input.jl b/DifferentiationInterface/src/fallbacks/input.jl index c4a821a4e..971c8be33 100644 --- a/DifferentiationInterface/src/fallbacks/input.jl +++ b/DifferentiationInterface/src/fallbacks/input.jl @@ -1,18 +1,20 @@ """ - overloaded_input_example(prep) + overloaded_input(prep) -If it exists, return an example of overloaded input which will be passed to the differentiated function when preparation result `prep` is reused. +If it can be deduced, return the overloaded input which will be passed to the differentiated function when preparation result `prep` is reused. !!! danger This function is experimental and not part of the public API. """ -function overloaded_input_example end +function overloaded_input end + +function overloaded_input_type end function error_if_overloading(backend) if check_operator_overloading(backend) throw( ArgumentError( - "The current backend is based on operator overloading, a custom method for `overloaded_input_example` is therefore necessary. Please open an issue on DifferentiationInterface.jl if you encounter this error.", + "The current backend is based on operator overloading, a custom method for `overloaded_input` is therefore necessary. Please open an issue on DifferentiationInterface.jl if you encounter this error.", ), ) end @@ -29,14 +31,14 @@ for op in [ :hvp, ] if op in (:derivative, :jacobian, :gradient) - @eval function overloaded_input_example( + @eval function overloaded_input( ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} error_if_overloading(backend) return copy(x) end op == :gradient && continue - @eval function overloaded_input_example( + @eval function overloaded_input( ::typeof($op), f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} error_if_overloading(backend) @@ -44,7 +46,7 @@ for op in [ end elseif op in (:second_derivative, :hessian) - @eval function overloaded_input_example( + @eval function overloaded_input( ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} error_if_overloading(backend) @@ -52,7 +54,7 @@ for op in [ end elseif op in (:pushforward, :pullback, :hvp) - @eval function overloaded_input_example( + @eval function overloaded_input( ::typeof($op), f::F, backend::AbstractADType, @@ -64,7 +66,7 @@ for op in [ return copy(x) end op == :hvp && continue - @eval function overloaded_input_example( + @eval function overloaded_input( ::typeof($op), f!::F, y, diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index cb4b1877b..3e6cdc579 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -97,10 +97,9 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - xo = overloaded_input_example( + xo = overloaded_input( pushforward, shuffled_gradient, outer(backend), x, tx, new_contexts_unknown_prep... ) - copyto!(xo, x) inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...) new_contexts = ( FunctionContext(f), @@ -354,10 +353,9 @@ function _prepare_hvp_aux( Constant(rewrap), contexts..., ) - xo = overloaded_input_example( + xo = overloaded_input( pullback, shuffled_gradient, backend, x, tx, new_contexts_unknown_prep... ) - copyto!(xo, x) inner_gradient_prep = prepare_gradient(f, backend, xo, contexts...) new_contexts = ( FunctionContext(f), diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index d83d8bd2d..5c659e7ec 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -112,23 +112,22 @@ end; # Derivative x = 1.0 y = [1.0, 1.0] - @test DI.overloaded_input_example(prepare_derivative(copy, backend, x)) isa + @test DI.overloaded_input(prepare_derivative(copy, backend, x)) isa ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,1} - @test DI.overloaded_input_example(prepare_derivative(copyto!, y, backend, x)) isa + @test DI.overloaded_input(prepare_derivative(copyto!, y, backend, x)) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} # Gradient x = [1.0, 1.0] - @test DI.overloaded_input_example(prepare_gradient(sum, backend, x)) isa + @test DI.overloaded_input(prepare_gradient(sum, backend, x)) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(sum),Float64},Float64,2}} # Jacobian x = [1.0, 0.0, 0.0] - @test DI.overloaded_input_example(prepare_jacobian(copy, backend, x)) isa + @test DI.overloaded_input(prepare_jacobian(copy, backend, x)) isa ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,3} - @test DI.overloaded_input_example(prepare_jacobian(copyto!, similar(x), backend, x)) isa + @test DI.overloaded_input(prepare_jacobian(copyto!, similar(x), backend, x)) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,3}} - @test DI.overloaded_input_example( - prepare_jacobian(copyto!, similar(x), sparse_backend, x) - ) isa Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} + @test DI.overloaded_input(prepare_jacobian(copyto!, similar(x), sparse_backend, x)) isa + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} end; diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index 1cebc7ae7..ba985331a 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -45,17 +45,17 @@ test_differentiation( # Derivative x = 1.0 - @test_skip DI.overloaded_input_example(prepare_derivative(copy, backend, x)) isa + @test_skip DI.overloaded_input(prepare_derivative(copy, backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Gradient x = [1.0; 0.0; 0.0] - @test DI.overloaded_input_example(prepare_gradient(sum, backend, x)) isa + @test DI.overloaded_input(prepare_gradient(sum, backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Jacobian - @test DI.overloaded_input_example(prepare_jacobian(copy, backend, x)) isa + @test DI.overloaded_input(prepare_jacobian(copy, backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} - @test DI.overloaded_input_example(prepare_jacobian(copyto!, similar(x), backend, x)) isa + @test DI.overloaded_input(prepare_jacobian(copyto!, similar(x), backend, x)) isa ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} end; diff --git a/DifferentiationInterface/test/Back/SparsityDetector/test.jl b/DifferentiationInterface/test/Back/SparsityDetector/test.jl index 1ea6a975a..560e954f4 100644 --- a/DifferentiationInterface/test/Back/SparsityDetector/test.jl +++ b/DifferentiationInterface/test/Back/SparsityDetector/test.jl @@ -12,6 +12,9 @@ using Test rng = StableRNG(63) +using Random +rng = Random.default_rng() + const Jc = sprand(rng, Bool, 10, 20, 0.3) const Hc = sparse(Symmetric(sprand(rng, Bool, 20, 20, 0.3))) From 9549ec48d1d62cc5ce1bc8d3e563192d01fd2492 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:08:00 +0100 Subject: [PATCH 07/24] No docstring --- DifferentiationInterface/src/fallbacks/input.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/DifferentiationInterface/src/fallbacks/input.jl b/DifferentiationInterface/src/fallbacks/input.jl index 971c8be33..38fb3a847 100644 --- a/DifferentiationInterface/src/fallbacks/input.jl +++ b/DifferentiationInterface/src/fallbacks/input.jl @@ -1,13 +1,4 @@ -""" - overloaded_input(prep) - -If it can be deduced, return the overloaded input which will be passed to the differentiated function when preparation result `prep` is reused. - -!!! danger - This function is experimental and not part of the public API. -""" function overloaded_input end - function overloaded_input_type end function error_if_overloading(backend) From 0c36c4d5ba04823bbee6a41e6057e79f2de0d2be Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:16:59 +0100 Subject: [PATCH 08/24] Enzyme --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index a6ad7cbc4..b64beee17 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f) end @inline function _translate( - backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache + backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.GeneralizedCache ) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) From 5973ff8b4280cf332017838db7c92279e7ead045 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:50:34 +0100 Subject: [PATCH 09/24] typos --- .../DifferentiationInterfaceFiniteDifferencesExt.jl | 2 +- DifferentiationInterface/test/Back/ChainRules/zygote.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 588143289..79969f08d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -7,7 +7,7 @@ using LinearAlgebra: dot DI.check_available(::AutoFiniteDifferences) = true DI.check_inplace(::AutoFiniteDifferences) = false -DI.operator_overloading(::AutoFiniteDifferences) = false +DI.check_operator_overloading(::AutoFiniteDifferences) = false ## Pushforward diff --git a/DifferentiationInterface/test/Back/ChainRules/zygote.jl b/DifferentiationInterface/test/Back/ChainRules/zygote.jl index ef54db928..a77668e55 100644 --- a/DifferentiationInterface/test/Back/ChainRules/zygote.jl +++ b/DifferentiationInterface/test/Back/ChainRules/zygote.jl @@ -19,7 +19,7 @@ end test_differentiation( AutoChainRules(ZygoteRuleConfig()), default_scenarios(); - excluded=[:second_derivative], + excluded=SECOND_ORDER, logging=LOGGING, ); From 6d46464fd4b86f61f939c45e30e64d2fd3d96ecb Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:56:16 +0100 Subject: [PATCH 10/24] Zygote fix --- .../DifferentiationInterfaceZygoteExt.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 59074f606..97280aee6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -184,7 +184,12 @@ function DI.prepare_hvp( end function DI.hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.ForwardOverAnythingHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, ) where {C} return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -192,7 +197,7 @@ end function DI.hvp!( f, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -204,7 +209,12 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C} + f, + prep::DI.ForwardOverAnythingHVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, ) where {C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -215,7 +225,7 @@ function DI.gradient_and_hvp!( f, grad, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, From ea0209ec98450f49180dfabdb0131bab7e3ef571 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:57:00 +0100 Subject: [PATCH 11/24] No fail fast --- .github/workflows/Test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 2c6b6f5fd..c2ef71dda 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true # TODO: toggle + fail-fast: false # TODO: toggle matrix: version: - "1.10" From 124341e905aa23d8a2995bd66408f1a1b7fc17f3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 17:27:25 +0100 Subject: [PATCH 12/24] typos --- .../DifferentiationInterfaceFastDifferentiationExt.jl | 2 +- .../onearg.jl | 8 ++++---- DifferentiationInterface/test/Back/ReverseDiff/test.jl | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index acd61b1be..9b47937ec 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -17,7 +17,7 @@ using LinearAlgebra: dot using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoFastDifferentiation) = true -DI.operator_overloading(::AutoFastDifferentiation) = false +DI.check_operator_overloading(::AutoFastDifferentiation) = false myvec(x::Number) = [x] myvec(x::AbstractArray) = vec(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 8f67ad78d..6ec715c78 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -298,7 +298,7 @@ end function DI.hvp( f, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, @@ -312,7 +312,7 @@ end function DI.hvp!( f, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, @@ -325,7 +325,7 @@ end function DI.gradient_and_hvp( f, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, @@ -340,7 +340,7 @@ function DI.gradient_and_hvp!( f, grad, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index ba985331a..8549de0e6 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -45,17 +45,17 @@ test_differentiation( # Derivative x = 1.0 - @test_skip DI.overloaded_input(prepare_derivative(copy, backend, x)) isa + @test_skip DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Gradient x = [1.0; 0.0; 0.0] - @test DI.overloaded_input(prepare_gradient(sum, backend, x)) isa + @test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) == ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} # Jacobian - @test DI.overloaded_input(prepare_jacobian(copy, backend, x)) isa + @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} - @test DI.overloaded_input(prepare_jacobian(copyto!, similar(x), backend, x)) isa + @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == ReverseDiff.TrackedArray{Float64,Float64,1,Vector{Float64},Vector{Float64}} end; From 9f1d40ebca90d4254677fa2c1e69d60648904b2d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 17:48:31 +0100 Subject: [PATCH 13/24] ReverseDiff --- .../DifferentiationInterfaceReverseDiffExt.jl | 3 ++- .../ext/DifferentiationInterfaceReverseDiffExt/onearg.jl | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl index 461e987d1..ce23d24c6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl @@ -22,7 +22,8 @@ using ReverseDiff: hessian, hessian!, jacobian, - jacobian! + jacobian!, + value DI.check_available(::AutoReverseDiff) = true DI.check_operator_overloading(::AutoReverseDiff) = true diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 13285670c..3147d7329 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -15,7 +15,7 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {C} fc = DI.with_contexts(f, contexts...) - y = fc(x) + y = value(fc(x)) dotclosure(z, dy) = dot(fc(z), dy) tx = map(ty) do dy if y isa Number @@ -37,7 +37,7 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {C} fc = DI.with_contexts(f, contexts...) - y = fc(x) + y = value(fc(x)) dotclosure(z, dy) = dot(fc(z), dy) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] From 542594e86b3f828628e02a7c6a1e3a169af29221 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 20:47:35 +0100 Subject: [PATCH 14/24] Fixes --- .../onearg.jl | 4 +- .../utils.jl | 4 +- .../src/second_order/hvp.jl | 121 ++++++++++++------ DifferentiationInterface/src/utils/context.jl | 2 +- .../test/Back/ChainRules/zygote.jl | 5 - .../test/Back/Diffractor/test.jl | 5 - .../test/Back/Enzyme/test.jl | 7 - .../test/Back/FiniteDiff/test.jl | 5 - .../test/Back/FiniteDifferences/test.jl | 5 - .../test/Back/ForwardDiff/test.jl | 28 ++-- .../test/Back/GTPSA/test.jl | 5 - .../test/Back/Mooncake/test.jl | 5 - .../test/Back/PolyesterForwardDiff/test.jl | 5 - .../test/Back/ReverseDiff/test.jl | 5 - .../SymbolicBackends/fastdifferentiation.jl | 5 - .../test/Back/SymbolicBackends/symbolics.jl | 5 - .../test/Back/Tracker/test.jl | 5 - .../test/Back/Zygote/test.jl | 5 - .../src/test_differentiation.jl | 1 + 19 files changed, 102 insertions(+), 125 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 3147d7329..13285670c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -15,7 +15,7 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {C} fc = DI.with_contexts(f, contexts...) - y = value(fc(x)) + y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) tx = map(ty) do dy if y isa Number @@ -37,7 +37,7 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {C} fc = DI.with_contexts(f, contexts...) - y = value(fc(x)) + y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl index 9ea057974..fbaf0ee23 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl @@ -8,7 +8,7 @@ function DI.overloaded_input( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return copy(ReverseDiff.track(x)) + return nothing end function DI.overloaded_input( @@ -20,7 +20,7 @@ function DI.overloaded_input( ty::NTuple, contexts::Vararg{DI.Context,C}, ) where {C} - return copy(ReverseDiff.track(x)) + return nothing end ## Gradient diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 3e6cdc579..416a39547 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -332,7 +332,8 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{E1<:GradientPrep,E2<:PullbackPrep} <: HVPPrep +struct ReverseOverReverseHVPPrep{E1<:Union{Nothing,GradientPrep},E2<:PullbackPrep} <: + HVPPrep inner_gradient_prep::E1 outer_pullback_prep::E2 end @@ -356,14 +357,24 @@ function _prepare_hvp_aux( xo = overloaded_input( pullback, shuffled_gradient, backend, x, tx, new_contexts_unknown_prep... ) - inner_gradient_prep = prepare_gradient(f, backend, xo, contexts...) - new_contexts = ( - FunctionContext(f), - PrepContext(inner_gradient_prep), - BackendContext(inner(backend)), - Constant(rewrap), - contexts..., - ) + if isnothing(xo) + inner_gradient_prep = nothing + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + inner_gradient_prep = prepare_gradient(f, backend, xo, contexts...) + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end outer_pullback_prep = prepare_pullback( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -380,13 +391,22 @@ function hvp( ) where {F,C} (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), - PrepContext(inner_gradient_prep), - BackendContext(inner(backend)), - Constant(rewrap), - contexts..., - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end return pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -403,13 +423,22 @@ function hvp!( ) where {F,C} (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), - PrepContext(inner_gradient_prep), - BackendContext(inner(backend)), - Constant(rewrap), - contexts..., - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end return pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -425,13 +454,22 @@ function gradient_and_hvp( ) where {F,C} (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), - PrepContext(inner_gradient_prep), - BackendContext(inner(backend)), - Constant(rewrap), - contexts..., - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end return value_and_pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -449,13 +487,22 @@ function gradient_and_hvp!( ) where {F,C} (; inner_gradient_prep, outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), - PrepContext(inner_gradient_prep), - BackendContext(inner(backend)), - Constant(rewrap), - contexts..., - ) + if isnothing(inner_gradient_prep) + new_contexts = ( + FunctionContext(f), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + else + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + end new_grad, _ = value_and_pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 178d83247..4ab112054 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -66,7 +66,7 @@ The initial values present inside the cache do not matter. !!! warning Most backends require any `Cache` context to be an `AbstractArray`. """ -struct Cache{T} <: Context +struct Cache{T} <: GeneralizedCache data::T end diff --git a/DifferentiationInterface/test/Back/ChainRules/zygote.jl b/DifferentiationInterface/test/Back/ChainRules/zygote.jl index a77668e55..67ae40dd9 100644 --- a/DifferentiationInterface/test/Back/ChainRules/zygote.jl +++ b/DifferentiationInterface/test/Back/ChainRules/zygote.jl @@ -11,11 +11,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoChainRules(ZygoteRuleConfig())] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoChainRules(ZygoteRuleConfig()), default_scenarios(); diff --git a/DifferentiationInterface/test/Back/Diffractor/test.jl b/DifferentiationInterface/test/Back/Diffractor/test.jl index 08315b9c3..d90475db2 100644 --- a/DifferentiationInterface/test/Back/Diffractor/test.jl +++ b/DifferentiationInterface/test/Back/Diffractor/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoDiffractor()] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoDiffractor(), default_scenarios(; linalg=false); diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index f4803ccb3..989f28bf2 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -34,13 +34,6 @@ duplicated_backends = [ AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Duplicated), ] -@testset "Checks" begin - @testset "Check $(typeof(backend))" for backend in backends - @test check_available(backend) - @test check_inplace(backend) - end -end; - @testset "First order" begin test_differentiation( backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index 5276d6233..370d50046 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -12,11 +12,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoFiniteDiff()] - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( AutoFiniteDiff(), default_scenarios(; include_constantified=true, include_cachified=true); diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index cb83ab0cf..692feb5c8 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1))] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), default_scenarios(; include_constantified=true, include_cachified=true); diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 5c659e7ec..ef98bb049 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -23,11 +23,6 @@ backends = [ AutoForwardDiff(; tag=ForwardDiff.Tag(MyTag(), Float64)), ] -for backend in backends - @test check_available(backend) - @test check_inplace(backend) -end - ## Dense test_differentiation( @@ -78,14 +73,14 @@ test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) @testset verbose = true "StaticArrays" begin @testset "Batch size" begin - @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7} - @test DI.pick_batchsize(AutoForwardDiff(; chunksize=5), rand(7)) isa + @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) == DI.BatchSizeSettings{7} + @test DI.pick_batchsize(AutoForwardDiff(; chunksize=5), rand(7)) == DI.BatchSizeSettings{5} - @test (@inferred DI.pick_batchsize(AutoForwardDiff(), @SVector(rand(7)))) isa + @test (@inferred DI.pick_batchsize(AutoForwardDiff(), @SVector(rand(7)))) == DI.BatchSizeSettings{7} @test (@inferred DI.pick_batchsize( AutoForwardDiff(; chunksize=5), @SVector(rand(7)) - )) isa DI.BatchSizeSettings{5} + )) == DI.BatchSizeSettings{5} end filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen @@ -112,22 +107,23 @@ end; # Derivative x = 1.0 y = [1.0, 1.0] - @test DI.overloaded_input(prepare_derivative(copy, backend, x)) isa + @test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,1} - @test DI.overloaded_input(prepare_derivative(copyto!, y, backend, x)) isa + @test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} # Gradient x = [1.0, 1.0] - @test DI.overloaded_input(prepare_gradient(sum, backend, x)) isa + @test DI.overloaded_input_type(prepare_gradient(sum, backend, x)) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(sum),Float64},Float64,2}} # Jacobian x = [1.0, 0.0, 0.0] - @test DI.overloaded_input(prepare_jacobian(copy, backend, x)) isa + @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy),Float64},Float64,3} - @test DI.overloaded_input(prepare_jacobian(copyto!, similar(x), backend, x)) isa + @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,3}} - @test DI.overloaded_input(prepare_jacobian(copyto!, similar(x), sparse_backend, x)) isa - Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} + @test DI.overloaded_input_type( + prepare_jacobian(copyto!, similar(x), sparse_backend, x) + ) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!),Float64},Float64,1}} end; diff --git a/DifferentiationInterface/test/Back/GTPSA/test.jl b/DifferentiationInterface/test/Back/GTPSA/test.jl index 68b66918e..16d13bb74 100644 --- a/DifferentiationInterface/test/Back/GTPSA/test.jl +++ b/DifferentiationInterface/test/Back/GTPSA/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoGTPSA()] - @test check_available(backend) - @test check_inplace(backend) -end - # Test no Descriptor (use context) test_differentiation( AutoGTPSA(), diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 35df1002c..eb83a03a2 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -12,11 +12,6 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())] -for backend in backends - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( backends, default_scenarios(; include_constantified=true, include_cachified=true); diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 21fb3d297..2a888574a 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -17,11 +17,6 @@ backends = [ AutoPolyesterForwardDiff(; chunksize=2), ] -for backend in backends - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( backends, default_scenarios(; include_constantified=true); logging=LOGGING ); diff --git a/DifferentiationInterface/test/Back/ReverseDiff/test.jl b/DifferentiationInterface/test/Back/ReverseDiff/test.jl index 8549de0e6..d506913a8 100644 --- a/DifferentiationInterface/test/Back/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Back/ReverseDiff/test.jl @@ -16,11 +16,6 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] second_order_backends = [SecondOrder(AutoForwardDiff(), AutoReverseDiff())] -for backend in vcat(backends, second_order_backends) - @test check_available(backend) - @test check_inplace(backend) -end - ## Dense test_differentiation( diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl index 191025ae2..24c6475f1 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoFastDifferentiation(), AutoSparse(AutoFastDifferentiation())] - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( AutoFastDifferentiation(), filter(default_scenarios(; include_constantified=true, include_cachified=true)) do s diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl index 31f8316a0..b4ca3b119 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoSymbolics(), AutoSparse(AutoSymbolics())] - @test check_available(backend) - @test check_inplace(backend) -end - test_differentiation( AutoSymbolics(), default_scenarios(; include_constantified=true); logging=LOGGING ); diff --git a/DifferentiationInterface/test/Back/Tracker/test.jl b/DifferentiationInterface/test/Back/Tracker/test.jl index 0131a183e..557798ca0 100644 --- a/DifferentiationInterface/test/Back/Tracker/test.jl +++ b/DifferentiationInterface/test/Back/Tracker/test.jl @@ -10,11 +10,6 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -for backend in [AutoTracker()] - @test check_available(backend) - @test !check_inplace(backend) -end - test_differentiation( AutoTracker(), default_scenarios(; include_constantified=true); diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 4f7943b17..a115bc83a 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -17,11 +17,6 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [AutoZygote()] second_order_backends = [SecondOrder(AutoForwardDiff(), AutoZygote())] -for backend in vcat(backends, second_order_backends) - @test check_available(backend) - @test !check_inplace(backend) -end - ## Dense @testset "Dense" begin diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index 4ad566a35..2a6f3fcd8 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -117,6 +117,7 @@ function test_differentiation( @testset verbose = true "$title" begin @testset verbose = detailed "$backend" for (i, backend) in enumerate(backends) + @test DI.check_available(backend) filtered_scenarios = filter(s -> compatible(backend, s), scenarios) grouped_scenarios = group_by_operator(filtered_scenarios) @testset verbose = detailed "$op" for (j, (op, op_group)) in From a44bbd89f4265c36d9f970051d8dab6d01515f62 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 21:09:39 +0100 Subject: [PATCH 15/24] typo --- DifferentiationInterface/test/Core/ZeroBackends/test.jl | 5 ----- DifferentiationInterfaceTest/test/zero_backends.jl | 1 + 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 0571a73b3..836db05de 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -11,11 +11,6 @@ LOGGING = get(ENV, "CI", "false") == "false" zero_backends = [AutoZeroForward(), AutoZeroReverse()] -for backend in zero_backends - @test check_available(backend) - @test check_inplace(backend) -end - @testset "Type stability" begin test_differentiation( AutoZeroForward(), diff --git a/DifferentiationInterfaceTest/test/zero_backends.jl b/DifferentiationInterfaceTest/test/zero_backends.jl index e8f5ace6c..4e3051fcf 100644 --- a/DifferentiationInterfaceTest/test/zero_backends.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -44,6 +44,7 @@ data1 = benchmark_differentiation( struct FakeBackend <: ADTypes.AbstractADType end ADTypes.mode(::FakeBackend) = ADTypes.ForwardMode() +DifferentiationInterface.check_available(::FakeBackend) = true data2 = benchmark_differentiation( FakeBackend(), From abd36b8f8a41f0fc49f1136d209efbfc19f338af Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 21:40:12 +0100 Subject: [PATCH 16/24] test: Enzyme Cache testing --- .../utils.jl | 6 ++-- .../test/Back/Enzyme/test.jl | 9 +++++- .../src/scenarios/modify.jl | 30 +++++++++++-------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index a6ad7cbc4..12ce9ba6d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -57,9 +57,11 @@ end backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache ) where {B} if B == 1 - return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) + return DuplicatedNoNeed(DI.unwrap(c), make_zero(DI.unwrap(c))) else - return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B))) + return BatchDuplicatedNoNeed( + DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B)) + ) end end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index f4803ccb3..4528166b7 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -54,7 +54,7 @@ end; ) test_differentiation( - backends[2], + backends[2:3], default_scenarios(; include_normal=false, include_cachified=true); excluded=SECOND_ORDER, logging=LOGGING, @@ -68,6 +68,13 @@ end; ) end +test_differentiation( + AutoEnzyme(mode=Enzyme.Reverse), + default_scenarios(; include_normal=false, include_cachified=true); + excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :pushforward, :derivative), + logging=LOGGING, +) + #= # TODO: reactivate type stability tests diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 7d35ca335..47a0a7015 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -173,20 +173,24 @@ end Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") -function (sc::StoreInCache{:out})(x, y_cache) - y = sc.f(x) - if y isa Number - y_cache[1] = y - return y_cache[1] +function (sc::StoreInCache{:out})(x, x_cache) + if x isa Number + x_cache[1] = x + return sc.f(x_cache[1]) else - copyto!(y_cache, y) - return copy(y_cache) + copyto!(x_cache, x) + return sc.f(x_cache) end end -function (sc::StoreInCache{:in})(y, x, y_cache) - sc.f(y_cache, x) - copyto!(y, y_cache) +function (sc::StoreInCache{:in})(y, x, x_cache) + if x isa Number + x_cache[1] = x + sc.f(y, x_cache[1]) + else + copyto!(x_cache, x) + sc.f(y, x_cache) + end return nothing end @@ -199,10 +203,10 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f,) = scen @assert isempty(scen.contexts) cache_f = StoreInCache{pl_fun}(f) - y_cache = if scen.y isa Number - [myzero(scen.y)] + y_cache = if scen.x isa Number + [myzero(scen.x)] else - mysimilar(scen.y) + mysimilar(scen.x) end return Scenario{op,pl_op,pl_fun}( cache_f; From 5c14f2a78f74e00253bb974c0954685b3899a8f5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 21:50:28 +0100 Subject: [PATCH 17/24] typos --- .../ext/DifferentiationInterfaceForwardDiffExt/misc.jl | 2 +- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index e7793391e..261c302bc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -37,7 +37,7 @@ DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.x ## Derivative function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) - return DI.overloaded_input(prep.pushforward_prep) + return DI.overloaded_input_type(prep.pushforward_prep) end function DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) return typeof(prep.config.duals) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index ef98bb049..3a4b0f9d0 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -73,14 +73,14 @@ test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) @testset verbose = true "StaticArrays" begin @testset "Batch size" begin - @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) == DI.BatchSizeSettings{7} + @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7} @test DI.pick_batchsize(AutoForwardDiff(; chunksize=5), rand(7)) == DI.BatchSizeSettings{5} @test (@inferred DI.pick_batchsize(AutoForwardDiff(), @SVector(rand(7)))) == DI.BatchSizeSettings{7} @test (@inferred DI.pick_batchsize( AutoForwardDiff(; chunksize=5), @SVector(rand(7)) - )) == DI.BatchSizeSettings{5} + )) isa DI.BatchSizeSettings{5} end filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen From 4bfe3c8dd2f5a878dcc17c448cc4320272304675 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 22:14:54 +0100 Subject: [PATCH 18/24] Undo cache modif --- .../test/Back/Enzyme/test.jl | 9 +----- .../src/scenarios/modify.jl | 30 ++++++++----------- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index c81011f44..989f28bf2 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -47,7 +47,7 @@ duplicated_backends = [ ) test_differentiation( - backends[2:3], + backends[2], default_scenarios(; include_normal=false, include_cachified=true); excluded=SECOND_ORDER, logging=LOGGING, @@ -61,13 +61,6 @@ duplicated_backends = [ ) end -test_differentiation( - AutoEnzyme(mode=Enzyme.Reverse), - default_scenarios(; include_normal=false, include_cachified=true); - excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :pushforward, :derivative), - logging=LOGGING, -) - #= # TODO: reactivate type stability tests diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 47a0a7015..7d35ca335 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -173,24 +173,20 @@ end Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") -function (sc::StoreInCache{:out})(x, x_cache) - if x isa Number - x_cache[1] = x - return sc.f(x_cache[1]) +function (sc::StoreInCache{:out})(x, y_cache) + y = sc.f(x) + if y isa Number + y_cache[1] = y + return y_cache[1] else - copyto!(x_cache, x) - return sc.f(x_cache) + copyto!(y_cache, y) + return copy(y_cache) end end -function (sc::StoreInCache{:in})(y, x, x_cache) - if x isa Number - x_cache[1] = x - sc.f(y, x_cache[1]) - else - copyto!(x_cache, x) - sc.f(y, x_cache) - end +function (sc::StoreInCache{:in})(y, x, y_cache) + sc.f(y_cache, x) + copyto!(y, y_cache) return nothing end @@ -203,10 +199,10 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f,) = scen @assert isempty(scen.contexts) cache_f = StoreInCache{pl_fun}(f) - y_cache = if scen.x isa Number - [myzero(scen.x)] + y_cache = if scen.y isa Number + [myzero(scen.y)] else - mysimilar(scen.x) + mysimilar(scen.y) end return Scenario{op,pl_op,pl_fun}( cache_f; From 8226d283893e2e1e9c4aff2acbd6d21484fba849 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 22:58:19 +0100 Subject: [PATCH 19/24] Duplicated --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 215b61814..b64beee17 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -57,11 +57,9 @@ end backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.GeneralizedCache ) where {B} if B == 1 - return DuplicatedNoNeed(DI.unwrap(c), make_zero(DI.unwrap(c))) + return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) else - return BatchDuplicatedNoNeed( - DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B)) - ) + return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B))) end end From 0cbc32b73aabd2ab46dbd9c4634fcdb6a6f05ebf Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 31 Jan 2025 07:03:42 +0100 Subject: [PATCH 20/24] typos --- .../jacobian.jl | 4 ++-- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 67883ea00..17928d44c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -335,6 +335,6 @@ end ## Operator overloading -function DI.overloaded_input(prep::PushforwardSparseJacobianPrep) - return DI.overloaded_input(prep.pushforward_prep) +function DI.overloaded_input_type(prep::PushforwardSparseJacobianPrep) + return DI.overloaded_input_type(prep.pushforward_prep) end diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 3a4b0f9d0..ee65aafcc 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -74,9 +74,9 @@ test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) @testset verbose = true "StaticArrays" begin @testset "Batch size" begin @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7} - @test DI.pick_batchsize(AutoForwardDiff(; chunksize=5), rand(7)) == + @test DI.pick_batchsize(AutoForwardDiff(; chunksize=5), rand(7)) isa DI.BatchSizeSettings{5} - @test (@inferred DI.pick_batchsize(AutoForwardDiff(), @SVector(rand(7)))) == + @test (@inferred DI.pick_batchsize(AutoForwardDiff(), @SVector(rand(7)))) isa DI.BatchSizeSettings{7} @test (@inferred DI.pick_batchsize( AutoForwardDiff(; chunksize=5), @SVector(rand(7)) From 8b74f6fe7d87c1f99d5455f4a16de8694e4e9250 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 31 Jan 2025 08:07:47 +0100 Subject: [PATCH 21/24] increase coverage --- ...fferentiationInterfaceChainRulesCoreExt.jl | 1 - .../DifferentiationInterfaceDiffractorExt.jl | 1 - .../DifferentiationInterfaceEnzymeExt.jl | 1 - ...ntiationInterfaceFastDifferentiationExt.jl | 1 - .../DifferentiationInterfaceFiniteDiffExt.jl | 1 - ...rentiationInterfaceFiniteDifferencesExt.jl | 1 - .../misc.jl | 3 +- .../onearg.jl | 72 ------------------- .../utils.jl | 12 ---- .../DifferentiationInterfaceSymbolicsExt.jl | 1 - .../DifferentiationInterfaceTrackerExt.jl | 1 - .../src/fallbacks/input.jl | 69 +++--------------- .../src/first_order/gradient.jl | 4 ++ DifferentiationInterface/src/utils/traits.jl | 28 +++++++- .../test/Core/Internals/backends.jl | 13 ++++ .../test/Core/ZeroBackends/test.jl | 4 ++ 16 files changed, 61 insertions(+), 152 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index 250360e84..9ce083126 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -19,7 +19,6 @@ const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} DI.check_available(::AutoChainRules) = true DI.check_inplace(::AutoChainRules) = false -DI.check_operator_overloading(::AutoChainRules) = false include("reverse_onearg.jl") include("differentiate_with.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl index 24a260032..42569a19d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl @@ -6,7 +6,6 @@ using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, DI.check_available(::AutoDiffractor) = true DI.check_inplace(::AutoDiffractor) = false -DI.check_operator_overloading(::AutoDiffractor) = false DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow() ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 7a01c61dd..328bffaf3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -44,7 +44,6 @@ using Enzyme: onehot DI.check_available(::AutoEnzyme) = true -DI.check_operator_overloading(::AutoEnzyme) = false include("utils.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 9b47937ec..430cc95f9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -17,7 +17,6 @@ using LinearAlgebra: dot using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoFastDifferentiation) = true -DI.check_operator_overloading(::AutoFastDifferentiation) = false myvec(x::Number) = [x] myvec(x::AbstractArray) = vec(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl index ce578b02d..4970eb96e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl @@ -18,7 +18,6 @@ using FiniteDiff: using LinearAlgebra: dot, mul! DI.check_available(::AutoFiniteDiff) = true -DI.check_operator_overloading(::AutoFiniteDiff) = false # see https://github.com/SciML/ADTypes.jl/issues/33 diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 79969f08d..bd43c65d8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -7,7 +7,6 @@ using LinearAlgebra: dot DI.check_available(::AutoFiniteDifferences) = true DI.check_inplace(::AutoFiniteDifferences) = false -DI.check_operator_overloading(::AutoFiniteDifferences) = false ## Pushforward diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index 261c302bc..132156e33 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -12,7 +12,7 @@ function DI.overloaded_input( xdual = make_dual(T, x, tx) return xdual end - +#= function DI.overloaded_input( ::typeof(DI.pushforward), f!::F, @@ -30,6 +30,7 @@ function DI.overloaded_input( end return xdual end +=# DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp) DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 6ec715c78..6571a3ab9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -286,78 +286,6 @@ function DI.value_gradient_and_hessian!( ) end -## HVP - -function DI.prepare_hvp( - f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C} -) where {C} - return DI.prepare_hvp( - f, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp( - f, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.hvp!( - f, - tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.hvp!( - f, tg, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp( - f, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.gradient_and_hvp( - f, prep, DI.SecondOrder(single_threaded(backend), backend), x, tx, contexts... - ) -end - -function DI.gradient_and_hvp!( - f, - grad, - tg::NTuple, - prep::DI.ForwardOverAnythingHVPPrep, - backend::AutoPolyesterForwardDiff, - x, - tx::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return DI.gradient_and_hvp!( - f, - grad, - tg, - prep, - DI.SecondOrder(single_threaded(backend), backend), - x, - tx, - contexts..., - ) -end - ## Second derivative function DI.prepare_second_derivative( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl index fbaf0ee23..bf4782a4e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/utils.jl @@ -11,18 +11,6 @@ function DI.overloaded_input( return nothing end -function DI.overloaded_input( - ::typeof(DI.pullback), - f!, - y, - ::AutoReverseDiff, - x::AbstractArray, - ty::NTuple, - contexts::Vararg{DI.Context,C}, -) where {C} - return nothing -end - ## Gradient DI.overloaded_input_type(prep::ReverseDiffGradientPrep) = typeof(prep.config.input) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 0e3711e80..2dc8d0018 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -19,7 +19,6 @@ using Symbolics: using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoSymbolics) = true -DI.check_operator_overloading(::AutoSymbolics) = false DI.pullback_performance(::AutoSymbolics) = DI.PullbackSlow() dense_ad(backend::AutoSymbolics) = backend diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 244d8342f..db4998021 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -6,7 +6,6 @@ using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgrad DI.check_available(::AutoTracker) = true DI.check_inplace(::AutoTracker) = false -DI.check_operator_overloading(::AutoTracker) = true ## Pullback diff --git a/DifferentiationInterface/src/fallbacks/input.jl b/DifferentiationInterface/src/fallbacks/input.jl index 38fb3a847..a3162449d 100644 --- a/DifferentiationInterface/src/fallbacks/input.jl +++ b/DifferentiationInterface/src/fallbacks/input.jl @@ -11,63 +11,16 @@ function error_if_overloading(backend) end end -for op in [ - :derivative, - :gradient, - :jacobian, - :second_derivative, - :hessian, - :pushforward, - :pullback, - :hvp, -] - if op in (:derivative, :jacobian, :gradient) - @eval function overloaded_input( - ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - error_if_overloading(backend) - return copy(x) - end - op == :gradient && continue - @eval function overloaded_input( - ::typeof($op), f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - error_if_overloading(backend) - return copy(x) - end - - elseif op in (:second_derivative, :hessian) - @eval function overloaded_input( - ::typeof($op), f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} - ) where {F,C} - error_if_overloading(backend) - return copy(x) - end - - elseif op in (:pushforward, :pullback, :hvp) - @eval function overloaded_input( - ::typeof($op), - f::F, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - error_if_overloading(backend) - return copy(x) - end - op == :hvp && continue - @eval function overloaded_input( - ::typeof($op), - f!::F, - y, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - error_if_overloading(backend) - return copy(x) - end +for op in [:pushforward, :pullback] + @eval function overloaded_input( + ::typeof($op), + f::F, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + error_if_overloading(backend) + return copy(x) end end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index d689da5c1..f854fba1f 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -129,6 +129,7 @@ function shuffled_gradient( return gradient(f, backend, x, rewrap(unannotated_contexts...)...) end +#= function shuffled_gradient!( grad, x, @@ -139,6 +140,7 @@ function shuffled_gradient!( ) where {F,C} return gradient!(f, grad, backend, x, rewrap(unannotated_contexts...)...) end +=# function shuffled_gradient( x, @@ -151,6 +153,7 @@ function shuffled_gradient( return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) end +#= function shuffled_gradient!( grad, x, @@ -162,3 +165,4 @@ function shuffled_gradient!( ) where {F,C} return gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...) end +=# diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index 9f89e1953..c8eaeb852 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -26,6 +26,9 @@ end Check whether `backend` supports differentiation of in-place functions. Returns `true` or `false` in a statically predictable way. + +!!! warning + This function defaults to `true` if the backend is not loaded. """ check_inplace(::AbstractADType) = true @@ -48,8 +51,31 @@ end Check whether backend relies on operator overloading. Returns `true` or `false` in a statically predictable way. + +!!! warning + This function defaults to `false` if the backend is not loaded. """ -function check_operator_overloading end +check_operator_overloading(::AbstractADType) = false + +function check_operator_overloading(::SecondOrder) + return throw( + ArgumentError( + "Operator overloading check does not make sense for second-order backend" + ), + ) +end + +function check_operator_overloading(backend::AutoSparse) + return check_operator_overloading(dense_ad(backend)) +end + +function check_operator_overloading(::MixedMode) + return throw( + ArgumentError( + "Operator overloading check does not make sense for mixed-mode backend" + ), + ) +end ## Pushforward diff --git a/DifferentiationInterface/test/Core/Internals/backends.jl b/DifferentiationInterface/test/Core/Internals/backends.jl index 2edca6aa3..08d4a535c 100644 --- a/DifferentiationInterface/test/Core/Internals/backends.jl +++ b/DifferentiationInterface/test/Core/Internals/backends.jl @@ -9,6 +9,7 @@ using DifferentiationInterface: forward_backend, reverse_backend, check_inplace, + check_operator_overloading, pushforward_performance, pullback_performance, hvp_mode @@ -25,6 +26,7 @@ rb = AutoReverseFromPrimitive(AutoSimpleFiniteDiff()) @test inner(backend) isa AutoReverseFromPrimitive @test mode(backend) isa ADTypes.ForwardMode @test check_inplace(backend) + @test_throws ArgumentError check_operator_overloading(backend) @test_throws ArgumentError pushforward_performance(backend) @test_throws ArgumentError pullback_performance(backend) end @@ -36,6 +38,7 @@ end @test forward_backend(backend) isa AutoSimpleFiniteDiff @test reverse_backend(backend) isa AutoReverseFromPrimitive @test check_inplace(backend) + @test_throws ArgumentError check_operator_overloading(backend) @test_throws MethodError pushforward_performance(backend) @test_throws MethodError pullback_performance(backend) end @@ -45,8 +48,18 @@ end backend = AutoSparse(dense_backend) @test mode(backend) == ADTypes.mode(dense_backend) @test check_inplace(backend) + @test !check_operator_overloading(backend) @test_throws ArgumentError pushforward_performance(backend) @test_throws ArgumentError pullback_performance(backend) @test_throws ArgumentError hvp_mode(backend) end end + +struct FakeBackend <: AbstractADType end + +@testset "Fake" begin + backend = FakeBackend() + @test !check_available(backend) + @test check_inplace(backend) + @test !check_operator_overloading(backend) +end diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 836db05de..37c2549c5 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -11,6 +11,10 @@ LOGGING = get(ENV, "CI", "false") == "false" zero_backends = [AutoZeroForward(), AutoZeroReverse()] +@testset "Correctness" begin + test_differentiation(zero_backends, map(zero, default_scenarios()); logging=LOGGING) +end + @testset "Type stability" begin test_differentiation( AutoZeroForward(), From e8f3dec168c3f3420c62d958e973304594186a04 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 31 Jan 2025 08:27:19 +0100 Subject: [PATCH 22/24] Poly --- ...DifferentiationInterfacePolyesterForwardDiffExt.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index c6650991a..0a53dfbcc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -32,4 +32,15 @@ end include("onearg.jl") include("twoarg.jl") +function DI.overloaded_input( + ::typeof(DI.pushforward), + f::F, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + return DI.overloaded_input(pushforward, f, single_threaded(backend), x, tx, contexts...) +end + end # module From 7dc3e4f4d434d91e064af46e82c5fb4a5a011bb7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 31 Jan 2025 08:50:20 +0100 Subject: [PATCH 23/24] typo --- .../DifferentiationInterfacePolyesterForwardDiffExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index 0a53dfbcc..ae6425a24 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -40,7 +40,9 @@ function DI.overloaded_input( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - return DI.overloaded_input(pushforward, f, single_threaded(backend), x, tx, contexts...) + return DI.overloaded_input( + DI.pushforward, f, single_threaded(backend), x, tx, contexts... + ) end end # module From 2471e2462c016f965d4d7c1f8603432996294f4e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 31 Jan 2025 08:53:04 +0100 Subject: [PATCH 24/24] typo --- DifferentiationInterface/test/Core/Internals/backends.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/DifferentiationInterface/test/Core/Internals/backends.jl b/DifferentiationInterface/test/Core/Internals/backends.jl index 08d4a535c..1f7bda4a1 100644 --- a/DifferentiationInterface/test/Core/Internals/backends.jl +++ b/DifferentiationInterface/test/Core/Internals/backends.jl @@ -56,10 +56,15 @@ end end struct FakeBackend <: AbstractADType end +struct FakeOOBackend <: AbstractADType end +DI.check_operator_overloading(::FakeOOBackend) = true @testset "Fake" begin backend = FakeBackend() @test !check_available(backend) @test check_inplace(backend) @test !check_operator_overloading(backend) + @test_throws ArgumentError DI.overloaded_input( + pushforward, identity, FakeOOBackend(), 1.0, (1.0,) + ) end