Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ jobs:
actions: write
contents: read
strategy:
fail-fast: true
fail-fast: false # TODO: toggle
Comment thread
gdalle marked this conversation as resolved.
Outdated
matrix:
version:
- '1.10'
Expand Down
104 changes: 55 additions & 49 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,61 +325,69 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Real,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
dx = dot(a, dy)
return y, dx
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Complex,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...))
dx = real(dot(a, dy)) + im * real(dot(b, dy))
return y, dx
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Real},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dx = map(CartesianIndices(x)) do j
tx = map(CartesianIndices(x)) do j
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
dot(a, dy)
map(ty) do dy
dot(a, dy)
end
end
return y, dx
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Complex},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dx = map(CartesianIndices(x)) do j
tx = map(CartesianIndices(x)) do j
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
b = only(
pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...),
)
real(dot(a, dy)) + im * real(dot(b, dy))
map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, dx
return y, arroftup_to_tupofarr(tx)
end

function value_and_pullback(
Expand All @@ -392,13 +400,7 @@ function value_and_pullback(
) where {F, B, C}
check_prep(f, prep, backend, x, ty, contexts...)
(; pushforward_prep) = prep
ys_and_tx = ntuple(
b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...),
Val(B),
)
y = first(first(ys_and_tx))
tx = map(last, ys_and_tx)
return y, tx
return _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty, contexts...)
end

function value_and_pullback!(
Expand Down Expand Up @@ -449,12 +451,14 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Real,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
dx = dot(a, dy)
return dx
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -463,15 +467,17 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Complex,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
_, b = onlysecond(
value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
)
dx = real(dot(a, dy)) + im * real(dot(b, dy))
return dx
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -480,14 +486,16 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Real},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dx = map(CartesianIndices(x)) do j # preserve shape
) where {F, B, C}
tx = map(CartesianIndices(x)) do j # preserve shape
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
dot(a, dy)
map(ty) do dy
dot(a, dy)
end
end
return dx
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -496,19 +504,21 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Complex},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dx = map(CartesianIndices(x)) do j # preserve shape
) where {F, B, C}
tx = map(CartesianIndices(x)) do j # preserve shape
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
_, b = onlysecond(
value_and_pushforward(
f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts...
),
)
real(dot(a, dy)) + im * real(dot(b, dy))
map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return dx
return y, arroftup_to_tupofarr(tx)
end

function value_and_pullback(
Expand All @@ -522,13 +532,9 @@ function value_and_pullback(
) where {F, B, C}
check_prep(f!, y, prep, backend, x, ty, contexts...)
(; pushforward_prep) = prep
tx = ntuple(
b -> _value_and_pullback_via_pushforward(
f!, y, pushforward_prep, backend, x, ty[b], contexts...
),
Val(B),
return _value_and_pullback_via_pushforward(
f!, y, pushforward_prep, backend, x, ty, contexts...
)
return y, tx
end

function value_and_pullback!(
Expand Down
83 changes: 42 additions & 41 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,14 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
dy = dot(a, dx)
return y, dy
ty = map(tx) do dx
dot(a, dx)
end
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -338,13 +340,15 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...))
dy = real(dot(a, dx)) + im * real(dot(b, dx))
return y, dy
ty = map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -353,15 +357,17 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dy = map(CartesianIndices(y_ex)) do i
ty = map(CartesianIndices(y_ex)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
dot(a, dx)
map(tx) do dx
dot(a, dx)
end
end
return y, dy
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -370,16 +376,18 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dy = map(CartesianIndices(y_ex)) do i
ty = map(CartesianIndices(y_ex)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...))
real(dot(a, dx)) + im * real(dot(b, dx))
map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, dy
return y, arroftup_to_tupofarr(ty)
end

function value_and_pushforward(
Expand All @@ -392,13 +400,7 @@ function value_and_pushforward(
) where {F, B, C}
check_prep(f, prep, backend, x, tx, contexts...)
(; pullback_prep, y_example) = prep
ys_and_ty = ntuple(
b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...),
Val(B),
)
y = first(first(ys_and_ty))
ty = map(last, ys_and_ty)
return y, ty
return _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx, contexts...)
end

function value_and_pushforward!(
Expand Down Expand Up @@ -449,14 +451,16 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dy = map(CartesianIndices(y)) do i # preserve shape
) where {F, B, C}
ty = map(CartesianIndices(y)) do i # preserve shape
_, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
dot(a, dx)
map(tx) do dx
dot(a, dx)
end
end
return dy
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -465,17 +469,19 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dy = map(CartesianIndices(y)) do i # preserve shape
) where {F, B, C}
ty = map(CartesianIndices(y)) do i # preserve shape
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
_, b = onlysecond(
value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
)
real(dot(a, dx)) + im * real(dot(b, dx))
map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return dy
return y, arroftup_to_tupofarr(ty)
end

function value_and_pushforward(
Expand All @@ -489,12 +495,7 @@ function value_and_pushforward(
) where {F, B, C}
check_prep(f!, y, prep, backend, x, tx, contexts...)
(; pullback_prep) = prep
ty = ntuple(
b ->
_value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...),
Val(B),
)
return y, ty
return _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx, contexts...)
end

function value_and_pushforward!(
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ The trivial dense fallback is designed to protect against a change of format in
get_pattern(M::AbstractMatrix) = trues(size(M))

onlysecond((a, b)) = (a, only(b))

arroftup_to_tupofarr(x::NTuple) = x
arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B))
Loading