Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #988 +/- ##
==========================================
- Coverage 98.21% 97.28% -0.94%
==========================================
Files 135 131 -4
Lines 8000 7984 -16
==========================================
- Hits 7857 7767 -90
- Misses 143 217 +74
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I think I understand the CI error, there is something I need to patch on the Mooncake side, will come back to this |
|
chalk-lab/Mooncake.jl#1129 should unblock this PR, we'll release it as soon as it's available |
|
Thank you for taking a crack at this! I'll wait until the tests pass before reviewing if that's okay |
|
totally fine! |
Mooncake returns raw Tangent objects instead of friendly arrays for StaticArrays on Julia 1.11. This is an upstream bug — skip the test until it is fixed.
On Julia 1.11, Mooncake may return raw Tangent objects instead of friendly arrays for StaticArrays even with friendly_tangents=true. Add _maybe_to_primal dispatch as a safety net that converts leaked Tangent objects to primal-shaped values, no-op otherwise.
Also convert leaked Mooncake.MutableTangent (e.g. MVector tangents) and apply _maybe_to_primal in forward mode (pushforward) paths.
|
@gdalle the Mooncake CIs are passing (it probably requires Mooncake v0.5.26). Could you take over? I also won't be offended if you want to start a new PR. |
|
I'll take a look when I can, thanks a bunch! DI's tests are failing on main too because of Mooncake's breaking release so this is a priority for me. Do you know why coverage is not complete? |
|
Thanks a lot.
I am not certain. A bad guess is that some code changes are more defensive than necessary. Sorry! |
gdalle
left a comment
There was a problem hiding this comment.
Thank you for trying to fix what others broke! I added a few remarks to understand the task a bit better, I'll wait for your answers
| @inline maybe_getfield(mod, name::Symbol) = | ||
| isdefined(mod, name) ? getfield(mod, name) : nothing | ||
|
|
||
| const mooncake_tangent_to_friendly = maybe_getfield( | ||
| Mooncake, Symbol("tangent_to_friendly!!") | ||
| ) | ||
| const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) | ||
| const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) | ||
| const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) |
There was a problem hiding this comment.
This doesn't seem to be very robust? I'd rather impose a lower bound for Mooncake at v0.5.25 in Project.toml (that way we're sure we can use all of these symbols)
| ) | ||
| y = first(y_and_dy) | ||
| dy = _copy_output(last(y_and_dy)) | ||
| dy = _maybe_to_primal(last(y_and_dy), y) |
There was a problem hiding this comment.
Why do we need to ensure that primal conversion happens here? If friendly_tangents is set to true, won't Mooncake's pushforward and pullback already return a primal-like object?
| backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) | ||
| inputs = ( | ||
| Symmetric([2.0 1.0; 1.0 3.0]), | ||
| Hermitian(ComplexF64[2 1 + im; 1 - im 3]), |
There was a problem hiding this comment.
Do you know which convention Mooncake uses for gradients of functions with complex inputs and real outputs? There are two possible choices, see e.g. https://arxiv.org/abs/2409.06752
| @testset "$(typeof(x))" for x in inputs | ||
| grad = gradient(f, backend, x) | ||
| y, grad2 = value_and_gradient(f, backend, x) | ||
| pb = only(pullback(identity, backend, x, (x,))) |
There was a problem hiding this comment.
This is not a strong enough test, the function is too simple
| @test grad isa Matrix | ||
| @test grad2 isa Matrix | ||
| @test pb isa Matrix | ||
| @test grad == grad2 |
There was a problem hiding this comment.
grad and grad2 are never compared against the ground truth
| !isnothing(mooncake_as_primal) && | ||
| !isnothing(mooncake_no_cache) | ||
| dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) | ||
| cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() |
There was a problem hiding this comment.
Why is a type-unstable dictionary needed here?
Does this make every tangent-to-primal conversion outside of bitstypes slow?
There was a problem hiding this comment.
Regardless of why, we may want to allocate this dictionary in the preparation phase
There was a problem hiding this comment.
yup this will regress performance on hot paths. This should be allocated once during the prepare phase and stored in the extras cache.
|
@AstitvaAggarwal @Technici4n could you maybe take a look too? |
AstitvaAggarwal
left a comment
There was a problem hiding this comment.
also we might want to keep track of future possible tangent_types: _maybe_to_primal(x, _) = x will silently pass through any tangent type not yet accounted for (e.g. a future Mooncake.SparseTangent), making failures invisible.
| !isnothing(mooncake_as_primal) && | ||
| !isnothing(mooncake_no_cache) | ||
| dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) | ||
| cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() |
There was a problem hiding this comment.
yup this will regress performance on hot paths. This should be allocated once during the prepare phase and stored in the extras cache.
Replace the broad type-unstable _maybe_to_primal shim (with its IdDict fallback and defensive symbol lookups) with a narrow _to_friendly_value hook in the Mooncake extension and a new triple-extension activated by [Mooncake, StaticArrays]. Mooncake's friendly_tangents framework (chalk-lab/Mooncake.jl#1103) deliberately defaults to :as_raw for SArray/MArray primals; bridge that gap at the DI boundary where the primal type is statically known. Reconstruction via typeof(x)(Mooncake.val(t.fields.data)) is unambiguous because the data::NTuple field maps 1:1 to logical positions, with no aliasing. Restrict the dispatch to scalar float / complex-float eltypes so static arrays with non-float eltype keep using Mooncake's element-wise AbstractArray recursion. Route zero_tangent_or_primal through the same hook so the prep-time dy_backup buffer is primal-shaped (otherwise copyto!(::MutableTangent, ::MVector) fails at twoarg.jl:54). Restore nomatrix(static_scenarios()) on the friendly_tangents=true backends.
|
@sunxd3 can you explain why type piracy is needed here? Sounds like there should be a StaticArraysExt in Mooncake (but the maintainer doesn't like it) or the other way around. I will not introduce piracy in a package depended upon by the entire ecosystem. |
|
would be weird but one could do a thin DI wrapper DISArray of a SArray which would just wrap a SArray everytime it needs to go into a Mooncake function, llvm should understand this is unneeded and remove it at compile time (since its all static) but then, there is the question of do we need any kind of interface on those, and I hope not otherwise its just too big of a job. |
An attempt at addressing #986.
Feel free to make any edits or take over!