From bad9cb5adf2230b771b3065bbf05d3536686c088 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 19:48:02 +0100 Subject: [PATCH 01/15] perf: compute in-place HVP from in-place gradient --- .../src/first_order/gradient.jl | 13 + .../src/misc/from_primitive.jl | 8 +- .../src/second_order/hvp.jl | 309 +++++++++++++++++- .../test/Core/SimpleFiniteDiff/test.jl | 1 + 4 files changed, 318 insertions(+), 13 deletions(-) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 3b6041f8e..eda75f359 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -132,3 +132,16 @@ function shuffled_gradient( ) where {F,C} return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) end + +function shuffled_gradient!( + grad, + x, + f::F, + prep::GradientPrep, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...) + return nothing +end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index ecdc88445..df5d4aa13 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -14,10 +14,16 @@ Wrapper which forces a given backend to act as a reverse-mode backend. Used in internal testing. """ -struct AutoReverseFromPrimitive{B} <: FromPrimitive +struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive backend::B end +function AutoReverseFromPrimitive(backend::AbstractADType; inplace=false) + return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend) +end + +inplace_support(::AutoReverseFromPrimitive{true}) = InPlaceSupported() +inplace_support(::AutoReverseFromPrimitive{false}) = InPlaceNotSupported() ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode() function threshold_batchsize(fromprim::AutoReverseFromPrimitive, dimension::Integer) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 3fd696daa..e0445dfa4 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -65,9 +65,11 @@ end ## Forward over forward -struct ForwardOverForwardHVPPrep{E2<:PushforwardPrep} <: HVPPrep +struct ForwardOverForwardHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep # pushforward of many pushforwards in theory, but pushforward of gradient in practice - outer_pushforward_prep::E2 + grad_buffer::G + outer_pushforward_prep::PO + outer_pushforward_in_prep::PI end function _prepare_hvp_aux( @@ -82,10 +84,16 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) + grad_buffer = similar(x) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - return ForwardOverForwardHVPPrep(outer_pushforward_prep) + outer_pushforward_in_prep = prepare_pushforward( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + return ForwardOverForwardHVPPrep( + grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep + ) end function hvp( @@ -114,6 +122,48 @@ function hvp!( x, tx::NTuple, contexts::Vararg{Context,C}, +) where {F,C} + return _hvp_aux!( + inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... + ) +end + +function _hvp_aux!( + ::InPlaceSupported, + f::F, + tg::NTuple, + prep::ForwardOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; grad_buffer, outer_pushforward_in_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) + return pushforward!( + shuffled_gradient!, + grad_buffer, + tg, + outer_pushforward_in_prep, + outer(backend), + x, + tx, + new_contexts..., + ) +end + +function _hvp_aux!( + ::InPlaceNotSupported, + f::F, + tg::NTuple, + prep::ForwardOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) @@ -158,6 +208,51 @@ function gradient_and_hvp!( x, tx::NTuple, contexts::Vararg{Context,C}, +) where {F,C} + return _gradient_and_hvp_aux!( + inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... + ) +end + +function _gradient_and_hvp_aux!( + ::InPlaceSupported, + f::F, + grad, + tg::NTuple, + prep::ForwardOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_pushforward_in_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) + value_and_pushforward!( + shuffled_gradient!, + grad, + tg, + outer_pushforward_in_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return grad, tg +end + +function _gradient_and_hvp_aux!( + ::InPlaceNotSupported, + f::F, + grad, + tg::NTuple, + prep::ForwardOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) @@ -178,9 +273,11 @@ end ## Forward over reverse -struct ForwardOverReverseHVPPrep{E2<:PushforwardPrep} <: HVPPrep +struct ForwardOverReverseHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep # pushforward of gradient - outer_pushforward_prep::E2 + grad_buffer::G + outer_pushforward_prep::PO + outer_pushforward_in_prep::PI end function _prepare_hvp_aux( @@ -195,10 +292,16 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) + grad_buffer = similar(x) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - return ForwardOverReverseHVPPrep(outer_pushforward_prep) + outer_pushforward_in_prep = prepare_pushforward( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + return ForwardOverReverseHVPPrep( + grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep + ) end function hvp( @@ -227,6 +330,48 @@ function hvp!( x, tx::NTuple, contexts::Vararg{Context,C}, +) where {F,C} + return _hvp_aux!( + inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... + ) +end + +function _hvp_aux!( + ::InPlaceSupported, + f::F, + tg::NTuple, + prep::ForwardOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; grad_buffer, outer_pushforward_in_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) + return pushforward!( + shuffled_gradient!, + grad_buffer, + tg, + outer_pushforward_in_prep, + outer(backend), + x, + tx, + new_contexts..., + ) +end + +function _hvp_aux!( + ::InPlaceNotSupported, + f::F, + tg::NTuple, + prep::ForwardOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) @@ -271,6 +416,51 @@ function gradient_and_hvp!( x, tx::NTuple, contexts::Vararg{Context,C}, +) where {F,C} + return _gradient_and_hvp_aux!( + inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... + ) +end + +function _gradient_and_hvp_aux!( + ::InPlaceSupported, + f::F, + grad, + tg::NTuple, + prep::ForwardOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_pushforward_in_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) + value_and_pushforward!( + shuffled_gradient!, + grad, + tg, + outer_pushforward_in_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return grad, tg +end + +function _gradient_and_hvp_aux!( + ::InPlaceNotSupported, + f::F, + grad, + tg::NTuple, + prep::ForwardOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) @@ -291,10 +481,10 @@ end ## Reverse over forward -struct ReverseOverForwardHVPPrep{E2<:GradientPrep,E1<:GradientPrep} <: HVPPrep +struct ReverseOverForwardHVPPrep{G2<:GradientPrep,G1<:GradientPrep} <: HVPPrep # gradient of pushforward - outer_gradient_prep::E2 - gradient_prep::E1 + outer_gradient_prep::G2 + gradient_prep::G1 end function _prepare_hvp_aux( @@ -404,9 +594,11 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{E2<:PullbackPrep} <: HVPPrep +struct ReverseOverReverseHVPPrep{G,PO<:PullbackPrep,PI<:PullbackPrep} <: HVPPrep # pullback of gradient - outer_pullback_prep::E2 + grad_buffer::G + outer_pullback_prep::PO + outer_pullback_in_prep::PI end function _prepare_hvp_aux( @@ -421,10 +613,16 @@ function _prepare_hvp_aux( new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) + grad_buffer = similar(x) outer_pullback_prep = prepare_pullback( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - return ReverseOverReverseHVPPrep(outer_pullback_prep) + outer_pullback_in_prep = prepare_pullback( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + return ReverseOverReverseHVPPrep( + grad_buffer, outer_pullback_prep, outer_pullback_in_prep + ) end function hvp( @@ -453,6 +651,48 @@ function hvp!( x, tx::NTuple, contexts::Vararg{Context,C}, +) where {F,C} + return _hvp_aux!( + inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... + ) +end + +function _hvp_aux!( + ::InPlaceSupported, + f::F, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; grad_buffer, outer_pullback_in_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) + return pullback!( + shuffled_gradient!, + grad_buffer, + tg, + outer_pullback_in_prep, + outer(backend), + x, + tx, + new_contexts..., + ) +end + +function _hvp_aux!( + ::InPlaceNotSupported, + f::F, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) @@ -491,6 +731,51 @@ function gradient_and_hvp!( x, tx::NTuple, contexts::Vararg{Context,C}, +) where {F,C} + return _gradient_and_hvp_aux!( + inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... + ) +end + +function _gradient_and_hvp_aux!( + ::InPlaceSupported, + f::F, + grad, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_pullback_in_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) + new_grad, _ = value_and_pullback!( + shuffled_gradient!, + grad, + tg, + outer_pullback_in_prep, + outer(backend), + x, + tx, + new_contexts..., + ) + return grad, tg +end + +function _gradient_and_hvp_aux!( + ::InPlaceNotSupported, + f::F, + grad, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index e1c2974cd..708f83e2a 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -9,6 +9,7 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [ # AutoSimpleFiniteDiff(; chunksize=5), AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=3); inplace=false), ] second_order_backends = [ # From 0fd107ca0ff46de42d96fc6c215e095f3aa2c8f7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 19:57:24 +0100 Subject: [PATCH 02/15] Shuffled gradient without prep --- .../src/first_order/gradient.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index eda75f359..64f8aaca4 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -133,6 +133,19 @@ function shuffled_gradient( return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...) end +function shuffled_gradient!( + grad, + x, + f::F, + backend::AbstractADType, + rewrap::Rewrap{C}, + unannotated_contexts::Vararg{Any,C}, +) where {F,C} + gradient!(f, grad, backend, x, rewrap(unannotated_contexts...)...) + return nothing +end + +#= function shuffled_gradient!( grad, x, @@ -145,3 +158,4 @@ function shuffled_gradient!( gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...) return nothing end +=# From f863413a197beeaf90c9d74d02349cb3c387468f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 20:14:14 +0100 Subject: [PATCH 03/15] Inplace true by default --- DifferentiationInterface/src/misc/from_primitive.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index df5d4aa13..7398a2861 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -18,7 +18,7 @@ struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive backend::B end -function AutoReverseFromPrimitive(backend::AbstractADType; inplace=false) +function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true) return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend) end From eff8ff54234fa483969063cbd3f48da0715e4bb0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 21:00:51 +0100 Subject: [PATCH 04/15] Fix FromPrimitive tests --- .../src/misc/from_primitive.jl | 116 ++++++++- .../src/second_order/hvp.jl | 232 +----------------- DifferentiationInterface/src/utils/traits.jl | 2 + .../test/Core/SimpleFiniteDiff/test.jl | 31 ++- 4 files changed, 152 insertions(+), 229 deletions(-) diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 7398a2861..731f27c99 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -1,12 +1,112 @@ -abstract type FromPrimitive <: AbstractADType end +abstract type FromPrimitive{inplace} <: AbstractADType end check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) -inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend) +inplace_support(::FromPrimitive{true}) = InPlaceSupported() +inplace_support(::FromPrimitive{false}) = InPlaceNotSupported() function pick_batchsize(fromprim::FromPrimitive, N::Integer) return pick_batchsize(fromprim.backend, N) end +""" + AutoForwardFromPrimitive + +Wrapper which forces a given backend to act as a reverse-mode backend. + +Used in internal testing. +""" +struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} + backend::B +end + +function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true) + return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend) +end + +ADTypes.mode(::AutoForwardFromPrimitive) = ADTypes.ForwardMode() + +function threshold_batchsize( + fromprim::AutoForwardFromPrimitive{inplace}, dimension::Integer +) where {inplace} + return AutoForwardFromPrimitive( + threshold_batchsize(fromprim.backend, dimension); inplace + ) +end + +struct FromPrimitivePushforwardPrep{E<:PushforwardPrep} <: PushforwardPrep + pushforward_prep::E +end + +function prepare_pushforward( + f::F, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C} +) where {F,C} + primitive_prep = prepare_pushforward(f, fromprim.backend, x, tx, contexts...) + return FromPrimitivePushforwardPrep(primitive_prep) +end + +function prepare_pushforward( + f!::F, y, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C} +) where {F,C} + primitive_prep = prepare_pushforward(f!, y, fromprim.backend, x, tx, contexts...) + return FromPrimitivePushforwardPrep(primitive_prep) +end + +function value_and_pushforward( + f::F, + prep::FromPrimitivePushforwardPrep, + fromprim::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + return value_and_pushforward( + f, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + ) +end + +function value_and_pushforward( + f!::F, + y, + prep::FromPrimitivePushforwardPrep, + fromprim::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + return value_and_pushforward( + f!, y, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + ) +end + +function value_and_pushforward!( + f::F, + ty::NTuple, + prep::FromPrimitivePushforwardPrep, + fromprim::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + return value_and_pushforward!( + f, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + ) +end + +function value_and_pushforward!( + f!::F, + y, + ty::NTuple, + prep::FromPrimitivePushforwardPrep, + fromprim::AutoForwardFromPrimitive, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + return value_and_pushforward!( + f!, y, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts... + ) +end + """ AutoReverseFromPrimitive @@ -14,7 +114,7 @@ Wrapper which forces a given backend to act as a reverse-mode backend. Used in internal testing. """ -struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive +struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} backend::B end @@ -22,12 +122,14 @@ function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true) return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend) end -inplace_support(::AutoReverseFromPrimitive{true}) = InPlaceSupported() -inplace_support(::AutoReverseFromPrimitive{false}) = InPlaceNotSupported() ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode() -function threshold_batchsize(fromprim::AutoReverseFromPrimitive, dimension::Integer) - return AutoReverseFromPrimitive(threshold_batchsize(fromprim.backend, dimension)) +function threshold_batchsize( + fromprim::AutoReverseFromPrimitive{inplace}, dimension::Integer +) where {inplace} + return AutoReverseFromPrimitive( + threshold_batchsize(fromprim.backend, dimension); inplace + ) end struct FromPrimitivePullbackPrep{E<:PullbackPrep} <: PullbackPrep diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index e0445dfa4..26049ec30 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -63,9 +63,9 @@ function prepare_hvp( return _prepare_hvp_aux(hvp_mode(backend), f, backend, x, tx, contexts...) end -## Forward over forward +## Forward over anything -struct ForwardOverForwardHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep +struct ForwardOverAnythingPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep # pushforward of many pushforwards in theory, but pushforward of gradient in practice grad_buffer::G outer_pushforward_prep::PO @@ -73,7 +73,7 @@ struct ForwardOverForwardHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: H end function _prepare_hvp_aux( - ::ForwardOverForward, + ::ForwardOverAnything, f::F, backend::AbstractADType, x, @@ -91,14 +91,14 @@ function _prepare_hvp_aux( outer_pushforward_in_prep = prepare_pushforward( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) - return ForwardOverForwardHVPPrep( + return ForwardOverAnythingPrep( grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep ) end function hvp( f::F, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, @@ -117,7 +117,7 @@ end function hvp!( f::F, tg::NTuple, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, @@ -132,7 +132,7 @@ function _hvp_aux!( ::InPlaceSupported, f::F, tg::NTuple, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, @@ -159,7 +159,7 @@ function _hvp_aux!( ::InPlaceNotSupported, f::F, tg::NTuple, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, @@ -183,7 +183,7 @@ end function gradient_and_hvp( f::F, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, @@ -203,7 +203,7 @@ function gradient_and_hvp!( f::F, grad, tg::NTuple, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, @@ -219,7 +219,7 @@ function _gradient_and_hvp_aux!( f::F, grad, tg::NTuple, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, @@ -248,215 +248,7 @@ function _gradient_and_hvp_aux!( f::F, grad, tg::NTuple, - prep::ForwardOverForwardHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - new_grad, _ = value_and_pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return copyto!(grad, new_grad), tg -end - -## Forward over reverse - -struct ForwardOverReverseHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep - # pushforward of gradient - grad_buffer::G - outer_pushforward_prep::PO - outer_pushforward_in_prep::PI -end - -function _prepare_hvp_aux( - ::ForwardOverReverse, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - grad_buffer = similar(x) - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) - outer_pushforward_in_prep = prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... - ) - return ForwardOverReverseHVPPrep( - grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep - ) -end - -function hvp( - f::F, - prep::ForwardOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) -end - -function hvp!( - f::F, - tg::NTuple, - prep::ForwardOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - return _hvp_aux!( - inplace_support(outer(backend)), f, tg, prep, backend, x, tx, contexts... - ) -end - -function _hvp_aux!( - ::InPlaceSupported, - f::F, - tg::NTuple, - prep::ForwardOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; grad_buffer, outer_pushforward_in_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return pushforward!( - shuffled_gradient!, - grad_buffer, - tg, - outer_pushforward_in_prep, - outer(backend), - x, - tx, - new_contexts..., - ) -end - -function _hvp_aux!( - ::InPlaceNotSupported, - f::F, - tg::NTuple, - prep::ForwardOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return pushforward!( - shuffled_gradient, - tg, - outer_pushforward_prep, - outer(backend), - x, - tx, - new_contexts..., - ) -end - -function gradient_and_hvp( - f::F, - prep::ForwardOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - return value_and_pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... - ) -end - -function gradient_and_hvp!( - f::F, - grad, - tg::NTuple, - prep::ForwardOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - return _gradient_and_hvp_aux!( - inplace_support(outer(backend)), f, grad, tg, prep, backend, x, tx, contexts... - ) -end - -function _gradient_and_hvp_aux!( - ::InPlaceSupported, - f::F, - grad, - tg::NTuple, - prep::ForwardOverReverseHVPPrep, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - (; outer_pushforward_in_prep) = prep - rewrap = Rewrap(contexts...) - new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... - ) - value_and_pushforward!( - shuffled_gradient!, - grad, - tg, - outer_pushforward_in_prep, - outer(backend), - x, - tx, - new_contexts..., - ) - return grad, tg -end - -function _gradient_and_hvp_aux!( - ::InPlaceNotSupported, - f::F, - grad, - tg::NTuple, - prep::ForwardOverReverseHVPPrep, + prep::ForwardOverAnythingPrep, backend::AbstractADType, x, tx::NTuple, diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index f684aea36..7cf00b547 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -141,6 +141,8 @@ Traits identifying second-order backends that compute HVPs in forward over forwa """ struct ForwardOverForward <: HVPMode end +const ForwardOverAnything = Union{ForwardOverForward,ForwardOverReverse} + """ hvp_mode(backend) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 708f83e2a..eb24a0041 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -1,6 +1,9 @@ using DifferentiationInterface, DifferentiationInterfaceTest using DifferentiationInterface: - AutoSimpleFiniteDiff, AutoReverseFromPrimitive, DenseSparsityDetector + AutoSimpleFiniteDiff, + AutoForwardFromPrimitive, + AutoReverseFromPrimitive, + DenseSparsityDetector using SparseMatrixColorings using Test @@ -9,7 +12,6 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [ # AutoSimpleFiniteDiff(; chunksize=5), AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), - AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=3); inplace=false), ] second_order_backends = [ # @@ -23,6 +25,25 @@ second_order_backends = [ # ), ] +second_order_hvp_backends = [ # + SecondOrder( + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff()), + ), + SecondOrder( + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff();), + ), + SecondOrder( + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff();), + ), + SecondOrder( + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(); inplace=false), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff();), + ), +] + adaptive_backends = [ # AutoSimpleFiniteDiff(), AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), @@ -44,6 +65,12 @@ end logging=LOGGING, ) + test_differentiation( + second_order_hvp_backends; + excluded=vcat(FIRST_ORDER, :hessian, :second_derivative), + logging=true, + ) + test_differentiation(backends, complex_scenarios(); logging=LOGGING) end From f5c6ad29890345a7db7401104e05710a700cfeaa Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 21:05:44 +0100 Subject: [PATCH 05/15] Fix Zygote --- .../DifferentiationInterfaceZygoteExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index c5214c666..2d0ab5cfc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -159,7 +159,7 @@ end function DI.hvp( f, - prep::DI.ForwardOverReverseHVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -171,7 +171,7 @@ end function DI.hvp!( f, tg::NTuple, - prep::DI.ForwardOverReverseHVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -184,7 +184,7 @@ end function DI.gradient_and_hvp( f, - prep::DI.ForwardOverReverseHVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, @@ -199,7 +199,7 @@ function DI.gradient_and_hvp!( f, grad, tg::NTuple, - prep::DI.ForwardOverReverseHVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoZygote, x, tx::NTuple, From 6ab9b938eb8cb3cdb95747d57f728f857e5ed8f3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 21:10:29 +0100 Subject: [PATCH 06/15] Fix --- .../src/second_order/hvp.jl | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 26049ec30..ffdd095c0 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -65,7 +65,7 @@ end ## Forward over anything -struct ForwardOverAnythingPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep +struct ForwardOverAnythingHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep # pushforward of many pushforwards in theory, but pushforward of gradient in practice grad_buffer::G outer_pushforward_prep::PO @@ -91,14 +91,14 @@ function _prepare_hvp_aux( outer_pushforward_in_prep = prepare_pushforward( shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... ) - return ForwardOverAnythingPrep( + return ForwardOverAnythingHVPPrep( grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep ) end function hvp( f::F, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -117,7 +117,7 @@ end function hvp!( f::F, tg::NTuple, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -132,7 +132,7 @@ function _hvp_aux!( ::InPlaceSupported, f::F, tg::NTuple, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -159,7 +159,7 @@ function _hvp_aux!( ::InPlaceNotSupported, f::F, tg::NTuple, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -183,7 +183,7 @@ end function gradient_and_hvp( f::F, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -203,7 +203,7 @@ function gradient_and_hvp!( f::F, grad, tg::NTuple, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -219,7 +219,7 @@ function _gradient_and_hvp_aux!( f::F, grad, tg::NTuple, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -248,7 +248,7 @@ function _gradient_and_hvp_aux!( f::F, grad, tg::NTuple, - prep::ForwardOverAnythingPrep, + prep::ForwardOverAnythingHVPPrep, backend::AbstractADType, x, tx::NTuple, From 2cc261f497b9d2a6ea954a4c6df4b864c37860e1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 22:27:17 +0100 Subject: [PATCH 07/15] Avoid prep --- .../src/second_order/hvp.jl | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index ffdd095c0..b4a6e749d 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -65,7 +65,9 @@ end ## Forward over anything -struct ForwardOverAnythingHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep +struct ForwardOverAnythingHVPPrep{ + G,PO<:PushforwardPrep,PI<:Union{Nothing,PushforwardPrep} +} <: HVPPrep # pushforward of many pushforwards in theory, but pushforward of gradient in practice grad_buffer::G outer_pushforward_prep::PO @@ -88,9 +90,13 @@ function _prepare_hvp_aux( outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - outer_pushforward_in_prep = prepare_pushforward( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... - ) + outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported + prepare_pushforward( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + else + nothing + end return ForwardOverAnythingHVPPrep( grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep ) @@ -386,7 +392,8 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{G,PO<:PullbackPrep,PI<:PullbackPrep} <: HVPPrep +struct ReverseOverReverseHVPPrep{G,PO<:PullbackPrep,PI<:Union{Nothing,PullbackPrep}} <: + HVPPrep # pullback of gradient grad_buffer::G outer_pullback_prep::PO @@ -409,9 +416,13 @@ function _prepare_hvp_aux( outer_pullback_prep = prepare_pullback( shuffled_gradient, outer(backend), x, tx, new_contexts... ) - outer_pullback_in_prep = prepare_pullback( - shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... - ) + outer_pullback_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported + prepare_pullback( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts... + ) + else + nothing + end return ReverseOverReverseHVPPrep( grad_buffer, outer_pullback_prep, outer_pullback_in_prep ) From fb524418ae0c2e738b1328c5455a0fb2703bfd59 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 23:13:53 +0100 Subject: [PATCH 08/15] Inner HVP preparation --- .../reverse_onearg.jl | 14 +- .../DifferentiationInterfaceForwardDiffExt.jl | 1 + .../misc.jl | 27 +++ .../onearg.jl | 24 +-- .../twoarg.jl | 20 +- .../utils.jl | 2 +- .../DifferentiationInterfaceTrackerExt.jl | 28 +-- .../DifferentiationInterfaceZygoteExt.jl | 14 +- .../src/first_order/gradient.jl | 2 - .../src/misc/from_primitive.jl | 3 + .../src/misc/simple_finite_diff.jl | 1 + .../src/second_order/hvp.jl | 196 ++++++++++++++++-- DifferentiationInterface/src/utils/context.jl | 17 +- DifferentiationInterface/src/utils/traits.jl | 12 ++ .../test/Back/ForwardDiff/test.jl | 6 +- 15 files changed, 278 insertions(+), 89 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 04d3f55f1..079e5dd1a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -6,11 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, - ::AutoReverseChainRules, - x, - ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoPullbackPrep() end @@ -21,7 +17,7 @@ function DI.prepare_pullback_same_point( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -34,7 +30,7 @@ function DI.value_and_pullback( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -50,7 +46,7 @@ function DI.value_and_pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -65,7 +61,7 @@ function DI.pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; pb) = prep tx = map(ty) do dy diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index b978c065a..2e46031a0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -28,6 +28,7 @@ using ForwardDiff: value DI.check_available(::AutoForwardDiff) = true +DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload() include("utils.jl") include("onearg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index d2b76a6d1..54e308485 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -2,6 +2,33 @@ DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp) DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) +function DI.overloaded_input( + ::typeof(DI.pushforward), + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, tx) + return xdual +end + +function DI.overloaded_input( + ::typeof(DI.pushforward), + f!::F, + y, + backend::AutoForwardDiff, + x, + tx::NTuple{B}, + contexts::Vararg{DI.Context,C}, +) where {F,B,C} + T = tag_type(f!, backend, x) + xdual = make_dual(T, x, tx) + return xdual +end + ## Derivative function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) return DI.overloaded_input_type(prep.pushforward_prep) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index b405cbb8d..43c0f44be 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -272,7 +272,7 @@ function DI.value_and_gradient!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -292,7 +292,7 @@ function DI.value_and_gradient( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) result = GradientResult(x) @@ -310,7 +310,7 @@ function DI.gradient!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) return gradient!(grad, fc, x) @@ -326,7 +326,7 @@ function DI.gradient( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) return gradient(fc, x) @@ -435,7 +435,7 @@ function DI.value_and_jacobian!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -456,7 +456,7 @@ function DI.value_and_jacobian( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) return fc(x), jacobian(fc, x) @@ -472,7 +472,7 @@ function DI.jacobian!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) return jacobian!(jac, fc, x) @@ -488,7 +488,7 @@ function DI.jacobian( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) return jacobian(fc, x) @@ -738,7 +738,7 @@ function DI.hessian!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) return hessian!(hess, fc, x) @@ -754,7 +754,7 @@ function DI.hessian( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) return hessian(fc, x) @@ -775,7 +775,7 @@ function DI.value_gradient_and_hessian!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -796,7 +796,7 @@ function DI.value_gradient_and_hessian( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.with_contexts(f, contexts...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 8acb1d9a0..f52ae69d9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -117,7 +117,7 @@ end function DI.value_and_derivative( f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}) + if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) result = derivative!(result, fc!, y, x) @@ -131,7 +131,7 @@ end function DI.value_and_derivative!( f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}) + if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) result = derivative!(result, fc!, y, x) @@ -145,7 +145,7 @@ end function DI.derivative( f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}) + if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) fc! = DI.with_contexts(f!, contexts...) return derivative(fc!, y, x) else @@ -157,7 +157,7 @@ end function DI.derivative!( f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} - if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}) + if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) fc! = DI.with_contexts(f!, contexts...) return derivative!(der, fc!, y, x) else @@ -188,7 +188,7 @@ function DI.prepare!_derivative( old_prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} if y isa Vector (; config) = old_prep @@ -283,7 +283,7 @@ function DI.value_and_jacobian( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -302,7 +302,7 @@ function DI.value_and_jacobian!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) @@ -320,7 +320,7 @@ function DI.jacobian( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc! = DI.with_contexts(f!, contexts...) return jacobian(fc!, y, x) @@ -336,7 +336,7 @@ function DI.jacobian!( if ( isnothing(chunksize) && T === Nothing && - contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend} + contexts isa NTuple{C,DI.GeneralizedConstant} ) fc! = DI.with_contexts(f!, contexts...) return jacobian!(jac, fc!, y, x) @@ -369,7 +369,7 @@ function DI.prepare!_jacobian( old_prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {F,C} if x isa Vector && y isa Vector (; config) = old_prep diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index ed8393576..fb5d86474 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -87,7 +87,7 @@ struct PrepContext{T<:DI.Prep} <: DI.Context data::T end -NotCache = Union{DI.ConstantOrFunctionOrBackend,PrepContext} +NotCache = Union{DI.GeneralizedConstant,PrepContext} _translate(::Type{D}, c::NotCache) where {D<:Dual} = DI.unwrap(c) function _translate(::Type{D}, c::DI.Cache) where {D<:Dual} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index fb5da6b76..17d885ca7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -15,7 +15,7 @@ struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoPullbackPrep() end @@ -26,7 +26,7 @@ function DI.prepare_pullback_same_point( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(y, pb) @@ -38,7 +38,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -53,7 +53,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -68,7 +68,7 @@ function DI.pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -80,28 +80,20 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoTracker, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, - ::DI.NoGradientPrep, - ::AutoTracker, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, - ::DI.NoGradientPrep, - ::AutoTracker, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) @@ -113,7 +105,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -125,7 +117,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 2d0ab5cfc..adf1c397e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -213,17 +213,13 @@ end ## Hessian function DI.prepare_hessian( - f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} return DI.NoHessianPrep() end function DI.hessian( - f, - ::DI.NoHessianPrep, - ::AutoZygote, - x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + f, ::DI.NoHessianPrep, ::AutoZygote, x, contexts::Vararg{DI.GeneralizedConstant,C} ) where {C} fc = DI.with_contexts(f, contexts...) hess = hessian(fc, x) @@ -236,7 +232,7 @@ function DI.hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end @@ -246,7 +242,7 @@ function DI.value_gradient_and_hessian( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) @@ -260,7 +256,7 @@ function DI.value_gradient_and_hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 64f8aaca4..17c39ac92 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -145,7 +145,6 @@ function shuffled_gradient!( return nothing end -#= function shuffled_gradient!( grad, x, @@ -158,4 +157,3 @@ function shuffled_gradient!( gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...) return nothing end -=# diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 731f27c99..496675503 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -3,6 +3,9 @@ abstract type FromPrimitive{inplace} <: AbstractADType end check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) inplace_support(::FromPrimitive{true}) = InPlaceSupported() inplace_support(::FromPrimitive{false}) = InPlaceNotSupported() +function inner_preparation_behavior(fromprim::FromPrimitive) + return inner_preparation_behavior(fromprim.backend) +end function pick_batchsize(fromprim::FromPrimitive, N::Integer) return pick_batchsize(fromprim.backend, N) diff --git a/DifferentiationInterface/src/misc/simple_finite_diff.jl b/DifferentiationInterface/src/misc/simple_finite_diff.jl index 11f05ac1e..b36f36503 100644 --- a/DifferentiationInterface/src/misc/simple_finite_diff.jl +++ b/DifferentiationInterface/src/misc/simple_finite_diff.jl @@ -18,6 +18,7 @@ end ADTypes.mode(::AutoSimpleFiniteDiff) = ForwardMode() check_available(::AutoSimpleFiniteDiff) = true inplace_support(::AutoSimpleFiniteDiff) = InPlaceSupported() +inner_preparation_behavior(::AutoSimpleFiniteDiff) = PrepareInnerSimple() function pick_batchsize(::AutoSimpleFiniteDiff{nothing}, N::Integer) B = reasonable_batchsize(N, 12) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index b4a6e749d..8e71ee4b4 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -60,33 +60,43 @@ function gradient_and_hvp! end function prepare_hvp( f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} ) where {F,C} - return _prepare_hvp_aux(hvp_mode(backend), f, backend, x, tx, contexts...) + return _prepare_hvp_aux( + hvp_mode(backend), + inner_preparation_behavior(outer(backend)), + f, + backend, + x, + tx, + contexts..., + ) end ## Forward over anything -struct ForwardOverAnythingHVPPrep{ - G,PO<:PushforwardPrep,PI<:Union{Nothing,PushforwardPrep} -} <: HVPPrep +struct ForwardOverAnythingHVPPrep{G,GO,GI,PO,PI} <: HVPPrep # pushforward of many pushforwards in theory, but pushforward of gradient in practice grad_buffer::G + maybe_inner_gradient_prep::GO + maybe_inner_gradient_in_prep::GI outer_pushforward_prep::PO outer_pushforward_in_prep::PI end function _prepare_hvp_aux( ::ForwardOverAnything, + ::DontPrepareInner, f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + grad_buffer = similar(x) rewrap = Rewrap(contexts...) + # Outer pushforward new_contexts = ( FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... ) - grad_buffer = similar(x) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -98,7 +108,131 @@ function _prepare_hvp_aux( nothing end return ForwardOverAnythingHVPPrep( - grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep + grad_buffer, (), (), outer_pushforward_prep, outer_pushforward_in_prep + ) +end + +function _prepare_hvp_aux( + ::ForwardOverAnything, + ::PrepareInnerSimple, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + grad_buffer = similar(x) + rewrap = Rewrap(contexts...) + # Inner gradient + inner_gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) + inner_gradient_in_prep = inner_gradient_prep + # Outer pushforward + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + new_contexts_in = ( + FunctionContext(f), + PrepContext(inner_gradient_in_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported + prepare_pushforward( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in... + ) + else + nothing + end + return ForwardOverAnythingHVPPrep( + grad_buffer, + (inner_gradient_prep,), + (inner_gradient_in_prep,), + outer_pushforward_prep, + outer_pushforward_in_prep, + ) +end + +function _prepare_hvp_aux( + ::ForwardOverAnything, + ::PrepareInnerOverload, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + grad_buffer = similar(x) + rewrap = Rewrap(contexts...) + # Inner gradient + new_contexts_unknown = ( + FunctionContext(f), + UnknownContext(), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + inner_gradient_prep = let + xo = overloaded_input( + pushforward, + shuffled_gradient, + outer(backend), + x, + tx, + new_contexts_unknown..., + ) + prepare_gradient(f, inner(backend), xo, contexts...) + end + inner_gradient_in_prep = let + xo = overloaded_input( + pushforward, + shuffled_gradient!, + grad_buffer, + outer(backend), + x, + tx, + new_contexts_unknown..., + ) + prepare_gradient(f, inner(backend), xo, contexts...) + end + # Outer pushforward + new_contexts = ( + FunctionContext(f), + PrepContext(inner_gradient_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + new_contexts_in = ( + FunctionContext(f), + PrepContext(inner_gradient_in_prep), + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., + ) + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported + prepare_pushforward( + shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts_in... + ) + else + nothing + end + return ForwardOverAnythingHVPPrep( + grad_buffer, + (inner_gradient_prep,), + (inner_gradient_in_prep,), + outer_pushforward_prep, + outer_pushforward_in_prep, ) end @@ -110,10 +244,14 @@ function hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + map(PrepContext, maybe_inner_gradient_prep)..., + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... @@ -144,10 +282,14 @@ function _hvp_aux!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; grad_buffer, outer_pushforward_in_prep) = prep + (; grad_buffer, maybe_inner_gradient_in_prep, outer_pushforward_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + map(PrepContext, maybe_inner_gradient_in_prep)..., + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return pushforward!( shuffled_gradient!, @@ -171,10 +313,14 @@ function _hvp_aux!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + map(PrepContext, maybe_inner_gradient_prep)..., + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return pushforward!( shuffled_gradient, @@ -195,10 +341,14 @@ function gradient_and_hvp( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + map(PrepContext, maybe_inner_gradient_prep)..., + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... @@ -231,10 +381,14 @@ function _gradient_and_hvp_aux!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_in_prep) = prep + (; maybe_inner_gradient_in_prep, outer_pushforward_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + map(PrepContext, maybe_inner_gradient_in_prep)..., + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) value_and_pushforward!( shuffled_gradient!, @@ -260,10 +414,14 @@ function _gradient_and_hvp_aux!( tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pushforward_prep) = prep + (; maybe_inner_gradient_prep, outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), + map(PrepContext, maybe_inner_gradient_prep)..., + BackendContext(inner(backend)), + Constant(rewrap), + contexts..., ) new_grad, _ = value_and_pushforward!( shuffled_gradient, @@ -287,6 +445,7 @@ end function _prepare_hvp_aux( ::ReverseOverForward, + ::InnerPreparationBehavior, f::F, backend::AbstractADType, x, @@ -402,6 +561,7 @@ end function _prepare_hvp_aux( ::ReverseOverReverse, + ::InnerPreparationBehavior, f::F, backend::AbstractADType, x, diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 0b1f8f591..2ed868397 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -22,6 +22,9 @@ Abstract supertype for additional context arguments, which can be passed to diff """ abstract type Context end +abstract type GeneralizedConstant <: Context end +abstract type GeneralizedCache <: Context end + unwrap(c::Context) = c.data Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) @@ -58,7 +61,7 @@ julia> gradient(f, AutoForwardDiff(), [1.0, 2.0], Constant(100)) 400.0 ``` """ -struct Constant{T} <: Context +struct Constant{T} <: GeneralizedConstant data::T end @@ -94,7 +97,7 @@ julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2))) 1.0 ```` """ -struct Cache{T} <: Context +struct Cache{T} <: GeneralizedCache data::T end @@ -103,15 +106,19 @@ maker(::Cache) = cache_maker ## Internal contexts for passing stuff around -struct FunctionContext{T} <: Context +struct FunctionContext{T} <: GeneralizedConstant + data::T +end + +struct BackendContext{T} <: GeneralizedConstant data::T end -struct BackendContext{T} <: Context +struct PrepContext{T} <: GeneralizedConstant data::T end -const ConstantOrFunctionOrBackend = Union{Constant,FunctionContext,BackendContext} +struct UnknownContext <: Context end ## Context manipulation diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index 7cf00b547..d29307930 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -171,6 +171,18 @@ function hvp_mode(backend::AutoSparse) throw(ArgumentError("HVP mode not defined for $backend`.")) end +## Inner prep + +abstract type InnerPreparationBehavior end + +struct PrepareInnerSimple <: InnerPreparationBehavior end +struct PrepareInnerOverload <: InnerPreparationBehavior end +struct DontPrepareInner <: InnerPreparationBehavior end + +inner_preparation_behavior(::AbstractADType) = DontPrepareInner() + +function overloaded_input end + ## Conversions Base.Bool(::InPlaceSupported) = true diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 6d5209833..f4024afbd 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -42,11 +42,7 @@ end ) test_differentiation( - AutoForwardDiff(); - correctness=false, - type_stability=:prepared, - excluded=[:hvp], # TODO: toggle - logging=LOGGING, + AutoForwardDiff(); correctness=false, type_stability=:prepared, logging=LOGGING ) test_differentiation( From 1c9f6e1a1655e0839a01c7ab3c70ca76b347d710 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 23:14:25 +0100 Subject: [PATCH 09/15] No fail fast --- .github/workflows/Test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index f91ce0853..13de31d96 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true # TODO: toggle + fail-fast: false # TODO: toggle matrix: version: - "1.10" From ddc2e3f89462df543accc2e34249899ed1568ab8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 23:27:54 +0100 Subject: [PATCH 10/15] Appease JET --- .../DifferentiationInterfaceForwardDiffExt/utils.jl | 13 +++---------- DifferentiationInterface/src/second_order/hvp.jl | 3 +-- DifferentiationInterface/src/utils/traits.jl | 4 +++- .../test/Core/Internals/_formalities.jl | 1 + 4 files changed, 8 insertions(+), 13 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index fb5d86474..1b983bca8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -82,14 +82,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} return ty end -# store preparation result with the right input eltype -struct PrepContext{T<:DI.Prep} <: DI.Context - data::T -end - -NotCache = Union{DI.GeneralizedConstant,PrepContext} - -_translate(::Type{D}, c::NotCache) where {D<:Dual} = DI.unwrap(c) +_translate(::Type{D}, c::DI.GeneralizedConstant) where {D<:Dual} = DI.unwrap(c) function _translate(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) return similar(c0, D) @@ -102,7 +95,7 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} return new_contexts end -_translate_toprep(::Type{D}, c::NotCache) where {D<:Dual} = nothing +_translate_toprep(::Type{D}, c::DI.GeneralizedConstant) where {D<:Dual} = nothing function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) return similar(c0, D) @@ -115,7 +108,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D return new_contexts end -_translate_prepared(c::NotCache, _pc) = DI.unwrap(c) +_translate_prepared(c::DI.GeneralizedConstant, _pc) = DI.unwrap(c) _translate_prepared(_c::DI.Cache, pc) = pc function translate_prepared( diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 8e71ee4b4..848c14191 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -551,8 +551,7 @@ end ## Reverse over reverse -struct ReverseOverReverseHVPPrep{G,PO<:PullbackPrep,PI<:Union{Nothing,PullbackPrep}} <: - HVPPrep +struct ReverseOverReverseHVPPrep{G,PO,PI} <: HVPPrep # pullback of gradient grad_buffer::G outer_pullback_prep::PO diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index d29307930..63bb2a05f 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -181,7 +181,9 @@ struct DontPrepareInner <: InnerPreparationBehavior end inner_preparation_behavior(::AbstractADType) = DontPrepareInner() -function overloaded_input end +function overloaded_input(optype, f, backend, x, args...) + throw(ArgumentError("Just to appease JET")) +end ## Conversions diff --git a/DifferentiationInterface/test/Core/Internals/_formalities.jl b/DifferentiationInterface/test/Core/Internals/_formalities.jl index c0ad40c27..e23f8624d 100644 --- a/DifferentiationInterface/test/Core/Internals/_formalities.jl +++ b/DifferentiationInterface/test/Core/Internals/_formalities.jl @@ -3,6 +3,7 @@ using Aqua: Aqua using DifferentiationInterface using ExplicitImports +using ForwardDiff: ForwardDiff using JET: JET using JuliaFormatter: JuliaFormatter using Test From dfd537985f855fa009e8cab889500445d06518a6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Mar 2025 23:36:48 +0100 Subject: [PATCH 11/15] ForwardDiff test dep --- DifferentiationInterface/Project.toml | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 39e53ba6e..9d96a9d42 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -109,4 +109,22 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"] +test = [ + "ADTypes", + "Aqua", + "ComponentArrays", + "DataFrames", + "ExplicitImports", + "ForwardDiff", + "JET", + "JLArrays", + "JuliaFormatter", + "Pkg", + "Random", + "SparseArrays", + "SparseConnectivityTracer", + "SparseMatrixColorings", + "StableRNGs", + "StaticArrays", + "Test", +] From b09faabf56e8be9f9ef2f06761ab4df15d2028ed Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 15 Mar 2025 08:53:14 +0100 Subject: [PATCH 12/15] Fix PolyesterForwardDiff --- DifferentiationInterface/Project.toml | 3 +- .../utils.jl | 4 +- .../misc.jl | 19 +--- .../utils.jl | 14 ++- ...tiationInterfacePolyesterForwardDiffExt.jl | 25 ++---- .../misc.jl | 11 +++ .../onearg.jl | 86 +++++++++++++------ .../twoarg.jl | 74 +++++++++++----- .../utils.jl | 14 +++ .../src/misc/overloading.jl | 10 +++ .../src/second_order/hvp.jl | 24 +----- DifferentiationInterface/src/utils/context.jl | 2 +- DifferentiationInterface/src/utils/traits.jl | 26 +++++- .../test/Back/PolyesterForwardDiff/test.jl | 9 +- .../test/Core/Internals/_formalities.jl | 3 +- .../test/Core/SimpleFiniteDiff/test.jl | 9 ++ 16 files changed, 212 insertions(+), 121 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl create mode 100644 DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 9d96a9d42..1a1f819dd 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -39,7 +39,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" -DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" +DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -115,7 +115,6 @@ test = [ "ComponentArrays", "DataFrames", "ExplicitImports", - "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index edadfff96..c189d09d9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -48,13 +48,13 @@ force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) @inline function _translate( - ::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} + ::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant ) where {B} return Const(DI.unwrap(c)) end @inline function _translate( - backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache + ::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache ) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index 54e308485..2d5e4c9cc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -3,27 +3,16 @@ DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.x DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp) function DI.overloaded_input( - ::typeof(DI.pushforward), - f::F, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + ::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B} +) where {F,B} T = tag_type(f, backend, x) xdual = make_dual(T, x, tx) return xdual end function DI.overloaded_input( - ::typeof(DI.pushforward), - f!::F, - y, - backend::AutoForwardDiff, - x, - tx::NTuple{B}, - contexts::Vararg{DI.Context,C}, -) where {F,B,C} + ::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B} +) where {F,B} T = tag_type(f!, backend, x) xdual = make_dual(T, x, tx) return xdual diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 1b983bca8..c094bd14a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -82,7 +82,11 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} return ty end -_translate(::Type{D}, c::DI.GeneralizedConstant) where {D<:Dual} = DI.unwrap(c) +function _translate( + ::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext} +) where {D<:Dual} + return DI.unwrap(c) +end function _translate(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) return similar(c0, D) @@ -95,7 +99,11 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} return new_contexts end -_translate_toprep(::Type{D}, c::DI.GeneralizedConstant) where {D<:Dual} = nothing +function _translate_toprep( + ::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext} +) where {D<:Dual} + return nothing +end function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) return similar(c0, D) @@ -108,7 +116,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D return new_contexts end -_translate_prepared(c::DI.GeneralizedConstant, _pc) = DI.unwrap(c) +_translate_prepared(c::Union{DI.GeneralizedConstant,DI.PrepContext}, _pc) = DI.unwrap(c) _translate_prepared(_c::DI.Cache, pc) = pc function translate_prepared( diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index 0efbc9695..666b7f5f9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -4,31 +4,22 @@ using ADTypes: AutoForwardDiff, AutoPolyesterForwardDiff import DifferentiationInterface as DI using LinearAlgebra: mul! using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! -using PolyesterForwardDiff.ForwardDiff: Chunk -using PolyesterForwardDiff.ForwardDiff.DiffResults: DiffResults +using ForwardDiff: Chunk +using DiffResults: DiffResults + +const FDExt = Base.get_extension(DI, :DifferentiationInterfaceForwardDiffExt) +@assert !isnothing(FDExt) function single_threaded(backend::AutoPolyesterForwardDiff{chunksize,T}) where {chunksize,T} return AutoForwardDiff(; chunksize, tag=backend.tag) end DI.check_available(::AutoPolyesterForwardDiff) = true +DI.inner_preparation_behavior(::AutoPolyesterForwardDiff) = DI.PrepareInnerOverload() -function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, x::AbstractArray) - return DI.pick_batchsize(single_threaded(backend), x) -end - -function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, N::Integer) - return DI.pick_batchsize(single_threaded(backend), N) -end - -function DI.threshold_batchsize( - backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer -) where {chunksize1} - chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2) - return AutoPolyesterForwardDiff(; chunksize, tag=backend.tag) -end - +include("utils.jl") include("onearg.jl") include("twoarg.jl") +include("misc.jl") end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl new file mode 100644 index 000000000..2262e407c --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/misc.jl @@ -0,0 +1,11 @@ +function DI.overloaded_input( + ::typeof(DI.pushforward), f::F, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B} +) where {F,B} + return DI.overloaded_input(DI.pushforward, f, single_threaded(backend), x, tx) +end + +function DI.overloaded_input( + ::typeof(DI.pushforward), f!::F, y, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B} +) where {F,B} + return DI.overloaded_input(DI.pushforward, f!, y, single_threaded(backend), x, tx) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index dde2203ab..17cf1ab6c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -107,45 +107,61 @@ end ## Gradient -struct PolyesterForwardDiffGradientPrep{chunksize} <: DI.GradientPrep +struct PolyesterForwardDiffGradientPrep{chunksize,P} <: DI.GradientPrep chunk::Chunk{chunksize} + single_threaded_prep::P end function DI.prepare_gradient( - f, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} ) where {chunksize,C} if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end - return PolyesterForwardDiffGradientPrep(chunk) + single_threaded_prep = DI.prepare_gradient(f, single_threaded(backend), x, contexts...) + return PolyesterForwardDiffGradientPrep(chunk, single_threaded_prep) end function DI.value_and_gradient!( f, grad, prep::PolyesterForwardDiffGradientPrep, - ::AutoPolyesterForwardDiff, + backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - fc = DI.with_contexts(f, contexts...) - threaded_gradient!(fc, grad, x, prep.chunk) - return fc(x), grad + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc = DI.with_contexts(f, contexts...) + threaded_gradient!(fc, grad, x, prep.chunk) + return fc(x), grad + else + # TODO: optimize + return DI.value_and_gradient!( + f, grad, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end function DI.gradient!( f, grad, prep::PolyesterForwardDiffGradientPrep, - ::AutoPolyesterForwardDiff, + backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - fc = DI.with_contexts(f, contexts...) - threaded_gradient!(fc, grad, x, prep.chunk) - return grad + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc = DI.with_contexts(f, contexts...) + threaded_gradient!(fc, grad, x, prep.chunk) + return grad + else + # TODO: optimize + return DI.gradient!( + f, grad, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end function DI.value_and_gradient( @@ -170,43 +186,57 @@ end ## Jacobian -struct PolyesterForwardDiffOneArgJacobianPrep{chunksize} <: DI.JacobianPrep +struct PolyesterForwardDiffOneArgJacobianPrep{chunksize,P} <: DI.JacobianPrep chunk::Chunk{chunksize} + single_threaded_prep::P end function DI.prepare_jacobian( - f, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} ) where {chunksize,C} if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end - return PolyesterForwardDiffOneArgJacobianPrep(chunk) + single_threaded_prep = DI.prepare_jacobian(f, single_threaded(backend), x, contexts...) + return PolyesterForwardDiffOneArgJacobianPrep(chunk, single_threaded_prep) end function DI.value_and_jacobian!( f, jac, prep::PolyesterForwardDiffOneArgJacobianPrep, - ::AutoPolyesterForwardDiff, + backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - fc = DI.with_contexts(f, contexts...) - return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk) + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc = DI.with_contexts(f, contexts...) + return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk) + else + return DI.value_and_jacobian!( + f, jac, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end function DI.jacobian!( f, jac, prep::PolyesterForwardDiffOneArgJacobianPrep, - ::AutoPolyesterForwardDiff, + backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - fc = DI.with_contexts(f, contexts...) - return threaded_jacobian!(fc, jac, x, prep.chunk) + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc = DI.with_contexts(f, contexts...) + return threaded_jacobian!(fc, jac, x, prep.chunk) + else + return DI.jacobian!( + f, jac, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end function DI.value_and_jacobian( @@ -217,9 +247,8 @@ function DI.value_and_jacobian( contexts::Vararg{DI.Context,C}, ) where {C} y = f(x, map(DI.unwrap, contexts)...) - return DI.value_and_jacobian!( - f, similar(y, length(y), length(x)), prep, backend, x, contexts... - ) + jac = similar(y, length(y), length(x)) + return DI.value_and_jacobian!(f, jac, prep, backend, x, contexts...) end function DI.jacobian( @@ -230,7 +259,8 @@ function DI.jacobian( contexts::Vararg{DI.Context,C}, ) where {C} y = f(x, map(DI.unwrap, contexts)...) - return DI.jacobian!(f, similar(y, length(y), length(x)), prep, backend, x, contexts...) + jac = similar(y, length(y), length(x)) + return DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Hessian @@ -299,7 +329,7 @@ end function DI.hvp( f, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, @@ -313,7 +343,7 @@ end function DI.hvp!( f, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, @@ -326,7 +356,7 @@ end function DI.gradient_and_hvp( f, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, @@ -341,7 +371,7 @@ function DI.gradient_and_hvp!( f, grad, tg::NTuple, - prep::DI.HVPPrep, + prep::DI.ForwardOverAnythingHVPPrep, backend::AutoPolyesterForwardDiff, x, tx::NTuple, diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 1d03290aa..4f63a65c1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -118,34 +118,44 @@ end ## Jacobian -struct PolyesterForwardDiffTwoArgJacobianPrep{chunksize} <: DI.JacobianPrep +struct PolyesterForwardDiffTwoArgJacobianPrep{chunksize,P} <: DI.JacobianPrep chunk::Chunk{chunksize} + single_threaded_prep::P end function DI.prepare_jacobian( - f!, y, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} + f!, y, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C} ) where {chunksize,C} if isnothing(chunksize) chunk = Chunk(x) else chunk = Chunk{chunksize}() end - return PolyesterForwardDiffTwoArgJacobianPrep(chunk) + single_threaded_prep = DI.prepare_jacobian( + f!, y, single_threaded(backend), x, contexts... + ) + return PolyesterForwardDiffTwoArgJacobianPrep(chunk, single_threaded_prep) end function DI.value_and_jacobian( f!, y, prep::PolyesterForwardDiffTwoArgJacobianPrep, - ::AutoPolyesterForwardDiff{K}, + backend::AutoPolyesterForwardDiff{K}, x, contexts::Vararg{DI.Context,C}, ) where {K,C} - fc! = DI.with_contexts(f!, contexts...) - jac = similar(y, length(y), length(x)) - threaded_jacobian!(fc!, y, jac, x, prep.chunk) - fc!(y, x) - return y, jac + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc! = DI.with_contexts(f!, contexts...) + jac = similar(y, length(y), length(x)) + threaded_jacobian!(fc!, y, jac, x, prep.chunk) + fc!(y, x) + return y, jac + else + return DI.value_and_jacobian( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end function DI.value_and_jacobian!( @@ -153,28 +163,40 @@ function DI.value_and_jacobian!( y, jac, prep::PolyesterForwardDiffTwoArgJacobianPrep, - ::AutoPolyesterForwardDiff{K}, + backend::AutoPolyesterForwardDiff{K}, x, contexts::Vararg{DI.Context,C}, ) where {K,C} - fc! = DI.with_contexts(f!, contexts...) - threaded_jacobian!(fc!, y, jac, x, prep.chunk) - fc!(y, x) - return y, jac + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc! = DI.with_contexts(f!, contexts...) + threaded_jacobian!(fc!, y, jac, x, prep.chunk) + fc!(y, x) + return y, jac + else + return DI.value_and_jacobian!( + f!, y, jac, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end function DI.jacobian( f!, y, prep::PolyesterForwardDiffTwoArgJacobianPrep, - ::AutoPolyesterForwardDiff, + backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - fc! = DI.with_contexts(f!, contexts...) - jac = similar(y, length(y), length(x)) - threaded_jacobian!(fc!, y, jac, x, prep.chunk) - return jac + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc! = DI.with_contexts(f!, contexts...) + jac = similar(y, length(y), length(x)) + threaded_jacobian!(fc!, y, jac, x, prep.chunk) + return jac + else + return DI.jacobian( + f!, y, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end function DI.jacobian!( @@ -182,11 +204,17 @@ function DI.jacobian!( y, jac, prep::PolyesterForwardDiffTwoArgJacobianPrep, - ::AutoPolyesterForwardDiff, + backend::AutoPolyesterForwardDiff, x, contexts::Vararg{DI.Context,C}, ) where {C} - fc! = DI.with_contexts(f!, contexts...) - threaded_jacobian!(fc!, y, jac, x, prep.chunk) - return jac + if contexts isa NTuple{C,DI.GeneralizedConstant} + fc! = DI.with_contexts(f!, contexts...) + threaded_jacobian!(fc!, y, jac, x, prep.chunk) + return jac + else + return DI.jacobian!( + f!, y, jac, prep.single_threaded_prep, single_threaded(backend), x, contexts... + ) + end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl new file mode 100644 index 000000000..0322023fb --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/utils.jl @@ -0,0 +1,14 @@ +function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, x::AbstractArray) + return DI.pick_batchsize(single_threaded(backend), x) +end + +function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, N::Integer) + return DI.pick_batchsize(single_threaded(backend), N) +end + +function DI.threshold_batchsize( + backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer +) where {chunksize1} + chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2) + return AutoPolyesterForwardDiff(; chunksize, tag=backend.tag) +end diff --git a/DifferentiationInterface/src/misc/overloading.jl b/DifferentiationInterface/src/misc/overloading.jl index bda61a192..8c11a5bcf 100644 --- a/DifferentiationInterface/src/misc/overloading.jl +++ b/DifferentiationInterface/src/misc/overloading.jl @@ -7,3 +7,13 @@ If it exists, return the overloaded input type which will be passed to the diffe This function is experimental and not part of the public API. """ function overloaded_input_type end + +function overloaded_input(::typeof(pushforward), f, backend::AbstractADType, x, tx::NTuple) + throw(ArgumentError("Overloaded input not defined")) +end + +function overloaded_input( + ::typeof(pushforward), f!, y, backend::AbstractADType, x, tx::NTuple +) + throw(ArgumentError("Overloaded input not defined")) +end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 848c14191..85d08cd53 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -172,33 +172,13 @@ function _prepare_hvp_aux( grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient - new_contexts_unknown = ( - FunctionContext(f), - UnknownContext(), - BackendContext(inner(backend)), - Constant(rewrap), - contexts..., - ) inner_gradient_prep = let - xo = overloaded_input( - pushforward, - shuffled_gradient, - outer(backend), - x, - tx, - new_contexts_unknown..., - ) + xo = overloaded_input(pushforward, shuffled_gradient, outer(backend), x, tx) prepare_gradient(f, inner(backend), xo, contexts...) end inner_gradient_in_prep = let xo = overloaded_input( - pushforward, - shuffled_gradient!, - grad_buffer, - outer(backend), - x, - tx, - new_contexts_unknown..., + pushforward, shuffled_gradient!, grad_buffer, outer(backend), x, tx ) prepare_gradient(f, inner(backend), xo, contexts...) end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 2ed868397..6d85ec0b9 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -114,7 +114,7 @@ struct BackendContext{T} <: GeneralizedConstant data::T end -struct PrepContext{T} <: GeneralizedConstant +struct PrepContext{T} <: GeneralizedCache data::T end diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index 63bb2a05f..06b23529b 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -175,15 +175,33 @@ end abstract type InnerPreparationBehavior end +""" + PrepareInnerSimple + +Trait identifying outer backends for which the inner backend in second-order autodiff should be prepared with the same input type. +""" struct PrepareInnerSimple <: InnerPreparationBehavior end + +""" + PrepareInnerOverload + +Trait identifying outer backends for which the inner backend in second-order autodiff should be prepared with an overloaded input type. +""" struct PrepareInnerOverload <: InnerPreparationBehavior end + +""" + DontPrepareInner + +Trait identifying outer backends for which the inner backend in second-order autodiff should not be prepared at all. +""" struct DontPrepareInner <: InnerPreparationBehavior end -inner_preparation_behavior(::AbstractADType) = DontPrepareInner() +""" + inner_preparation_behavior(backend::AbstractADType) -function overloaded_input(optype, f, backend, x, args...) - throw(ArgumentError("Just to appease JET")) -end +Return [`PrepareInnerSimple`](@ref), [`PrepareInnerOverload`](@ref) or [`DontPrepareInner`](@ref) in a statically predictable way. +""" +inner_preparation_behavior(::AbstractADType) = DontPrepareInner() ## Conversions diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 21fb3d297..ba1cf8e37 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -12,8 +12,10 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" +struct MyTag end + backends = [ - AutoPolyesterForwardDiff(; tag=:hello), # + AutoPolyesterForwardDiff(; tag=ForwardDiff.Tag(MyTag(), Float64)), # AutoPolyesterForwardDiff(; chunksize=2), ] @@ -23,7 +25,10 @@ for backend in backends end test_differentiation( - backends, default_scenarios(; include_constantified=true); logging=LOGGING + backends, + default_scenarios(; include_constantified=true, include_cachified=true); + logging=LOGGING, + excluded=SECOND_ORDER, ); @testset "Batch size" begin diff --git a/DifferentiationInterface/test/Core/Internals/_formalities.jl b/DifferentiationInterface/test/Core/Internals/_formalities.jl index e23f8624d..bf851018f 100644 --- a/DifferentiationInterface/test/Core/Internals/_formalities.jl +++ b/DifferentiationInterface/test/Core/Internals/_formalities.jl @@ -3,7 +3,6 @@ using Aqua: Aqua using DifferentiationInterface using ExplicitImports -using ForwardDiff: ForwardDiff using JET: JET using JuliaFormatter: JuliaFormatter using Test @@ -37,7 +36,7 @@ end @test check_all_qualified_accesses_via_owners(DifferentiationInterface) === nothing @test check_no_self_qualified_accesses(DifferentiationInterface) === nothing if VERSION >= v"1.11" - @test check_all_explicit_imports_are_public(DifferentiationInterface) === nothing + @test check_all_explicit_imports_are_public(DifferentiationInterface;) === nothing @test_skip check_all_qualified_accesses_are_public(DifferentiationInterface) === nothing end diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index eb24a0041..7a4a59e26 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -121,3 +121,12 @@ end @test only(column_groups(hess_prep)) == 1:10 end end + +@testset "Misc" begin + @test_throws ArgumentError DifferentiationInterface.overloaded_input( + pushforward, sum, AutoSimpleFiniteDiff(), 1, (1, 2) + ) + @test_throws ArgumentError DifferentiationInterface.overloaded_input( + pushforward, copyto!, [1.0], AutoSimpleFiniteDiff(), [1.0], ([1.0], [1.0]) + ) +end From 606b895a3c6fab40f9718e1f8ae00bf9e90ab2d6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 15 Mar 2025 09:34:58 +0100 Subject: [PATCH 13/15] Improve coverage --- .../DifferentiationInterfaceFiniteDiffExt.jl | 1 + .../DifferentiationInterfaceFiniteDifferencesExt.jl | 1 + DifferentiationInterface/test/Back/FiniteDiff/test.jl | 9 +++++++++ .../test/Back/FiniteDifferences/test.jl | 2 ++ .../test/Back/PolyesterForwardDiff/test.jl | 8 ++++++++ .../test/Core/SimpleFiniteDiff/test.jl | 4 ++-- 6 files changed, 23 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl index d7441adfb..1822b7fe2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl @@ -21,6 +21,7 @@ using FiniteDiff: using LinearAlgebra: dot, mul! DI.check_available(::AutoFiniteDiff) = true +DI.inner_preparation_behavior(::AutoFiniteDiff) = DI.PrepareInnerSimple() # see https://github.com/SciML/ADTypes.jl/issues/33 diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 54dd38501..f2f692002 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -7,6 +7,7 @@ using LinearAlgebra: dot DI.check_available(::AutoFiniteDifferences) = true DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported() +DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple() ## Pushforward diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index 30ef9b38a..aa92743b6 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -15,6 +15,8 @@ LOGGING = get(ENV, "CI", "false") == "false" for backend in [AutoFiniteDiff()] @test check_available(backend) @test check_inplace(backend) + @test DifferentiationInterface.inner_preparation_behavior(backend) isa + DifferentiationInterface.PrepareInnerSimple end @testset "Dense" begin @@ -25,6 +27,13 @@ end logging=LOGGING, ) + test_differentiation( + SecondOrder(AutoFiniteDiff(; relstep=1e-5, absstep=1e-5), AutoFiniteDiff()), + default_scenarios(); + logging=LOGGING, + rtol=1e-2, + ) + test_differentiation( [ AutoFiniteDiff(; relstep=cbrt(eps(Float64))), diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index cb83ab0cf..d512ee8d4 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -13,6 +13,8 @@ LOGGING = get(ENV, "CI", "false") == "false" for backend in [AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1))] @test check_available(backend) @test !check_inplace(backend) + @test DifferentiationInterface.inner_preparation_behavior(backend) isa + DifferentiationInterface.PrepareInnerSimple end test_differentiation( diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index ba1cf8e37..71ea785db 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -22,6 +22,8 @@ backends = [ for backend in backends @test check_available(backend) @test check_inplace(backend) + @test DifferentiationInterface.inner_preparation_behavior(backend) isa + DifferentiationInterface.PrepareInnerOverload end test_differentiation( @@ -31,6 +33,12 @@ test_differentiation( excluded=SECOND_ORDER, ); +test_differentiation( + SecondOrder(AutoPolyesterForwardDiff(), AutoPolyesterForwardDiff()), + default_scenarios(); + logging=LOGGING, +); + @testset "Batch size" begin @test DI.pick_batchsize(AutoPolyesterForwardDiff(), 10) == DI.pick_batchsize(AutoForwardDiff(), 10) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 7a4a59e26..af9a42768 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -16,12 +16,12 @@ backends = [ # second_order_backends = [ # SecondOrder( - AutoSimpleFiniteDiff(; chunksize=5), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize=5)), AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), ), SecondOrder( AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=5)), - AutoSimpleFiniteDiff(; chunksize=4), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), ), ] From 923eca4037bbf251bd1b664238c30b5bb686dafc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 15 Mar 2025 10:18:12 +0100 Subject: [PATCH 14/15] Codecov --- DifferentiationInterface/src/second_order/hvp.jl | 16 ++++++---------- .../test/Core/SimpleFiniteDiff/test.jl | 3 ++- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 85d08cd53..edc1699f8 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -172,16 +172,12 @@ function _prepare_hvp_aux( grad_buffer = similar(x) rewrap = Rewrap(contexts...) # Inner gradient - inner_gradient_prep = let - xo = overloaded_input(pushforward, shuffled_gradient, outer(backend), x, tx) - prepare_gradient(f, inner(backend), xo, contexts...) - end - inner_gradient_in_prep = let - xo = overloaded_input( - pushforward, shuffled_gradient!, grad_buffer, outer(backend), x, tx - ) - prepare_gradient(f, inner(backend), xo, contexts...) - end + xo = overloaded_input(pushforward, shuffled_gradient, outer(backend), x, tx) + xoi = overloaded_input( + pushforward, shuffled_gradient!, grad_buffer, outer(backend), x, tx + ) + inner_gradient_prep = prepare_gradient(f, inner(backend), xo, contexts...) + inner_gradient_in_prep = prepare_gradient(f, inner(backend), xoi, contexts...) # Outer pushforward new_contexts = ( FunctionContext(f), diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index af9a42768..d716f136b 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -11,6 +11,7 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [ # AutoSimpleFiniteDiff(; chunksize=5), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize=4)), ] @@ -68,7 +69,7 @@ end test_differentiation( second_order_hvp_backends; excluded=vcat(FIRST_ORDER, :hessian, :second_derivative), - logging=true, + logging=LOGGING, ) test_differentiation(backends, complex_scenarios(); logging=LOGGING) From 8698e48697c68d8798a877b60f20a1a1903283e6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 15 Mar 2025 11:21:09 +0100 Subject: [PATCH 15/15] Fail fast toggle --- .github/workflows/Test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 13de31d96..f91ce0853 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - "1.10"