diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index bb7cbd4d8..bcc7f30bf 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -91,7 +91,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - '1.10' diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 183b15834..7b1f85fb9 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -320,7 +320,7 @@ end ## One argument -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f::F, pushforward_prep::PushforwardPrep, backend::AbstractADType, @@ -328,12 +328,12 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context, C}, ) where {F, C} - a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) + y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) dx = dot(a, dy) - return dx + return y, dx end -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f::F, pushforward_prep::PushforwardPrep, backend::AbstractADType, @@ -341,13 +341,13 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context, C}, ) where {F, C} - a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) + y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)) dx = real(dot(a, dy)) + im * real(dot(b, dy)) - return dx + return y, dx end -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f::F, pushforward_prep::PushforwardPrep, backend::AbstractADType, @@ -355,14 +355,15 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context, C}, ) where {F, C} + y = f(x, map(unwrap, contexts)...) dx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) dot(a, dy) end - return dx + return y, dx end -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f::F, pushforward_prep::PushforwardPrep, backend::AbstractADType, @@ -370,6 +371,7 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context, C}, ) where {F, C} + y = f(x, map(unwrap, contexts)...) dx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) b = only( @@ -377,7 +379,7 @@ function _pullback_via_pushforward( ) real(dot(a, dy)) + im * real(dot(b, dy)) end - return dx + return y, dx end function value_and_pullback( @@ -390,11 +392,12 @@ function value_and_pullback( ) where {F, B, C} check_prep(f, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep - y = f(x, map(unwrap, contexts)...) - tx = ntuple( - b -> _pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...), + ys_and_tx = ntuple( + b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...), Val(B), ) + y = first(first(ys_and_tx)) + tx = map(last, ys_and_tx) return y, tx end @@ -440,7 +443,7 @@ end ## Two arguments -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f!::F, y, pushforward_prep::PushforwardPrep, @@ -449,12 +452,12 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context, C}, ) where {F, C} - a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) + _, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) dx = dot(a, dy) return dx end -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f!::F, y, pushforward_prep::PushforwardPrep, @@ -464,14 +467,14 @@ function _pullback_via_pushforward( contexts::Vararg{Context, C}, ) where {F, C} a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) - b = only( - pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...) + _, b = onlysecond( + value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...) ) dx = real(dot(a, dy)) + im * real(dot(b, dy)) return dx end -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f!::F, y, pushforward_prep::PushforwardPrep, @@ -481,13 +484,13 @@ function _pullback_via_pushforward( contexts::Vararg{Context, C}, ) where {F, C} dx = map(CartesianIndices(x)) do j # preserve shape - a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) + _, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) dot(a, dy) end return dx end -function _pullback_via_pushforward( +function _value_and_pullback_via_pushforward( f!::F, y, pushforward_prep::PushforwardPrep, @@ -498,8 +501,8 @@ function _pullback_via_pushforward( ) where {F, C} dx = map(CartesianIndices(x)) do j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) - b = only( - pushforward( + _, b = onlysecond( + value_and_pushforward( f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts... ), ) @@ -520,12 +523,11 @@ function value_and_pullback( check_prep(f!, y, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep tx = ntuple( - b -> _pullback_via_pushforward( + b -> _value_and_pullback_via_pushforward( f!, y, pushforward_prep, backend, x, ty[b], contexts... ), Val(B), ) - f!(y, x, map(unwrap, contexts)...) return y, tx end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index d3922bc0a..295f98145 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -252,9 +252,10 @@ end ## Preparation -struct PullbackPushforwardPrep{SIG, E} <: PushforwardPrep{SIG} +struct PullbackPushforwardPrep{SIG, E, Y} <: PushforwardPrep{SIG} _sig::Val{SIG} pullback_prep::E + y_example::Y end function prepare_pushforward_nokwarg( @@ -296,7 +297,7 @@ function _prepare_pushforward_aux( basis(y) end pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...) - return PullbackPushforwardPrep(_sig, pullback_prep) + return PullbackPushforwardPrep(_sig, pullback_prep, y) end function _prepare_pushforward_aux( @@ -312,13 +313,13 @@ function _prepare_pushforward_aux( _sig = signature(f!, y, backend, x, tx, contexts...; strict) dy = basis(y) pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...) - return PullbackPushforwardPrep(_sig, pullback_prep) + return PullbackPushforwardPrep(_sig, pullback_prep, y) end ## One argument -function _pushforward_via_pullback( - y::Number, +function _value_and_pushforward_via_pullback( + y_ex::Number, f::F, pullback_prep::PullbackPrep, backend::AbstractADType, @@ -326,13 +327,13 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context, C}, ) where {F, C} - a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...)) + y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...)) dy = dot(a, dx) - return dy + return y, dy end -function _pushforward_via_pullback( - y::Complex, +function _value_and_pushforward_via_pullback( + y_ex::Complex, f::F, pullback_prep::PullbackPrep, backend::AbstractADType, @@ -340,14 +341,14 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context, C}, ) where {F, C} - a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...)) - b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y),), contexts...)) + y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...)) + b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...)) dy = real(dot(a, dx)) + im * real(dot(b, dx)) - return dy + return y, dy end -function _pushforward_via_pullback( - y::AbstractArray{<:Real}, +function _value_and_pushforward_via_pullback( + y_ex::AbstractArray{<:Real}, f::F, pullback_prep::PullbackPrep, backend::AbstractADType, @@ -355,15 +356,16 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context, C}, ) where {F, C} - dy = map(CartesianIndices(y)) do i - a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)) + y = f(x, map(unwrap, contexts)...) + dy = map(CartesianIndices(y_ex)) do i + a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...)) dot(a, dx) end - return dy + return y, dy end -function _pushforward_via_pullback( - y::AbstractArray{<:Complex}, +function _value_and_pushforward_via_pullback( + y_ex::AbstractArray{<:Complex}, f::F, pullback_prep::PullbackPrep, backend::AbstractADType, @@ -371,12 +373,13 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context, C}, ) where {F, C} - dy = map(CartesianIndices(y)) do i - a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)) - b = only(pullback(f, pullback_prep, backend, x, (im * basis(y, i),), contexts...)) + y = f(x, map(unwrap, contexts)...) + dy = map(CartesianIndices(y_ex)) do i + a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...)) + b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...)) real(dot(a, dx)) + im * real(dot(b, dx)) end - return dy + return y, dy end function value_and_pushforward( @@ -388,12 +391,13 @@ function value_and_pushforward( contexts::Vararg{Context, C}, ) where {F, B, C} check_prep(f, prep, backend, x, tx, contexts...) - (; pullback_prep) = prep - y = f(x, map(unwrap, contexts)...) - ty = ntuple( - b -> _pushforward_via_pullback(y, f, pullback_prep, backend, x, tx[b], contexts...), + (; pullback_prep, y_example) = prep + ys_and_ty = ntuple( + b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...), Val(B), ) + y = first(first(ys_and_ty)) + ty = map(last, ys_and_ty) return y, ty end @@ -439,7 +443,7 @@ end ## Two arguments -function _pushforward_via_pullback( +function _value_and_pushforward_via_pullback( f!::F, y::AbstractArray{<:Real}, pullback_prep::PullbackPrep, @@ -449,13 +453,13 @@ function _pushforward_via_pullback( contexts::Vararg{Context, C}, ) where {F, C} dy = map(CartesianIndices(y)) do i # preserve shape - a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) + _, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) dot(a, dx) end return dy end -function _pushforward_via_pullback( +function _value_and_pushforward_via_pullback( f!::F, y::AbstractArray{<:Complex}, pullback_prep::PullbackPrep, @@ -466,8 +470,8 @@ function _pushforward_via_pullback( ) where {F, C} dy = map(CartesianIndices(y)) do i # preserve shape a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) - b = only( - pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...) + _, b = onlysecond( + value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...) ) real(dot(a, dx)) + im * real(dot(b, dx)) end @@ -487,10 +491,9 @@ function value_and_pushforward( (; pullback_prep) = prep ty = ntuple( b -> - _pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...), + _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...), Val(B), ) - f!(y, x, map(unwrap, contexts)...) return y, ty end diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index 0be85a3ee..9b952ec2c 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -33,3 +33,5 @@ Only specialized on `SparseMatrixCSC` because it is used with symbolic backends, The trivial dense fallback is designed to protect against a change of format in these packages. """ get_pattern(M::AbstractMatrix) = trues(size(M)) + +onlysecond((a, b)) = (a, only(b))