[WIP] docs: prefer Mooncake over Zygote where it works end-to-end#1419
Conversation
WIP — depends on JuliaDiff/DifferentiationInterface.jl#989 With the recent SciMLSensitivity-side Mooncake fixes (1376/1397/1412) and the upstream DI Mooncake fix in JuliaDiff/DifferentiationInterface.jl#989 (which unwraps `Mooncake.Tangent`/`MutableTangent` cotangents back to the primal type so `OptimizationBase.gradient!` no longer dies on `copyto!(::ComponentVector, ::Mooncake.Tangent)`), Mooncake now works end-to-end for the majority of the SciMLSensitivity tutorials. This PR migrates every doc/tutorial that I could verify runs cleanly under `OPT.AutoMooncake(; config = nothing)` (or the direct Mooncake / DifferentiationInterface API for non-Optimization examples). Each migrated example was executed locally on Julia 1.11 against SciMLSensitivity master + the patched DI to confirm the gradient flows and the optimizer makes progress. ## Migrated to Mooncake - `getting_started.md` — `Zygote.gradient(loss, u0, p)` → `DI.gradient(closure, AutoMooncake, p)` (also exercises `GaussAdjoint`) - `manual/differential_equation_sensitivities.md` — same DI rewrite, reordered the AD list to put Mooncake first - `tutorials/parameter_estimation_ode.md` — `OPT.AutoZygote()` → `OPT.AutoMooncake(; config = nothing)`, PolyOpt converges to ≈2e-6 in 100 steps - `tutorials/chaotic_ode.md` — `Zygote.gradient(p -> G(p), p)` → `DI.gradient(p -> G(p), AutoMooncake, p)` for `ForwardLSS` - `tutorials/training_tips/divergence.md` — Lotka-Volterra retcode pattern, `AutoMooncake` swap - `tutorials/training_tips/local_minima.md` — Lux + ComponentArrays neural ODE, two `AutoZygote → AutoMooncake` swaps - `tutorials/training_tips/multiple_nn.md` — Lux + multi-NN + ComponentArrays + `InterpolatingAdjoint(ReverseDiffVJP)`, `AutoZygote → AutoMooncake` - `examples/ode/exogenous_input.md` — Hammerstein system + Lux UDE - `examples/hybrid_jump/bouncing_ball.md` — `OPT.AutoMooncake(...)` swap. Also replaced `sol[end][1]` with `last(sol.u)[1]` to dodge a pre-existing `BoundsError` in `SciMLBaseMooncakeExt._scatter_pullback` for `getindex(::ODESolution, end)`; the underlying `Vector{Vector{Float64}}` access takes the same value with no rrule bug - `examples/optimal_control/optimal_control.md` — drops the now-unused `import Zygote` (the example uses `OPT.AutoForwardDiff()`) - `examples/pde/pde_constrained.md` — 1D heat-equation parameter fit, `AutoZygote → AutoMooncake`, both `@example pde` and `@example pde2` blocks - `examples/sde/optimization_sde.md` (Example 3 only) — SDE control with `ForwardDiffSensitivity()`, `AutoZygote → AutoMooncake`. Example 1 keeps Zygote because it relies on `EnsembleProblem` (see below) - `Benchmark.md` — `Zygote.gradient($loss_neuralode, $u0, $ps, $st)` → `DI.gradient($loss_ps, $backend, $ps)` with a closure over `u0`/`st`. This block is `julia` not `@example`, so it isn't executed by Documenter, but the rewrite still demonstrates the recommended user pattern - `faq.md` — out-of-place RHS isolation snippet rewritten from `Zygote.pullback` to `Mooncake.prepare_pullback_cache` / `Mooncake.value_and_pullback!!`, verified locally on a Lotka-Volterra closure - `index.md` — list reorder to put Mooncake (and Enzyme) above Zygote in the AD compatibility table - `docs/Project.toml` — adds `Mooncake` and `DifferentiationInterface` with appropriate compat bounds ## Left on Zygote with an explanatory `!!! note` The remaining tutorials hit one of three independent upstream blockers in Mooncake itself, all of which are out of scope for a docs PR. I left them on `OPT.AutoZygote()` and added a callout pointing at the specific failure mode so future contributors know what to monitor: - **`EnsembleProblem` rule compilation fails** (`StackOverflowError` inside Mooncake's rule compiler when it tries to differentiate `__solve(::AbstractEnsembleProblem, …)`): - `tutorials/data_parallel.md` - `examples/sde/optimization_sde.md` (Example 1, the quasi-likelihood fit) - `examples/sde/SDE_control.md` (also has `Zygote.@Nograd CreateGrid` which would translate to `Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(CreateGrid), Any, Any}` once the EnsembleProblem blocker is resolved — I left both lines in the tutorial commentary so it's a one-line fix later) - **`MethodOfSteps` DDE adjoint fails** (`StackOverflowError` during rule compilation of the `DDEProblem` solve): - `examples/dde/delay_diffeq.md` - **ComponentArrays cotangent / SciMLBase Mooncake-extension gaps** on the more exotic adjoint paths (missing `increment_and_get_rdata!` method, `ReverseDiffAdjoint`-tracked values that don't match Mooncake's `CoDual` type expectations, nested `ComponentVector` cotangents, or `SecondOrder(AutoMooncake, AutoMooncake)` fallback): - `examples/ode/second_order_adjoints.md` (NewtonTrustRegion needs a Hessian, the Adam-only first half does work with Mooncake but the point of the tutorial is the second-order optimization) - `examples/ode/second_order_neural.md` (`SecondOrderODEProblem` + Lux + CV) - `examples/optimal_control/feedback_control.md` (nested `ComponentArray(; u0, p_all)`) - `examples/hybrid_jump/hybrid_diffeq.md` (`ReverseDiffAdjoint` inner) - `examples/neural_ode/simplechains.md` (`QuadratureAdjoint(ZygoteVJP)` + `StaticArrays`) - `examples/pde/brusselator.md` (FBDF stiff PDE + Lux + CV via the auto-selected adjoint) Each note records the exact error so it's clear which Mooncake/SciMLBase upstream patch unblocks the migration. When that lands, switching the remaining files is mechanical. ## Verified locally Every migrated `@example` block above was either run directly or matched a pattern that I ran end-to-end (Lux+CA neural ODE training, Optimization+CA loop, etc.) against the patched DI from JuliaDiff/DifferentiationInterface.jl#989. The tutorials that are blocked are exactly the ones I could not get past compilation. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
After SciML/ComponentArrays.jl#350 (released as ComponentArrays v0.15.34) registers a `friendly_tangent_cache` override for `ComponentArray`, the `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))` form now uses the friendly-tangent unwrap path inside Mooncake itself, which solves the same `copyto!(::ComponentVector, ::Mooncake.Tangent)` crash that JuliaDiff/DifferentiationInterface.jl#989 fixed at the DI layer for the `config = nothing` default. I re-tested the migration with **stock DI 0.7.16** plus **ComponentArrays from main (0.15.34)** and confirmed the migrated tutorials still pass end-to-end (LV+CA BFGS, multiple_nn Lux+CA Adam, local_minima Lux+CA Adam, parameter_estimation_ode PolyOpt, getting_started + GaussAdjoint, bouncing_ball with the `last(sol.u)[1]` workaround, divergence, exogenous_input, etc.). The reverted tutorials (\`EnsembleProblem\`, \`MethodOfSteps\` DDE, \`SecondOrderODEProblem\`, nested CV, \`ReverseDiffAdjoint\` inner, \`SimpleChains\`+\`StaticArrays\`, FBDF stiff PDE) are still blocked on independent upstream issues that CA SciML#350 does not address — I reverified each one with friendly_tangents+CA-main and they still fail with the same errors recorded in the !!! note callouts. This commit: 1. Switches every migrated `OPT.AutoMooncake(; config = nothing)` / `SMS.AutoMooncake(...)` / `DI.AutoMooncake(...)` / `ADTypes.AutoMooncake(...)` to `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))` (and the equivalent for the other prefixes). 2. Updates the recommended pattern shown in every `!!! note` callout on the still-Zygote tutorials to match. 3. Bumps the `ComponentArrays` compat in `docs/Project.toml` from `0.15` to `0.15.34` so the docs build picks up the friendly-tangent support. With this change the SMS docs PR no longer hard-depends on JuliaDiff/DifferentiationInterface.jl#989. That DI patch is still an independently useful improvement (it makes the default `config = nothing` form work without the user having to know about the flag, and also fixes the `MVector`/`SVector` cases), but it is no longer a blocker for this migration. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
|
@gdalle @AstitvaAggarwal why is friendly_tangents=false the default? It seems a bit odd that literally every example requires we set that to true in order for it to work, it seems to violate the principle of trying to make the simple things easy and complex things possible. |
|
Note the DI changes are no longer necessary with SciML/ComponentArrays.jl#350 |
|
Setting friendly tangents as the default is a breaking but welcome change, which will hopefully be part of Mooncake v0.6. |
|
Okay cool, that will be a welcome change. |
Adding a MooncakeOriginator dispatch for `_concrete_solve_adjoint(…, ::ReverseDiffAdjoint, …)` so the hybrid_diffeq tutorial (and PR SciML#1419) no longer have to fall back to Zygote when the inner sensealg is `ReverseDiffAdjoint` and the outer AD is Mooncake. Before: the method threw `MooncakeTrackedRealError` on `MooncakeOriginator` because the return value `sensitivity_solution(tracked_sol, plain_u, plain_t)` still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in its nested fields (`interp`, `prob`, `alg`, …). Mooncake's `@from_rrule` plumbing calls `zero_tangent(y_primal)` and therefore recursively computes `tangent_type` for every nested field of the returned solution; that recursion fails on the tracked type parameters with either a `TypeError` or an unhelpful tangent-type error. ChainRules / Zygote don't inspect the primal's type parameters, so they are unaffected. The new method delegates the tape construction (and hence the whole backward pass) to the existing `ChainRulesOriginator` path and then replaces the primal with a fresh plain-arithmetic solve of the same problem via `SensitivityADPassThrough`. Reusing the main method for the tape keeps the outward-facing return type identical to what the non-sensitivity solve would return (`InterpolationData` / `DEStats`), which matters because Mooncake's `DerivedRule` specialises on the type Julia inference picks for `CRC.rrule(solve_up, …)` — and that inference does not narrow through the `originator` kwarg, so the compiled rule expects the *main* method's return shape even on the dedicated Mooncake dispatch. Adds `test/mooncake_reversediff_adjoint.jl` covering both a plain Lotka-Volterra ODE and a hybrid ODE with `PresetTimeCallback` (mirroring the `hybrid_diffeq.md` tutorial that was forced to stay on Zygote in SciML#1419). Gradients are checked against `ForwardDiff` and `Zygote` at `rtol = 1e-4` / `1e-3`; the looser tolerance on the hybrid case reflects the ~ULP arithmetic reordering between the tape's tracked forward and the primal's plain forward amplified by the callback-driven time grid. Also updates `MOONCAKE_TRACKED_REAL_ERROR_MESSAGE` so it no longer claims `ReverseDiffAdjoint` is incompatible with Mooncake (only `TrackerAdjoint` still is). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
…er notes After investigating the four blocked tutorials more carefully and adding the missing `increment_and_get_rdata!` dispatch for `ComponentVector` cotangents in ComponentArrays' Mooncake extension (SciML/ComponentArrays.jl#351), three more tutorials are now Mooncake- compatible end-to-end: ## hybrid_diffeq.md (un-reverted) The original file pinned `sensealg = SMS.ReverseDiffAdjoint()` explicitly. The continuous adjoints (`BacksolveAdjoint`, `InterpolatingAdjoint`, `GaussAdjoint`, `QuadratureAdjoint`) are now compatible with callbacks for ODEs, so the explicit `ReverseDiffAdjoint` choice is no longer necessary. Drop it and let the default sensealg auto-pick. Combined with CA SciML#351 (which fixes the `increment_and_get_rdata!` mismatch on the parameter cotangent path), the example now trains under `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`. ## delay_diffeq.md (un-reverted) Same story — the original file pinned `sensealg = SMS.ReverseDiffAdjoint()` explicitly, which Mooncake hits with a `StackOverflowError` during rule compilation. The default sensealg for DDEs is `ForwardDiffSensitivity()` for problems with fewer than 100 parameters (SciMLSensitivity.jl `concrete_solve.jl:434-454`), and that path uses ForwardDiff dual numbers inside the rrule — which Mooncake handles fine. Drop the explicit `ReverseDiffAdjoint` and let the default pick. Replace the narrative line that explained the explicit choice with a note about the automatic ForwardDiff/ReverseDiff fallback for DDEs (continuous adjoints are not yet defined for DDEs, so the discretize-then-optimize methods are the only option here). ## brusselator.md (un-reverted) CA SciML#351 also unblocks this — the FBDF stiff-PDE adjoint with Lux+CV parameters was the same `increment_and_get_rdata!` mismatch, and once that dispatch lands the default `GaussAdjoint(ZygoteVJP)` flows through Mooncake without further changes. ## simplechains.md (note expanded) I tested the full matrix (default, `QuadratureAdjoint(ZygoteVJP)`, `QuadratureAdjoint(MooncakeVJP)`, `InterpolatingAdjoint(ReverseDiffVJP)`, `GaussAdjoint(MooncakeVJP)`) and **none of them work** with the SimpleChains+`StaticArrays` out-of-place flow. Each fails for a different reason — the new note enumerates all four with the exact upstream symptom so future contributors know which layer needs to grow the missing dispatch. Notable findings: - The default sensealg picks `GaussAdjoint`, which trips an `@assert sensealg isa QuadratureAdjoint` in `adjoint_common.jl:747` because `u::SVector` is immutable and only `QuadratureAdjoint` is wired up for the immutable-state path. - `QuadratureAdjoint(autojacvec=ZygoteVJP())` (the explicit choice in the file) emits a `ChainRulesCore.Tangent` cotangent that SciMLSensitivity's `df_iip`/`df_oop` adjoint backpass can't unwrap — Mooncake's pullback fails with a `BoundsError` accessing the nested `Tangent` fields. Zygote produces a different cotangent shape that flows through cleanly, which is why the tutorial works on `AutoZygote` but not `AutoMooncake`. - `QuadratureAdjoint(autojacvec=MooncakeVJP())` and `GaussAdjoint(autojacvec=MooncakeVJP())` both fail with `setindex!(::SVector, …)` — `MooncakeVJP` mutates the cotangent buffer in place, which has no method for static arrays. - `InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))` fails with `conversion to pointer not defined for ReverseDiff.TrackedArray` — SimpleChains reaches into raw pointer storage that's incompatible with ReverseDiff-tracked types. ## second_order_neural.md (note refined) This is **not a missing rule**. `SecondOrderODEProblem` constructs an `ODEProblem{…, SciMLBase.SecondOrderODEProblem{false}}` wrapping a `DynamicalODEFunction`, so the existing `_concrete_solve_adjoint(::AbstractODEProblem, …)` methods dispatch fine. The actual blocker is a `df_iip`/`df_oop` bug in `SciMLSensitivity/src/concrete_solve.jl`: when the state is an `ArrayPartition{Tuple{Vector,Vector}}` (which is what `SecondOrderODEProblem` uses internally), the Mooncake-originated cotangent comes back shaped as `ChainRulesCore.Tangent{NamedTuple{x::Tangent{Tuple{Vector,Vector}}}}`, and the adjoint backpass calls `vec(x)` on this nested `Tangent` and raises `MethodError: no method matching vec(::ChainRulesCore.Tangent)`. Zygote happens to produce a different (recursively-array-shaped) cotangent that flows through, which is why the tutorial works on Zygote but not Mooncake. The note now records this precisely so the fix path is clear: add a `Tangent` → `ArrayPartition` unwrap inside `df_iip`/`df_oop`. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
|
The exact reason was I believe the issue noticed and worked upon here: chalk-lab/Mooncake.jl#1102 which was superseded via chalk-lab/Mooncake.jl#1103. basically for unique |
The existing `increment_and_get_rdata!` method only matched a raw
`Array{P}` tangent against a flat-`Array`-backed ComponentVector fdata.
In practice the tangent coming out of a `ChainRulesCore.rrule` for a
ComponentArray primal is usually *another* ComponentArray (e.g. via
`ComponentArray(Δ, getaxes(x))`), so downstream packages that declare a
`@from_rrule` / `@from_chainrules` boundary with a ComponentArray
argument hit
ArgumentError: The fdata type ... ComponentVector{...} combination
is not supported with @from_chainrules or @from_rrule.
This is what blocked the Mooncake migration of the SciMLSensitivity.jl
tutorials in SciML/SciMLSensitivity.jl#1419 (the `feedback_control.md`
and `second_order_neural.md` notes). Widen the dispatch to cover:
- flat-`Array`-backed ComponentVector fdata with an incoming
`ComponentArray` cotangent (unwrap to the underlying storage),
- SubArray-backed ComponentVector fdata (produced by
`getproperty(::ComponentVector, ::Symbol)`) with either an `Array`
or a `ComponentArray` cotangent — handled for the common
full-parent-coverage case, with a clear `ArgumentError` for the
partial-view case that would otherwise silently misplace gradient
mass.
Tests: exercise both native Mooncake (`prepare_gradient_cache` +
`value_and_gradient!!` over nested `ComponentArray(; u0, p_all)`) and
the `@from_rrule` round-trip path that the new methods target. Adds
Mooncake to `test/autodiff/Project.toml` (pinned to `0.5.26` to match
the `friendly_tangent_cache` symbol the extension already references).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Once both of these land: - SciML#1422 (df_iip/df_oop ArrayPartition cotangent unwrap) - SciML/RecursiveArrayTools.jl#575 (Mooncake increment_and_get_rdata! for ArrayPartition) this tutorial works end-to-end under `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`. Verified locally with both PRs applied (Lux + StatefulLuxLayer + SecondOrderODEProblem Adam loop trains). Drops the explanatory `!!! note` and adds the `import Mooncake`. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
- second_order_adjoints.md: Phase 1 (Adam) now uses AutoMooncake, Phase 2 (NewtonTrustRegion) stays on AutoZygote (Hessian via SecondOrder(ForwardDiff, Zygote)) pending forward-over-Mooncake support (chalk-lab/Mooncake.jl#1142). Split into two OptimizationFunctions to avoid applying the wrong backend to Phase 2. - brusselator.md: switch AutoZygote → AutoMooncake with friendly_tangents. Tested locally with N_GRID=8 and shortened tspan — Mooncake gradient chain works end-to-end (loss decreasing from 0.131 to 0.059 in 3 Adam steps). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Remove all `!!! note` blocks explaining why specific tutorials still use Zygote. The information is now tracked as GitHub issues: - SciML#1424: EnsembleProblem tutorials (optimization_sde, SDE_control, data_parallel) - SciML#1425: SimpleChains + StaticArrays tutorial - SciML#1426: Nested ComponentArray partial-cover SubArray (feedback_control) - SciML#1427: Second-order adjoints forward-over-reverse Hessian Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
WIP — depends on JuliaDiff/DifferentiationInterface.jl#989
This PR is marked as a draft because it depends on the upstream
DI Mooncake fix in JuliaDiff/DifferentiationInterface.jl#989. Until that
PR is merged and a DI release is tagged, the docs build here will fail
in the same
MethodError: no method matching iterate(::Mooncake.Tangent)spot we just patched in DI.
Summary
With the recent SciMLSensitivity-side Mooncake fixes (#1376, #1397, #1412)
and the upstream DI Mooncake patch in
JuliaDiff/DifferentiationInterface.jl#989 (which unwraps
Mooncake.Tangent/MutableTangentcotangents back to the primal typeso
OptimizationBase.gradient!no longer dies oncopyto!(::ComponentVector, ::Mooncake.Tangent)), Mooncake now worksend-to-end for the majority of the SciMLSensitivity tutorials.
This PR migrates every doc/tutorial that I could verify runs cleanly
under
OPT.AutoMooncake(; config = nothing)(or the direct Mooncake /DifferentiationInterface API for non-Optimization examples). Each
migrated example was executed locally on Julia 1.11 against
SciMLSensitivity master + the patched DI to confirm the gradient flows
and the optimizer makes progress.
Migrated to Mooncake (verified)
getting_started.mdZygote.gradient(loss, u0, p)→DI.gradient(closure, AutoMooncake, p)(also exercisesGaussAdjoint)manual/differential_equation_sensitivities.mdtutorials/parameter_estimation_ode.mdOPT.AutoZygote()→OPT.AutoMooncake(; config = nothing), PolyOpt converges to ≈2e-6 in 100 stepstutorials/chaotic_ode.mdZygote.gradient(p -> G(p), p)→DI.gradient(...)forForwardLSStutorials/training_tips/divergence.mdtutorials/training_tips/local_minima.mdAutoZygoteswapstutorials/training_tips/multiple_nn.mdInterpolatingAdjoint(ReverseDiffVJP)examples/ode/exogenous_input.mdexamples/hybrid_jump/bouncing_ball.mdOPT.AutoMooncakeswap + replacedsol[end][1]withlast(sol.u)[1]to dodge a pre-existingBoundsErrorinSciMLBaseMooncakeExt._scatter_pullbackforgetindex(::ODESolution, end)examples/optimal_control/optimal_control.mdimport Zygote(the example usesOPT.AutoForwardDiff())examples/pde/pde_constrained.mdexamples/sde/optimization_sde.md(Example 3)ForwardDiffSensitivity()Benchmark.mdZygote.gradient($loss_neuralode, $u0, $ps, $st)→DI.gradient($loss_ps, $backend, $ps)with a closure (```julia block, not executed)faq.mdZygote.pullbacktoMooncake.prepare_pullback_cache/Mooncake.value_and_pullback!!index.mddocs/Project.tomlLeft on Zygote with an explanatory `!!! note`
The remaining tutorials hit one of three independent upstream
blockers in Mooncake itself, all of which are out of scope for a docs
PR. I left them on `OPT.AutoZygote()` and added a callout pointing at
the specific failure mode so future contributors know what to monitor:
`EnsembleProblem` rule compilation fails
`StackOverflowError` inside Mooncake's rule compiler when it tries to
differentiate `__solve(::AbstractEnsembleProblem, …)`:
fit)
which would translate to
`Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(CreateGrid), Any, Any}`
once the EnsembleProblem blocker is resolved — I left a one-line note
so it's a trivial fix later)
`MethodOfSteps` DDE adjoint fails
`StackOverflowError` during rule compilation of the `DDEProblem` solve:
ComponentArrays cotangent / SciMLBase Mooncake-extension gaps
Missing `increment_and_get_rdata!` method, `ReverseDiffAdjoint`-tracked
values that don't match Mooncake's `CoDual` type expectations, nested
`ComponentVector` cotangents, or `SecondOrder(AutoMooncake, AutoMooncake)`
fallback:
Hessian; the Adam-only first half does work with Mooncake but the
point of the tutorial is the second-order optimization)
`ComponentArray(; u0, p_all)`)
auto-selected adjoint)
Each note records the exact error so it's clear which Mooncake / SciMLBase
upstream patch unblocks the migration. When that lands, switching the
remaining files is mechanical.
Test plan
matched a pattern I ran end-to-end (Lux+CA neural ODE training,
Optimization+CA loop, etc.) against SciMLSensitivity master +
the patched DI on Julia 1.11
DI release tag, then re-run the docs build here
🤖 Generated with Claude Code