Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ jobs:
actions: write
contents: read
strategy:
fail-fast: false # TODO: toggle
fail-fast: true # TODO: toggle
matrix:
version:
- '1.10'
Expand Down
52 changes: 27 additions & 25 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,64 +320,66 @@ end

## One argument

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Real,
dy,
contexts::Vararg{Context, C},
) where {F, C}
a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
dx = dot(a, dy)
return dx
return y, dx
end

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Complex,
dy,
contexts::Vararg{Context, C},
) where {F, C}
a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...))
dx = real(dot(a, dy)) + im * real(dot(b, dy))
return dx
return y, dx
end

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Real},
dy,
contexts::Vararg{Context, C},
) where {F, C}
y = f(x, map(unwrap, contexts)...)
dx = map(CartesianIndices(x)) do j
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
dot(a, dy)
end
return dx
return y, dx
end

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Complex},
dy,
contexts::Vararg{Context, C},
) where {F, C}
y = f(x, map(unwrap, contexts)...)
dx = map(CartesianIndices(x)) do j
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
b = only(
pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...),
)
real(dot(a, dy)) + im * real(dot(b, dy))
end
return dx
return y, dx
end

function value_and_pullback(
Expand All @@ -390,11 +392,12 @@ function value_and_pullback(
) where {F, B, C}
check_prep(f, prep, backend, x, ty, contexts...)
(; pushforward_prep) = prep
y = f(x, map(unwrap, contexts)...)
tx = ntuple(
b -> _pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...),
ys_and_tx = ntuple(
b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...),
Val(B),
)
y = first(first(ys_and_tx))
tx = map(last, ys_and_tx)
return y, tx
end

Expand Down Expand Up @@ -440,7 +443,7 @@ end

