From 75d3c056ce09314359a4ab5eea8b10c3572559fa Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 4 Apr 2026 15:37:41 +0100 Subject: [PATCH 1/9] fix Mooncake friendly_tangents compatibility --- .../utils.jl | 27 ++++++++++++++-- .../test/Back/Mooncake/test.jl | 32 +++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..ddf1d281d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -11,9 +11,32 @@ end function zero_tangent_or_primal(x, backend::AnyAutoMooncake) if get_config(backend).friendly_tangents - # zero(x) but safer - return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) + # Mooncake 0.5.25+ replaced `tangent_to_primal!!` with the + # `tangent_to_friendly!!` framework. For this internal backup we still + # need a primal-shaped value, so use the `AsPrimal` path when + # available and fall back for older Mooncake releases. + return tangent_to_user_primal(zero_tangent(x), x) else return zero_tangent(x) end end + +@inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing + +const mooncake_tangent_to_friendly = maybe_getfield(Mooncake, Symbol("tangent_to_friendly!!")) +const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) +const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) +const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) + +function tangent_to_user_primal(tx, x) + if !isnothing(mooncake_tangent_to_friendly) && + !isnothing(mooncake_friendly_tangent_cache) && + !isnothing(mooncake_as_primal) && + !isnothing(mooncake_no_cache) + dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) + cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}() + return mooncake_tangent_to_friendly(dest, x, tx, cache) + else + return tangent_to_primal!!(_copy_output(x), tx) + end +end diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index d531e542a..b481b759c 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -1,6 +1,7 @@ include("../../testutils.jl") using DifferentiationInterface, DifferentiationInterfaceTest +using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric using Mooncake: Mooncake using Test @@ -80,3 +81,34 @@ test_differentiation( logging = LOGGING, excluded = SECOND_ORDER ) + +@testset "Friendly tangents structured matrices" begin + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + inputs = ( + Symmetric([2.0 1.0; 1.0 3.0]), + Hermitian(ComplexF64[2 1 + im; 1 - im 3]), + SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), + ) + f(x) = real(sum(abs2, x)) + + @testset "$(typeof(x))" for x in inputs + grad = gradient(f, backend, x) + y, grad2 = value_and_gradient(f, backend, x) + pb = only(pullback(identity, backend, x, (x,))) + + @test grad isa Matrix + @test grad2 isa Matrix + @test pb isa Matrix + @test grad == grad2 + @test y == f(x) + @test pb == Matrix(x) + + grad_dense = zero(Matrix(x)) + @test gradient!(f, grad_dense, backend, x) === grad_dense + @test grad_dense == grad + + tx_dense = (zero(Matrix(x)),) + @test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1] + @test tx_dense[1] == pb + end +end From f42023428e7e4a055c2b6f3a2d27d99fe63ed2be Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 4 Apr 2026 15:41:39 +0100 Subject: [PATCH 2/9] format Mooncake fix --- .../utils.jl | 13 +++--- .../test/Back/Mooncake/test.jl | 41 ++++++++----------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index ddf1d281d..927a4cfbe 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -21,18 +21,21 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) end end -@inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing +@inline maybe_getfield(mod, name::Symbol) = + isdefined(mod, name) ? getfield(mod, name) : nothing -const mooncake_tangent_to_friendly = maybe_getfield(Mooncake, Symbol("tangent_to_friendly!!")) +const mooncake_tangent_to_friendly = maybe_getfield( + Mooncake, Symbol("tangent_to_friendly!!") +) const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) function tangent_to_user_primal(tx, x) if !isnothing(mooncake_tangent_to_friendly) && - !isnothing(mooncake_friendly_tangent_cache) && - !isnothing(mooncake_as_primal) && - !isnothing(mooncake_no_cache) + !isnothing(mooncake_friendly_tangent_cache) && + !isnothing(mooncake_as_primal) && + !isnothing(mooncake_no_cache) dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}() return mooncake_tangent_to_friendly(dest, x, tx, cache) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index b481b759c..eb296a567 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -13,8 +13,8 @@ nomatrix(scens) = filter(s -> !(s.x isa AbstractMatrix) && !(s.y isa AbstractMat backends = [ AutoMooncake(), AutoMooncakeForward(), - AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), - AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)), + AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)), + AutoMooncakeForward(; config=Mooncake.Config(; friendly_tangents=true)), ] for backend in backends @@ -23,31 +23,25 @@ for backend in backends end test_differentiation( - backends[3:4], - default_scenarios(); - excluded = SECOND_ORDER, - logging = LOGGING, + backends[3:4], default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING ); test_differentiation( backends[3:4], nomatrix( default_scenarios(; - include_normal = false, - include_constantified = true, - include_cachified = true, - use_tuples = true - ) + include_normal=false, + include_constantified=true, + include_cachified=true, + use_tuples=true, + ), ); - excluded = SECOND_ORDER, - logging = LOGGING, + excluded=SECOND_ORDER, + logging=LOGGING, ); test_differentiation( - backends[1:2], - nomatrix(default_scenarios()); - excluded = SECOND_ORDER, - logging = LOGGING, + backends[1:2], nomatrix(default_scenarios()); excluded=SECOND_ORDER, logging=LOGGING ); EXCLUDED = @static if VERSION ≥ v"1.11-" && VERSION ≤ v"1.12-" @@ -63,12 +57,12 @@ end test_differentiation( [SecondOrder(AutoMooncakeForward(), AutoMooncake())], nomatrix(default_scenarios()); - excluded = EXCLUDED, - logging = LOGGING, + excluded=EXCLUDED, + logging=LOGGING, ) @testset "NamedTuples" begin - ps = (; A = rand(5), B = rand(5)) + ps = (; A=rand(5), B=rand(5)) myfun(ps) = sum(ps.A .* ps.B) grad = gradient(myfun, backends[1], ps) @test grad.A == ps.B @@ -76,14 +70,11 @@ test_differentiation( end test_differentiation( - backends[3:4], - nomatrix(static_scenarios()); - logging = LOGGING, - excluded = SECOND_ORDER + backends[3:4], nomatrix(static_scenarios()); logging=LOGGING, excluded=SECOND_ORDER ) @testset "Friendly tangents structured matrices" begin - backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + backend = AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)) inputs = ( Symmetric([2.0 1.0; 1.0 3.0]), Hermitian(ComplexF64[2 1 + im; 1 - im 3]), From 8235e6c60398e7e9afe676ab5afb1792ce614c34 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 4 Apr 2026 15:46:04 +0100 Subject: [PATCH 3/9] remove Mooncake test formatting churn --- .../test/Back/Mooncake/test.jl | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index eb296a567..b481b759c 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -13,8 +13,8 @@ nomatrix(scens) = filter(s -> !(s.x isa AbstractMatrix) && !(s.y isa AbstractMat backends = [ AutoMooncake(), AutoMooncakeForward(), - AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)), - AutoMooncakeForward(; config=Mooncake.Config(; friendly_tangents=true)), + AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), + AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)), ] for backend in backends @@ -23,25 +23,31 @@ for backend in backends end test_differentiation( - backends[3:4], default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING + backends[3:4], + default_scenarios(); + excluded = SECOND_ORDER, + logging = LOGGING, ); test_differentiation( backends[3:4], nomatrix( default_scenarios(; - include_normal=false, - include_constantified=true, - include_cachified=true, - use_tuples=true, - ), + include_normal = false, + include_constantified = true, + include_cachified = true, + use_tuples = true + ) ); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ); test_differentiation( - backends[1:2], nomatrix(default_scenarios()); excluded=SECOND_ORDER, logging=LOGGING + backends[1:2], + nomatrix(default_scenarios()); + excluded = SECOND_ORDER, + logging = LOGGING, ); EXCLUDED = @static if VERSION ≥ v"1.11-" && VERSION ≤ v"1.12-" @@ -57,12 +63,12 @@ end test_differentiation( [SecondOrder(AutoMooncakeForward(), AutoMooncake())], nomatrix(default_scenarios()); - excluded=EXCLUDED, - logging=LOGGING, + excluded = EXCLUDED, + logging = LOGGING, ) @testset "NamedTuples" begin - ps = (; A=rand(5), B=rand(5)) + ps = (; A = rand(5), B = rand(5)) myfun(ps) = sum(ps.A .* ps.B) grad = gradient(myfun, backends[1], ps) @test grad.A == ps.B @@ -70,11 +76,14 @@ test_differentiation( end test_differentiation( - backends[3:4], nomatrix(static_scenarios()); logging=LOGGING, excluded=SECOND_ORDER + backends[3:4], + nomatrix(static_scenarios()); + logging = LOGGING, + excluded = SECOND_ORDER ) @testset "Friendly tangents structured matrices" begin - backend = AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)) + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) inputs = ( Symmetric([2.0 1.0; 1.0 3.0]), Hermitian(ComplexF64[2 1 + im; 1 - im 3]), From a478578e57a8a0c958c103376b4dea0156cd837f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 08:12:26 +0100 Subject: [PATCH 4/9] ci: retrigger after Mooncake v0.5.26 release From 65997c462b8b493756bb37c1d7019fa7f6f92c81 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 08:25:15 +0100 Subject: [PATCH 5/9] style: apply Runic formatting to utils.jl --- .../ext/DifferentiationInterfaceMooncakeExt/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 927a4cfbe..2184e9ecb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -33,11 +33,11 @@ const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) function tangent_to_user_primal(tx, x) if !isnothing(mooncake_tangent_to_friendly) && - !isnothing(mooncake_friendly_tangent_cache) && - !isnothing(mooncake_as_primal) && - !isnothing(mooncake_no_cache) + !isnothing(mooncake_friendly_tangent_cache) && + !isnothing(mooncake_as_primal) && + !isnothing(mooncake_no_cache) dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) - cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}() + cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() return mooncake_tangent_to_friendly(dest, x, tx, cache) else return tangent_to_primal!!(_copy_output(x), tx) From 106f50fc7b1b77e9f350a45ad4fc2b867864dc8c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 10:41:25 +0100 Subject: [PATCH 6/9] test: skip friendly_tangents static_scenarios on Julia 1.11 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mooncake returns raw Tangent objects instead of friendly arrays for StaticArrays on Julia 1.11. This is an upstream bug — skip the test until it is fixed. --- .../test/Back/Mooncake/test.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index b481b759c..7e5a291a7 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -75,12 +75,15 @@ test_differentiation( @test grad.B == ps.A end -test_differentiation( - backends[3:4], - nomatrix(static_scenarios()); - logging = LOGGING, - excluded = SECOND_ORDER -) +# friendly_tangents + StaticArrays broken on Julia 1.11 (upstream Mooncake bug) +@static if !(VERSION ≥ v"1.11-" && VERSION < v"1.12-") + test_differentiation( + backends[3:4], + nomatrix(static_scenarios()); + logging = LOGGING, + excluded = SECOND_ORDER, + ) +end @testset "Friendly tangents structured matrices" begin backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) From 7043da2edaf2fb58870902246784eb175d66aa5b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 10:44:47 +0100 Subject: [PATCH 7/9] fix: convert raw Mooncake.Tangent in pullback/gradient results On Julia 1.11, Mooncake may return raw Tangent objects instead of friendly arrays for StaticArrays even with friendly_tangents=true. Add _maybe_to_primal dispatch as a safety net that converts leaked Tangent objects to primal-shaped values, no-op otherwise. --- .../DifferentiationInterfaceMooncakeExt/onearg.jl | 8 ++++---- .../DifferentiationInterfaceMooncakeExt/twoarg.jl | 4 ++-- .../DifferentiationInterfaceMooncakeExt/utils.jl | 5 +++++ .../test/Back/Mooncake/test.jl | 15 ++++++--------- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 2514cdc40..2bbf49f1c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -35,7 +35,7 @@ function DI.value_and_pullback( new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return new_y, (_copy_output(new_dx),) + return new_y, (_maybe_to_primal(new_dx, x),) end function DI.value_and_pullback( @@ -51,7 +51,7 @@ function DI.value_and_pullback( y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - y, _copy_output(new_dx) + y, _maybe_to_primal(new_dx, x) end y = first(ys_and_tx[1]) tx = map(last, ys_and_tx) @@ -134,7 +134,7 @@ function DI.value_and_gradient( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return y, _copy_output(new_grad) + return y, _maybe_to_primal(new_grad, x) end function DI.value_and_gradient!( @@ -150,7 +150,7 @@ function DI.value_and_gradient!( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - copyto!(grad, new_grad) + copyto!(grad, _maybe_to_primal(new_grad, x)) return y, grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2b55131b9..ed7f4ca9c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -64,7 +64,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - return y, (_copy_output(dx),) + return y, (_maybe_to_primal(dx, x),) end function DI.value_and_pullback( @@ -90,7 +90,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - _copy_output(dx) + _maybe_to_primal(dx, x) end return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 2184e9ecb..5be9097b3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -21,6 +21,11 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) end end +# Safety net: if Mooncake returns a raw Tangent (e.g. Julia 1.11 + StaticArrays), +# convert it to a primal-shaped value. No-op for already-converted results. +_maybe_to_primal(tx, x) = _copy_output(tx) +_maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x) + @inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 7e5a291a7..313c0d2bf 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -75,15 +75,12 @@ test_differentiation( @test grad.B == ps.A end -# friendly_tangents + StaticArrays broken on Julia 1.11 (upstream Mooncake bug) -@static if !(VERSION ≥ v"1.11-" && VERSION < v"1.12-") - test_differentiation( - backends[3:4], - nomatrix(static_scenarios()); - logging = LOGGING, - excluded = SECOND_ORDER, - ) -end +test_differentiation( + backends[3:4], + nomatrix(static_scenarios()); + logging = LOGGING, + excluded = SECOND_ORDER, +) @testset "Friendly tangents structured matrices" begin backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) From ce72bafde06c0427f022cdd93703a7d95c4ed6b7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 11:57:34 +0100 Subject: [PATCH 8/9] fix: handle MutableTangent and forward mode tangent leaks Also convert leaked Mooncake.MutableTangent (e.g. MVector tangents) and apply _maybe_to_primal in forward mode (pushforward) paths. --- .../ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl | 2 +- .../ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 4 ++-- .../ext/DifferentiationInterfaceMooncakeExt/utils.jl | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index c470b6473..cab7f84d6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -41,7 +41,7 @@ function DI.value_and_pushforward( map(first_unwrap, contexts, prep.context_tangents)..., ) y = first(y_and_dy) - dy = _copy_output(last(y_and_dy)) + dy = _maybe_to_primal(last(y_and_dy), y) return y, dy end y = _copy_output(first(ys_and_ty[1])) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 3c75f530b..8ebb6ef99 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -55,7 +55,7 @@ function DI.value_and_pushforward( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - return _copy_output(new_dy) + return _maybe_to_primal(new_dy, y) end return y, ty end @@ -93,7 +93,7 @@ function DI.value_and_pushforward!( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - copyto!(dy, new_dy) + copyto!(dy, _maybe_to_primal(new_dy, y)) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 5be9097b3..beeb6f611 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -25,6 +25,7 @@ end # convert it to a primal-shaped value. No-op for already-converted results. _maybe_to_primal(tx, x) = _copy_output(tx) _maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x) +_maybe_to_primal(tx::Mooncake.MutableTangent, x) = tangent_to_user_primal(tx, x) @inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing From 150f0a68018e62e126044ed9ae680e07c7aac0f6 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 26 Apr 2026 18:12:56 +0100 Subject: [PATCH 9/9] refactor: bridge SArray/MArray friendly_tangents via triple-extension Replace the broad type-unstable _maybe_to_primal shim (with its IdDict fallback and defensive symbol lookups) with a narrow _to_friendly_value hook in the Mooncake extension and a new triple-extension activated by [Mooncake, StaticArrays]. Mooncake's friendly_tangents framework (chalk-lab/Mooncake.jl#1103) deliberately defaults to :as_raw for SArray/MArray primals; bridge that gap at the DI boundary where the primal type is statically known. Reconstruction via typeof(x)(Mooncake.val(t.fields.data)) is unambiguous because the data::NTuple field maps 1:1 to logical positions, with no aliasing. Restrict the dispatch to scalar float / complex-float eltypes so static arrays with non-float eltype keep using Mooncake's element-wise AbstractArray recursion. Route zero_tangent_or_primal through the same hook so the prep-time dy_backup buffer is primal-shaped (otherwise copyto!(::MutableTangent, ::MVector) fails at twoarg.jl:54). Restore nomatrix(static_scenarios()) on the friendly_tangents=true backends. --- DifferentiationInterface/Project.toml | 3 +- .../DifferentiationInterfaceMooncakeExt.jl | 2 +- .../forward_onearg.jl | 3 +- .../forward_twoarg.jl | 4 +- .../onearg.jl | 8 +- .../twoarg.jl | 4 +- .../utils.jl | 46 +++------- ...tiationInterfaceMooncakeStaticArraysExt.jl | 36 ++++++++ .../test/Back/Mooncake/test.jl | 88 +++++++++++++++++-- 9 files changed, 141 insertions(+), 53 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index d4058e7ac..40d96972c 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -42,6 +42,7 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"] DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" +DifferentiationInterfaceMooncakeStaticArraysExt = ["Mooncake", "StaticArrays"] DifferentiationInterfacePolyesterForwardDiffExt = [ "PolyesterForwardDiff", "ForwardDiff", @@ -71,7 +72,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.5.1" +Mooncake = "0.5.27" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 3513d548c..8d1fc46f3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -30,7 +30,7 @@ using Mooncake: primal, _copy_output, _copy_to_output!!, - tangent_to_primal!! + tangent_to_friendly!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index cab7f84d6..447f84b6b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -41,7 +41,8 @@ function DI.value_and_pushforward( map(first_unwrap, contexts, prep.context_tangents)..., ) y = first(y_and_dy) - dy = _maybe_to_primal(last(y_and_dy), y) + dy_raw = last(y_and_dy) + dy = @something(_to_friendly_value(dy_raw, y), _copy_output(dy_raw)) return y, dy end y = _copy_output(first(ys_and_ty[1])) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 8ebb6ef99..a6354a0f7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -55,7 +55,7 @@ function DI.value_and_pushforward( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - return _maybe_to_primal(new_dy, y) + return @something(_to_friendly_value(new_dy, y), _copy_output(new_dy)) end return y, ty end @@ -93,7 +93,7 @@ function DI.value_and_pushforward!( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - copyto!(dy, _maybe_to_primal(new_dy, y)) + copyto!(dy, something(_to_friendly_value(new_dy, y), new_dy)) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 2bbf49f1c..61f5eb837 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -35,7 +35,7 @@ function DI.value_and_pullback( new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return new_y, (_maybe_to_primal(new_dx, x),) + return new_y, (@something(_to_friendly_value(new_dx, x), _copy_output(new_dx)),) end function DI.value_and_pullback( @@ -51,7 +51,7 @@ function DI.value_and_pullback( y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - y, _maybe_to_primal(new_dx, x) + y, @something(_to_friendly_value(new_dx, x), _copy_output(new_dx)) end y = first(ys_and_tx[1]) tx = map(last, ys_and_tx) @@ -134,7 +134,7 @@ function DI.value_and_gradient( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return y, _maybe_to_primal(new_grad, x) + return y, @something(_to_friendly_value(new_grad, x), _copy_output(new_grad)) end function DI.value_and_gradient!( @@ -150,7 +150,7 @@ function DI.value_and_gradient!( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - copyto!(grad, _maybe_to_primal(new_grad, x)) + copyto!(grad, something(_to_friendly_value(new_grad, x), new_grad)) return y, grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index ed7f4ca9c..4dbc6d742 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -64,7 +64,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - return y, (_maybe_to_primal(dx, x),) + return y, (@something(_to_friendly_value(dx, x), _copy_output(dx)),) end function DI.value_and_pullback( @@ -90,7 +90,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - _maybe_to_primal(dx, x) + @something(_to_friendly_value(dx, x), _copy_output(dx)) end return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index beeb6f611..11119face 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -9,43 +9,19 @@ function call_and_return(f!::F, y, x, contexts...) where {F} return y end +# Hook for bridging primal types whose `friendly_tangent_cache` falls through to +# `:as_raw` in Mooncake's framework, leaking a raw `Tangent` / `MutableTangent` +# instead of a primal-shaped value. The default returns `nothing`; specialised +# methods are loaded by triple-extensions when the relevant primal-type packages +# are available (see DifferentiationInterfaceMooncakeStaticArraysExt for the +# `SArray` / `MArray` case). +_to_friendly_value(t, x) = nothing + function zero_tangent_or_primal(x, backend::AnyAutoMooncake) + zt = zero_tangent(x) if get_config(backend).friendly_tangents - # Mooncake 0.5.25+ replaced `tangent_to_primal!!` with the - # `tangent_to_friendly!!` framework. For this internal backup we still - # need a primal-shaped value, so use the `AsPrimal` path when - # available and fall back for older Mooncake releases. - return tangent_to_user_primal(zero_tangent(x), x) - else - return zero_tangent(x) - end -end - -# Safety net: if Mooncake returns a raw Tangent (e.g. Julia 1.11 + StaticArrays), -# convert it to a primal-shaped value. No-op for already-converted results. -_maybe_to_primal(tx, x) = _copy_output(tx) -_maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x) -_maybe_to_primal(tx::Mooncake.MutableTangent, x) = tangent_to_user_primal(tx, x) - -@inline maybe_getfield(mod, name::Symbol) = - isdefined(mod, name) ? getfield(mod, name) : nothing - -const mooncake_tangent_to_friendly = maybe_getfield( - Mooncake, Symbol("tangent_to_friendly!!") -) -const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) -const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) -const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) - -function tangent_to_user_primal(tx, x) - if !isnothing(mooncake_tangent_to_friendly) && - !isnothing(mooncake_friendly_tangent_cache) && - !isnothing(mooncake_as_primal) && - !isnothing(mooncake_no_cache) - dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) - cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() - return mooncake_tangent_to_friendly(dest, x, tx, cache) + return @something(_to_friendly_value(zt, x), tangent_to_friendly!!(x, zt)) else - return tangent_to_primal!!(_copy_output(x), tx) + return zt end end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl new file mode 100644 index 000000000..797445b6f --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl @@ -0,0 +1,36 @@ +module DifferentiationInterfaceMooncakeStaticArraysExt + +using Base: IEEEFloat +using DifferentiationInterface: DifferentiationInterface +using Mooncake: Mooncake +using StaticArrays: MArray, SArray + +# Reach into the binary MooncakeExt to extend its `_to_friendly_value` hook. +# Both Mooncake and StaticArrays are loaded whenever this extension is active, +# so the binary extension is guaranteed to be loaded as well. +const _MooncakeExt = Base.get_extension( + DifferentiationInterface, :DifferentiationInterfaceMooncakeExt, +) + +# Restrict to scalar float / complex-float eltypes: those are the layouts where +# Mooncake's framework sends `SArray` / `MArray` through `AsRaw` and the single +# `data::NTuple` element-to-position mapping is 1:1 (no aliasing). Non-float +# eltypes hit Mooncake's element-wise `AbstractArray` recursion at +# `Mooncake/src/tangents/tangents.jl:1453`; let that path run unimpeded. +const _StaticEltype = Union{IEEEFloat, Complex{<:IEEEFloat}} + +# Mooncake's `friendly_tangent_cache` framework defaults to `:as_raw` for +# `SArray` / `MArray` primals with float eltype, leaking a raw `Tangent` / +# `MutableTangent` instead of a primal-shaped value. Bridge that gap here. The +# reconstruction is unambiguous because the `data::NTuple` field maps each +# element to one logical position (unlike `Symmetric` / `Hermitian`, where a +# single stored entry can represent two positions). +@inline _MooncakeExt._to_friendly_value( + t::Mooncake.Tangent, x::SArray{S, T} +) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data)) + +@inline _MooncakeExt._to_friendly_value( + t::Mooncake.MutableTangent, x::MArray{S, T} +) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data)) + +end # module diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 313c0d2bf..7555752af 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -83,15 +83,88 @@ test_differentiation( ) @testset "Friendly tangents structured matrices" begin + # Mooncake 0.5.25+ returns a plain `Matrix` for structured inputs under + # `friendly_tangents=true` (chalk-lab/Mooncake.jl#1103); the complex case + # follows the standard reverse-mode convention (chalk-lab/Mooncake.jl#773). + # + # Per-wrapper test functions are chosen for their non-triviality given + # that matmul on small matrices hits a `utf8proc_isupper` ccall that + # Mooncake cannot differentiate (LinearAlgebra._matmul2x2_elements → + # WrapperChar). For the real wrappers we use a manual triple-loop tr(X^3) + # whose unrestricted gradient is 3·X²; for Hermitian we use the simpler + # abs2-sum because the complex Wirtinger ground truth via tr(X^3) is + # convention-heavy. The expected friendly gradient is then computed by + # aggregating the unrestricted per-element gradient into the wrapper's + # canonical storage cells, derived independently of Mooncake. + + # tr(X^3) without matmul. Indices i,j,k each range over axes(X,1). + function tr_x3(X) + s = zero(eltype(X)) + n = size(X, 1) + @inbounds for i in 1:n, j in 1:n, k in 1:n + s += X[i, j] * X[j, k] * X[k, i] + end + return real(s) + end + + # Symmetric storage: upper triangle holds the sum of (i,j) and (j,i) + # per-element contributions; strict lower triangle is zero. + function aggregate_symmetric(G) + n = size(G, 1) + H = zero(G) + @inbounds for i in 1:n + H[i, i] = G[i, i] + for j in (i + 1):n + H[i, j] = G[i, j] + G[j, i] + end + end + return H + end + + # SymTridiagonal storage: diagonal + symmetric off-diagonals (both + # `(i,i+1)` and `(i+1,i)` slots hold the doubled contribution); entries + # outside that band are structurally zero in the wrapper. + function aggregate_symtridiagonal(G) + n = size(G, 1) + H = zero(G) + @inbounds for i in 1:n + H[i, i] = G[i, i] + if i < n + aggregated = G[i, i + 1] + G[i + 1, i] + H[i, i + 1] = aggregated + H[i + 1, i] = aggregated + end + end + return H + end + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) - inputs = ( - Symmetric([2.0 1.0; 1.0 3.0]), - Hermitian(ComplexF64[2 1 + im; 1 - im 3]), - SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), + abs2_sum(x) = real(sum(abs2, x)) + cases = ( + ( + x = Symmetric([2.0 1.0; 1.0 3.0]), + f = tr_x3, + expected_grad = let M = Matrix(Symmetric([2.0 1.0; 1.0 3.0])) + aggregate_symmetric(3 * M^2) + end, + ), + ( + x = Hermitian(ComplexF64[2 1 + im; 1 - im 3]), + f = abs2_sum, + expected_grad = ComplexF64[4 4 + 4im; 0 6], + ), + ( + x = SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), + f = tr_x3, + expected_grad = let M = Matrix(SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0])) + aggregate_symtridiagonal(3 * M^2) + end, + ), ) - f(x) = real(sum(abs2, x)) - @testset "$(typeof(x))" for x in inputs + @testset "$(typeof(case.x))" for case in cases + x = case.x + f = case.f grad = gradient(f, backend, x) y, grad2 = value_and_gradient(f, backend, x) pb = only(pullback(identity, backend, x, (x,))) @@ -100,7 +173,8 @@ test_differentiation( @test grad2 isa Matrix @test pb isa Matrix @test grad == grad2 - @test y == f(x) + @test grad ≈ case.expected_grad + @test y ≈ f(x) @test pb == Matrix(x) grad_dense = zero(Matrix(x))