-
Notifications
You must be signed in to change notification settings - Fork 32
Mooncake 0.5.25 compat #988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
sunxd3
wants to merge
9
commits into
JuliaDiff:main
Choose a base branch
from
sunxd3:sunxd/fix-mooncake-friendly-tangents
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
75d3c05
fix Mooncake friendly_tangents compatibility
sunxd3 f420234
format Mooncake fix
sunxd3 8235e6c
remove Mooncake test formatting churn
sunxd3 a478578
ci: retrigger after Mooncake v0.5.26 release
sunxd3 65997c4
style: apply Runic formatting to utils.jl
sunxd3 106f50f
test: skip friendly_tangents static_scenarios on Julia 1.11
sunxd3 7043da2
fix: convert raw Mooncake.Tangent in pullback/gradient results
sunxd3 ce72baf
fix: handle MutableTangent and forward mode tangent leaks
sunxd3 150f0a6
refactor: bridge SArray/MArray friendly_tangents via triple-extension
sunxd3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
...iationInterfaceMooncakeStaticArraysExt/DifferentiationInterfaceMooncakeStaticArraysExt.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| module DifferentiationInterfaceMooncakeStaticArraysExt | ||
|
|
||
| using Base: IEEEFloat | ||
| using DifferentiationInterface: DifferentiationInterface | ||
| using Mooncake: Mooncake | ||
| using StaticArrays: MArray, SArray | ||
|
|
||
| # Reach into the binary MooncakeExt to extend its `_to_friendly_value` hook. | ||
| # Both Mooncake and StaticArrays are loaded whenever this extension is active, | ||
| # so the binary extension is guaranteed to be loaded as well. | ||
| const _MooncakeExt = Base.get_extension( | ||
| DifferentiationInterface, :DifferentiationInterfaceMooncakeExt, | ||
| ) | ||
|
|
||
| # Restrict to scalar float / complex-float eltypes: those are the layouts where | ||
| # Mooncake's framework sends `SArray` / `MArray` through `AsRaw` and the single | ||
| # `data::NTuple` element-to-position mapping is 1:1 (no aliasing). Non-float | ||
| # eltypes hit Mooncake's element-wise `AbstractArray` recursion at | ||
| # `Mooncake/src/tangents/tangents.jl:1453`; let that path run unimpeded. | ||
| const _StaticEltype = Union{IEEEFloat, Complex{<:IEEEFloat}} | ||
|
|
||
| # Mooncake's `friendly_tangent_cache` framework defaults to `:as_raw` for | ||
| # `SArray` / `MArray` primals with float eltype, leaking a raw `Tangent` / | ||
| # `MutableTangent` instead of a primal-shaped value. Bridge that gap here. The | ||
| # reconstruction is unambiguous because the `data::NTuple` field maps each | ||
| # element to one logical position (unlike `Symmetric` / `Hermitian`, where a | ||
| # single stored entry can represent two positions). | ||
| @inline _MooncakeExt._to_friendly_value( | ||
| t::Mooncake.Tangent, x::SArray{S, T} | ||
| ) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data)) | ||
|
|
||
| @inline _MooncakeExt._to_friendly_value( | ||
| t::Mooncake.MutableTangent, x::MArray{S, T} | ||
| ) where {S, T <: _StaticEltype} = typeof(x)(Mooncake.val(t.fields.data)) | ||
|
|
||
| end # module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| include("../../testutils.jl") | ||
|
|
||
| using DifferentiationInterface, DifferentiationInterfaceTest | ||
| using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric | ||
| using Mooncake: Mooncake | ||
| using Test | ||
|
|
||
|
|
@@ -78,5 +79,110 @@ test_differentiation( | |
| backends[3:4], | ||
| nomatrix(static_scenarios()); | ||
| logging = LOGGING, | ||
| excluded = SECOND_ORDER | ||
| excluded = SECOND_ORDER, | ||
| ) | ||
|
|
||
| @testset "Friendly tangents structured matrices" begin | ||
| # Mooncake 0.5.25+ returns a plain `Matrix` for structured inputs under | ||
| # `friendly_tangents=true` (chalk-lab/Mooncake.jl#1103); the complex case | ||
| # follows the standard reverse-mode convention (chalk-lab/Mooncake.jl#773). | ||
| # | ||
| # Per-wrapper test functions are chosen for their non-triviality given | ||
| # that matmul on small matrices hits a `utf8proc_isupper` ccall that | ||
| # Mooncake cannot differentiate (LinearAlgebra._matmul2x2_elements → | ||
| # WrapperChar). For the real wrappers we use a manual triple-loop tr(X^3) | ||
| # whose unrestricted gradient is 3·X²; for Hermitian we use the simpler | ||
| # abs2-sum because the complex Wirtinger ground truth via tr(X^3) is | ||
| # convention-heavy. The expected friendly gradient is then computed by | ||
| # aggregating the unrestricted per-element gradient into the wrapper's | ||
| # canonical storage cells, derived independently of Mooncake. | ||
|
|
||
| # tr(X^3) without matmul. Indices i,j,k each range over axes(X,1). | ||
| function tr_x3(X) | ||
| s = zero(eltype(X)) | ||
| n = size(X, 1) | ||
| @inbounds for i in 1:n, j in 1:n, k in 1:n | ||
| s += X[i, j] * X[j, k] * X[k, i] | ||
| end | ||
| return real(s) | ||
| end | ||
|
|
||
| # Symmetric storage: upper triangle holds the sum of (i,j) and (j,i) | ||
| # per-element contributions; strict lower triangle is zero. | ||
| function aggregate_symmetric(G) | ||
| n = size(G, 1) | ||
| H = zero(G) | ||
| @inbounds for i in 1:n | ||
| H[i, i] = G[i, i] | ||
| for j in (i + 1):n | ||
| H[i, j] = G[i, j] + G[j, i] | ||
| end | ||
| end | ||
| return H | ||
| end | ||
|
|
||
| # SymTridiagonal storage: diagonal + symmetric off-diagonals (both | ||
| # `(i,i+1)` and `(i+1,i)` slots hold the doubled contribution); entries | ||
| # outside that band are structurally zero in the wrapper. | ||
| function aggregate_symtridiagonal(G) | ||
| n = size(G, 1) | ||
| H = zero(G) | ||
| @inbounds for i in 1:n | ||
| H[i, i] = G[i, i] | ||
| if i < n | ||
| aggregated = G[i, i + 1] + G[i + 1, i] | ||
| H[i, i + 1] = aggregated | ||
| H[i + 1, i] = aggregated | ||
| end | ||
| end | ||
| return H | ||
| end | ||
|
|
||
| backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) | ||
| abs2_sum(x) = real(sum(abs2, x)) | ||
| cases = ( | ||
| ( | ||
| x = Symmetric([2.0 1.0; 1.0 3.0]), | ||
| f = tr_x3, | ||
| expected_grad = let M = Matrix(Symmetric([2.0 1.0; 1.0 3.0])) | ||
| aggregate_symmetric(3 * M^2) | ||
| end, | ||
| ), | ||
| ( | ||
| x = Hermitian(ComplexF64[2 1 + im; 1 - im 3]), | ||
| f = abs2_sum, | ||
| expected_grad = ComplexF64[4 4 + 4im; 0 6], | ||
| ), | ||
| ( | ||
| x = SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), | ||
| f = tr_x3, | ||
| expected_grad = let M = Matrix(SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0])) | ||
| aggregate_symtridiagonal(3 * M^2) | ||
| end, | ||
| ), | ||
| ) | ||
|
|
||
| @testset "$(typeof(case.x))" for case in cases | ||
| x = case.x | ||
| f = case.f | ||
| grad = gradient(f, backend, x) | ||
| y, grad2 = value_and_gradient(f, backend, x) | ||
| pb = only(pullback(identity, backend, x, (x,))) | ||
|
|
||
| @test grad isa Matrix | ||
| @test grad2 isa Matrix | ||
| @test pb isa Matrix | ||
| @test grad == grad2 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| @test grad ≈ case.expected_grad | ||
| @test y ≈ f(x) | ||
| @test pb == Matrix(x) | ||
|
|
||
| grad_dense = zero(Matrix(x)) | ||
| @test gradient!(f, grad_dense, backend, x) === grad_dense | ||
| @test grad_dense == grad | ||
|
|
||
| tx_dense = (zero(Matrix(x)),) | ||
| @test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1] | ||
| @test tx_dense[1] == pb | ||
| end | ||
| end | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a strong enough test, the function is too simple