Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"]
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfaceMooncakeStaticArraysExt = ["Mooncake", "StaticArrays"]
DifferentiationInterfacePolyesterForwardDiffExt = [
"PolyesterForwardDiff",
"ForwardDiff",
Expand Down Expand Up @@ -71,7 +72,7 @@ ForwardDiff = "0.10.36,1"
GPUArraysCore = "0.2"
GTPSA = "1.4.0"
LinearAlgebra = "1"
Mooncake = "0.5.1"
Mooncake = "0.5.27"
PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using Mooncake:
primal,
_copy_output,
_copy_to_output!!,
tangent_to_primal!!
tangent_to_friendly!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ function DI.value_and_pushforward(
map(first_unwrap, contexts, prep.context_tangents)...,
)
y = first(y_and_dy)
dy = _copy_output(last(y_and_dy))
dy_raw = last(y_and_dy)
dy = @something(_to_friendly_value(dy_raw, y), _copy_output(dy_raw))
return y, dy
end
y = _copy_output(first(ys_and_ty[1]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function DI.value_and_pushforward(
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
return _copy_output(new_dy)
return @something(_to_friendly_value(new_dy, y), _copy_output(new_dy))
end
return y, ty
end
Expand Down Expand Up @@ -93,7 +93,7 @@ function DI.value_and_pushforward!(
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
copyto!(dy, new_dy)
copyto!(dy, something(_to_friendly_value(new_dy, y), new_dy))
end
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function DI.value_and_pullback(
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
return new_y, (_copy_output(new_dx),)
return new_y, (@something(_to_friendly_value(new_dx, x), _copy_output(new_dx)),)
end

function DI.value_and_pullback(
Expand All @@ -51,7 +51,7 @@ function DI.value_and_pullback(
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
y, _copy_output(new_dx)
y, @something(_to_friendly_value(new_dx, x), _copy_output(new_dx))
end
y = first(ys_and_tx[1])
tx = map(last, ys_and_tx)
Expand Down Expand Up @@ -134,7 +134,7 @@ function DI.value_and_gradient(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
return y, _copy_output(new_grad)
return y, @something(_to_friendly_value(new_grad, x), _copy_output(new_grad))
end

function DI.value_and_gradient!(
Expand All @@ -150,7 +150,7 @@ function DI.value_and_gradient!(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
copyto!(grad, new_grad)
copyto!(grad, something(_to_friendly_value(new_grad, x), new_grad))
return y, grad
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function DI.value_and_pullback(
prep.args_to_zero
)
copyto!(y, y_after)
return y, (_copy_output(dx),)
return y, (@something(_to_friendly_value(dx, x), _copy_output(dx)),)
end

function DI.value_and_pullback(
Expand All @@ -90,7 +90,7 @@ function DI.value_and_pullback(
prep.args_to_zero
)
copyto!(y, y_after)
_copy_output(dx)
@something(_to_friendly_value(dx, x), _copy_output(dx))
end
return y, tx
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@ function call_and_return(f!::F, y, x, contexts...) where {F}
return y
end

# Hook for bridging primal types whose `friendly_tangent_cache` falls through to
# `:as_raw` in Mooncake's framework, leaking a raw `Tangent` / `MutableTangent`
# instead of a primal-shaped value. The default returns `nothing`; specialised
# methods are loaded by triple-extensions when the relevant primal-type packages
# are available (see DifferentiationInterfaceMooncakeStaticArraysExt for the
# `SArray` / `MArray` case).
_to_friendly_value(t, x) = nothing

function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
zt = zero_tangent(x)
if get_config(backend).friendly_tangents
# zero(x) but safer
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
return @something(_to_friendly_value(zt, x), tangent_to_friendly!!(x, zt))
else
return zero_tangent(x)
return zt
end
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module DifferentiationInterfaceMooncakeStaticArraysExt

using Base: IEEEFloat
using DifferentiationInterface: DifferentiationInterface
using Mooncake: Mooncake
using StaticArrays: MArray, SArray

# Reach into the binary MooncakeExt to extend its `_to_friendly_value` hook.
# Both Mooncake and StaticArrays are loaded whenever this extension is active,
# so the binary extension is guaranteed to be loaded as well.
const _MooncakeExt = Base.get_extension(
DifferentiationInterface, :DifferentiationInterfaceMooncakeExt,
)

# Restrict to scalar float / complex-float eltypes: those are the layouts where
# Mooncake's framework sends `SArray` / `MArray` through `AsRaw` and the single
# `data::NTuple` element-to-position mapping is 1:1 (no aliasing). Non-float
# eltypes hit Mooncake's element-wise `AbstractArray` recursion at
# `Mooncake/src/tangents/tangents.jl:1453`; let that path run unimpeded.
const _StaticEltype = Union{IEEEFloat, Complex{<:IEEEFloat}}

# Mooncake's `friendly_tangent_cache` framework defaults to `:as_raw` for
# `SArray` / `MArray` primals with float eltype, leaking a raw `Tangent` /
# `MutableTangent` instead of a primal-shaped value. Bridge that gap here. The
# reconstruction is unambiguous because the `data::NTuple` field maps each
# element to one logical position (unlike `Symmetric` / `Hermitian`, where a
# single stored entry can represent two positions).
@inline _MooncakeExt._to_friendly_value(
t::Mooncake.Tangent, x::SArray{S, T}
) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data))

@inline _MooncakeExt._to_friendly_value(
t::Mooncake.MutableTangent, x::MArray{S, T}
) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data))

end # module
108 changes: 107 additions & 1 deletion DifferentiationInterface/test/Back/Mooncake/test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include("../../testutils.jl")

using DifferentiationInterface, DifferentiationInterfaceTest
using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric
using Mooncake: Mooncake
using Test

Expand Down Expand Up @@ -78,5 +79,110 @@ test_differentiation(
backends[3:4],
nomatrix(static_scenarios());
logging = LOGGING,
excluded = SECOND_ORDER
excluded = SECOND_ORDER,
)

@testset "Friendly tangents structured matrices" begin
# Mooncake 0.5.25+ returns a plain `Matrix` for structured inputs under
# `friendly_tangents=true` (chalk-lab/Mooncake.jl#1103); the complex case
# follows the standard reverse-mode convention (chalk-lab/Mooncake.jl#773).
#
# Per-wrapper test functions are chosen for their non-triviality given
# that matmul on small matrices hits a `utf8proc_isupper` ccall that
# Mooncake cannot differentiate (LinearAlgebra._matmul2x2_elements →
# WrapperChar). For the real wrappers we use a manual triple-loop tr(X^3)
# whose unrestricted gradient is 3·X²; for Hermitian we use the simpler
# abs2-sum because the complex Wirtinger ground truth via tr(X^3) is
# convention-heavy. The expected friendly gradient is then computed by
# aggregating the unrestricted per-element gradient into the wrapper's
# canonical storage cells, derived independently of Mooncake.

# tr(X^3) without matmul. Indices i,j,k each range over axes(X,1).
function tr_x3(X)
s = zero(eltype(X))
n = size(X, 1)
@inbounds for i in 1:n, j in 1:n, k in 1:n
s += X[i, j] * X[j, k] * X[k, i]
end
return real(s)
end

# Symmetric storage: upper triangle holds the sum of (i,j) and (j,i)
# per-element contributions; strict lower triangle is zero.
function aggregate_symmetric(G)
n = size(G, 1)
H = zero(G)
@inbounds for i in 1:n
H[i, i] = G[i, i]
for j in (i + 1):n
H[i, j] = G[i, j] + G[j, i]
end
end
return H
end

# SymTridiagonal storage: diagonal + symmetric off-diagonals (both
# `(i,i+1)` and `(i+1,i)` slots hold the doubled contribution); entries
# outside that band are structurally zero in the wrapper.
function aggregate_symtridiagonal(G)
n = size(G, 1)
H = zero(G)
@inbounds for i in 1:n
H[i, i] = G[i, i]
if i < n
aggregated = G[i, i + 1] + G[i + 1, i]
H[i, i + 1] = aggregated
H[i + 1, i] = aggregated
end
end
return H
end

backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
abs2_sum(x) = real(sum(abs2, x))
cases = (
(
x = Symmetric([2.0 1.0; 1.0 3.0]),
f = tr_x3,
expected_grad = let M = Matrix(Symmetric([2.0 1.0; 1.0 3.0]))
aggregate_symmetric(3 * M^2)
end,
),
(
x = Hermitian(ComplexF64[2 1 + im; 1 - im 3]),
f = abs2_sum,
expected_grad = ComplexF64[4 4 + 4im; 0 6],
),
(
x = SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]),
f = tr_x3,
expected_grad = let M = Matrix(SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]))
aggregate_symtridiagonal(3 * M^2)
end,
),
)

@testset "$(typeof(case.x))" for case in cases
x = case.x
f = case.f
grad = gradient(f, backend, x)
y, grad2 = value_and_gradient(f, backend, x)
pb = only(pullback(identity, backend, x, (x,)))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a strong enough test, the function is too simple


@test grad isa Matrix
@test grad2 isa Matrix
@test pb isa Matrix
@test grad == grad2
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad and grad2 are never compared against the ground truth

@test grad ≈ case.expected_grad
@test y ≈ f(x)
@test pb == Matrix(x)

grad_dense = zero(Matrix(x))
@test gradient!(f, grad_dense, backend, x) === grad_dense
@test grad_dense == grad

tx_dense = (zero(Matrix(x)),)
@test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1]
@test tx_dense[1] == pb
end
end