diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index d4058e7ac..40d96972c 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -42,6 +42,7 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"] DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" +DifferentiationInterfaceMooncakeStaticArraysExt = ["Mooncake", "StaticArrays"] DifferentiationInterfacePolyesterForwardDiffExt = [ "PolyesterForwardDiff", "ForwardDiff", @@ -71,7 +72,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.5.1" +Mooncake = "0.5.27" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 3513d548c..8d1fc46f3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -30,7 +30,7 @@ using Mooncake: primal, _copy_output, _copy_to_output!!, - tangent_to_primal!! + tangent_to_friendly!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index c470b6473..447f84b6b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -41,7 +41,8 @@ function DI.value_and_pushforward( map(first_unwrap, contexts, prep.context_tangents)..., ) y = first(y_and_dy) - dy = _copy_output(last(y_and_dy)) + dy_raw = last(y_and_dy) + dy = @something(_to_friendly_value(dy_raw, y), _copy_output(dy_raw)) return y, dy end y = _copy_output(first(ys_and_ty[1])) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 3c75f530b..a6354a0f7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -55,7 +55,7 @@ function DI.value_and_pushforward( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - return _copy_output(new_dy) + return @something(_to_friendly_value(new_dy, y), _copy_output(new_dy)) end return y, ty end @@ -93,7 +93,7 @@ function DI.value_and_pushforward!( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - copyto!(dy, new_dy) + copyto!(dy, something(_to_friendly_value(new_dy, y), new_dy)) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 2514cdc40..61f5eb837 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -35,7 +35,7 @@ function DI.value_and_pullback( new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return new_y, (_copy_output(new_dx),) + return new_y, (@something(_to_friendly_value(new_dx, x), _copy_output(new_dx)),) end function DI.value_and_pullback( @@ -51,7 +51,7 @@ function DI.value_and_pullback( y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - y, _copy_output(new_dx) + y, @something(_to_friendly_value(new_dx, x), _copy_output(new_dx)) end y = first(ys_and_tx[1]) tx = map(last, ys_and_tx) @@ -134,7 +134,7 @@ function DI.value_and_gradient( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return y, _copy_output(new_grad) + return y, @something(_to_friendly_value(new_grad, x), _copy_output(new_grad)) end function DI.value_and_gradient!( @@ -150,7 +150,7 @@ function DI.value_and_gradient!( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - copyto!(grad, new_grad) + copyto!(grad, something(_to_friendly_value(new_grad, x), new_grad)) return y, grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2b55131b9..4dbc6d742 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -64,7 +64,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - return y, (_copy_output(dx),) + return y, (@something(_to_friendly_value(dx, x), _copy_output(dx)),) end function DI.value_and_pullback( @@ -90,7 +90,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - _copy_output(dx) + @something(_to_friendly_value(dx, x), _copy_output(dx)) end return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..11119face 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -9,11 +9,19 @@ function call_and_return(f!::F, y, x, contexts...) where {F} return y end +# Hook for bridging primal types whose `friendly_tangent_cache` falls through to +# `:as_raw` in Mooncake's framework, leaking a raw `Tangent` / `MutableTangent` +# instead of a primal-shaped value. The default returns `nothing`; specialised +# methods are loaded by triple-extensions when the relevant primal-type packages +# are available (see DifferentiationInterfaceMooncakeStaticArraysExt for the +# `SArray` / `MArray` case). +_to_friendly_value(t, x) = nothing + function zero_tangent_or_primal(x, backend::AnyAutoMooncake) + zt = zero_tangent(x) if get_config(backend).friendly_tangents - # zero(x) but safer - return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) + return @something(_to_friendly_value(zt, x), tangent_to_friendly!!(x, zt)) else - return zero_tangent(x) + return zt end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl new file mode 100644 index 000000000..797445b6f --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl @@ -0,0 +1,36 @@ +module DifferentiationInterfaceMooncakeStaticArraysExt + +using Base: IEEEFloat +using DifferentiationInterface: DifferentiationInterface +using Mooncake: Mooncake +using StaticArrays: MArray, SArray + +# Reach into the binary MooncakeExt to extend its `_to_friendly_value` hook. +# Both Mooncake and StaticArrays are loaded whenever this extension is active, +# so the binary extension is guaranteed to be loaded as well. +const _MooncakeExt = Base.get_extension( + DifferentiationInterface, :DifferentiationInterfaceMooncakeExt, +) + +# Restrict to scalar float / complex-float eltypes: those are the layouts where +# Mooncake's framework sends `SArray` / `MArray` through `AsRaw` and the single +# `data::NTuple` element-to-position mapping is 1:1 (no aliasing). Non-float +# eltypes hit Mooncake's element-wise `AbstractArray` recursion at +# `Mooncake/src/tangents/tangents.jl:1453`; let that path run unimpeded. +const _StaticEltype = Union{IEEEFloat, Complex{<:IEEEFloat}} + +# Mooncake's `friendly_tangent_cache` framework defaults to `:as_raw` for +# `SArray` / `MArray` primals with float eltype, leaking a raw `Tangent` / +# `MutableTangent` instead of a primal-shaped value. Bridge that gap here. The +# reconstruction is unambiguous because the `data::NTuple` field maps each +# element to one logical position (unlike `Symmetric` / `Hermitian`, where a +# single stored entry can represent two positions). +@inline _MooncakeExt._to_friendly_value( + t::Mooncake.Tangent, x::SArray{S, T} +) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data)) + +@inline _MooncakeExt._to_friendly_value( + t::Mooncake.MutableTangent, x::MArray{S, T} +) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data)) + +end # module diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index d531e542a..7555752af 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -1,6 +1,7 @@ include("../../testutils.jl") using DifferentiationInterface, DifferentiationInterfaceTest +using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric using Mooncake: Mooncake using Test @@ -78,5 +79,110 @@ test_differentiation( backends[3:4], nomatrix(static_scenarios()); logging = LOGGING, - excluded = SECOND_ORDER + excluded = SECOND_ORDER, ) + +@testset "Friendly tangents structured matrices" begin + # Mooncake 0.5.25+ returns a plain `Matrix` for structured inputs under + # `friendly_tangents=true` (chalk-lab/Mooncake.jl#1103); the complex case + # follows the standard reverse-mode convention (chalk-lab/Mooncake.jl#773). + # + # Per-wrapper test functions are chosen for their non-triviality given + # that matmul on small matrices hits a `utf8proc_isupper` ccall that + # Mooncake cannot differentiate (LinearAlgebra._matmul2x2_elements → + # WrapperChar). For the real wrappers we use a manual triple-loop tr(X^3) + # whose unrestricted gradient is 3·X²; for Hermitian we use the simpler + # abs2-sum because the complex Wirtinger ground truth via tr(X^3) is + # convention-heavy. The expected friendly gradient is then computed by + # aggregating the unrestricted per-element gradient into the wrapper's + # canonical storage cells, derived independently of Mooncake. + + # tr(X^3) without matmul. Indices i,j,k each range over axes(X,1). + function tr_x3(X) + s = zero(eltype(X)) + n = size(X, 1) + @inbounds for i in 1:n, j in 1:n, k in 1:n + s += X[i, j] * X[j, k] * X[k, i] + end + return real(s) + end + + # Symmetric storage: upper triangle holds the sum of (i,j) and (j,i) + # per-element contributions; strict lower triangle is zero. + function aggregate_symmetric(G) + n = size(G, 1) + H = zero(G) + @inbounds for i in 1:n + H[i, i] = G[i, i] + for j in (i + 1):n + H[i, j] = G[i, j] + G[j, i] + end + end + return H + end + + # SymTridiagonal storage: diagonal + symmetric off-diagonals (both + # `(i,i+1)` and `(i+1,i)` slots hold the doubled contribution); entries + # outside that band are structurally zero in the wrapper. + function aggregate_symtridiagonal(G) + n = size(G, 1) + H = zero(G) + @inbounds for i in 1:n + H[i, i] = G[i, i] + if i < n + aggregated = G[i, i + 1] + G[i + 1, i] + H[i, i + 1] = aggregated + H[i + 1, i] = aggregated + end + end + return H + end + + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + abs2_sum(x) = real(sum(abs2, x)) + cases = ( + ( + x = Symmetric([2.0 1.0; 1.0 3.0]), + f = tr_x3, + expected_grad = let M = Matrix(Symmetric([2.0 1.0; 1.0 3.0])) + aggregate_symmetric(3 * M^2) + end, + ), + ( + x = Hermitian(ComplexF64[2 1 + im; 1 - im 3]), + f = abs2_sum, + expected_grad = ComplexF64[4 4 + 4im; 0 6], + ), + ( + x = SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), + f = tr_x3, + expected_grad = let M = Matrix(SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0])) + aggregate_symtridiagonal(3 * M^2) + end, + ), + ) + + @testset "$(typeof(case.x))" for case in cases + x = case.x + f = case.f + grad = gradient(f, backend, x) + y, grad2 = value_and_gradient(f, backend, x) + pb = only(pullback(identity, backend, x, (x,))) + + @test grad isa Matrix + @test grad2 isa Matrix + @test pb isa Matrix + @test grad == grad2 + @test grad ≈ case.expected_grad + @test y ≈ f(x) + @test pb == Matrix(x) + + grad_dense = zero(Matrix(x)) + @test gradient!(f, grad_dense, backend, x) === grad_dense + @test grad_dense == grad + + tx_dense = (zero(Matrix(x)),) + @test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1] + @test tx_dense[1] == pb + end +end