## Two arguments

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f!::F,
y,
pushforward_prep::PushforwardPrep,
Expand All @@ -449,12 +452,12 @@ function _pullback_via_pushforward(
dy,
contexts::Vararg{Context, C},
) where {F, C}
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
dx = dot(a, dy)
return dx
end

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f!::F,
y,
pushforward_prep::PushforwardPrep,
Expand All @@ -464,14 +467,14 @@ function _pullback_via_pushforward(
contexts::Vararg{Context, C},
) where {F, C}
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
b = only(
pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
_, b = onlysecond(
value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
)
dx = real(dot(a, dy)) + im * real(dot(b, dy))
return dx
end

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f!::F,
y,
pushforward_prep::PushforwardPrep,
Expand All @@ -481,13 +484,13 @@ function _pullback_via_pushforward(
contexts::Vararg{Context, C},
) where {F, C}
dx = map(CartesianIndices(x)) do j # preserve shape
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
dot(a, dy)
end
return dx
end

function _pullback_via_pushforward(
function _value_and_pullback_via_pushforward(
f!::F,
y,
pushforward_prep::PushforwardPrep,
Expand All @@ -498,8 +501,8 @@ function _pullback_via_pushforward(
) where {F, C}
dx = map(CartesianIndices(x)) do j # preserve shape
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
b = only(
pushforward(
_, b = onlysecond(
value_and_pushforward(
f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts...
),
)
Expand All @@ -520,12 +523,11 @@ function value_and_pullback(
check_prep(f!, y, prep, backend, x, ty, contexts...)
(; pushforward_prep) = prep
tx = ntuple(
b -> _pullback_via_pushforward(
b -> _value_and_pullback_via_pushforward(
f!, y, pushforward_prep, backend, x, ty[b], contexts...
),
Val(B),
)
f!(y, x, map(unwrap, contexts)...)
return y, tx
end

Expand Down
71 changes: 37 additions & 34 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,10 @@ end

## Preparation

struct PullbackPushforwardPrep{SIG, E} <: PushforwardPrep{SIG}
struct PullbackPushforwardPrep{SIG, E, Y} <: PushforwardPrep{SIG}
_sig::Val{SIG}
pullback_prep::E
y_example::Y
end

function prepare_pushforward_nokwarg(
Expand Down Expand Up @@ -296,7 +297,7 @@ function _prepare_pushforward_aux(
basis(y)
end
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
return PullbackPushforwardPrep(_sig, pullback_prep)
return PullbackPushforwardPrep(_sig, pullback_prep, y)
end

function _prepare_pushforward_aux(
Expand All @@ -312,71 +313,73 @@ function _prepare_pushforward_aux(
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
dy = basis(y)
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
return PullbackPushforwardPrep(_sig, pullback_prep)
return PullbackPushforwardPrep(_sig, pullback_prep, y)
end

## One argument

function _pushforward_via_pullback(
y::Number,
function _value_and_pushforward_via_pullback(
y_ex::Number,
f::F,
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
contexts::Vararg{Context, C},
) where {F, C}
a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...))
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
dy = dot(a, dx)
return dy
return y, dy
end

function _pushforward_via_pullback(
y::Complex,
function _value_and_pushforward_via_pullback(
y_ex::Complex,
f::F,
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
contexts::Vararg{Context, C},
) where {F, C}
a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y),), contexts...))
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...))
dy = real(dot(a, dx)) + im * real(dot(b, dx))
return dy
return y, dy
end

function _pushforward_via_pullback(
y::AbstractArray{<:Real},
function _value_and_pushforward_via_pullback(
y_ex::AbstractArray{<:Real},
f::F,
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
contexts::Vararg{Context, C},
) where {F, C}
dy = map(CartesianIndices(y)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...))
y = f(x, map(unwrap, contexts)...)
dy = map(CartesianIndices(y_ex)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
dot(a, dx)
end
return dy
return y, dy
end

function _pushforward_via_pullback(
y::AbstractArray{<:Complex},
function _value_and_pushforward_via_pullback(
y_ex::AbstractArray{<:Complex},
f::F,
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
contexts::Vararg{Context, C},
) where {F, C}
dy = map(CartesianIndices(y)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y, i),), contexts...))
y = f(x, map(unwrap, contexts)...)
dy = map(CartesianIndices(y_ex)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...))
real(dot(a, dx)) + im * real(dot(b, dx))
end
return dy
return y, dy
end

function value_and_pushforward(
Expand All @@ -388,12 +391,13 @@ function value_and_pushforward(
contexts::Vararg{Context, C},
) where {F, B, C}
check_prep(f, prep, backend, x, tx, contexts...)
(; pullback_prep) = prep
y = f(x, map(unwrap, contexts)...)
ty = ntuple(
b -> _pushforward_via_pullback(y, f, pullback_prep, backend, x, tx[b], contexts...),
(; pullback_prep, y_example) = prep
ys_and_ty = ntuple(
b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...),
Val(B),
)
y = first(first(ys_and_ty))
ty = map(last, ys_and_ty)
return y, ty
end

Expand Down Expand Up @@ -439,7 +443,7 @@ end

## Two arguments

function _pushforward_via_pullback(
function _value_and_pushforward_via_pullback(
f!::F,
y::AbstractArray{<:Real},
pullback_prep::PullbackPrep,
Expand All @@ -449,13 +453,13 @@ function _pushforward_via_pullback(
contexts::Vararg{Context, C},
) where {F, C}
dy = map(CartesianIndices(y)) do i # preserve shape
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
_, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
dot(a, dx)
end
return dy
end

function _pushforward_via_pullback(
function _value_and_pushforward_via_pullback(
f!::F,
y::AbstractArray{<:Complex},
pullback_prep::PullbackPrep,
Expand All @@ -466,8 +470,8 @@ function _pushforward_via_pullback(
) where {F, C}
dy = map(CartesianIndices(y)) do i # preserve shape
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
b = only(
pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
_, b = onlysecond(
value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
)
real(dot(a, dx)) + im * real(dot(b, dx))
end
Expand All @@ -487,10 +491,9 @@ function value_and_pushforward(
(; pullback_prep) = prep
ty = ntuple(
b ->
_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...),
_value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...),
Val(B),
)
f!(y, x, map(unwrap, contexts)...)
return y, ty
end

Expand Down
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ Only specialized on `SparseMatrixCSC` because it is used with symbolic backends,
The trivial dense fallback is designed to protect against a change of format in these packages.
"""
get_pattern(M::AbstractMatrix) = trues(size(M))

onlysecond((a, b)) = (a, only(b))
Loading