Skip to content

Commit f8deaaf

Browse files
docs: switch second_order_adjoints Phase 1 and brusselator to Mooncake
- second_order_adjoints.md: Phase 1 (Adam) now uses AutoMooncake, Phase 2 (NewtonTrustRegion) stays on AutoZygote (Hessian via SecondOrder(ForwardDiff, Zygote)) pending forward-over-Mooncake support (chalk-lab/Mooncake.jl#1142). Split into two OptimizationFunctions to avoid applying the wrong backend to Phase 2. - brusselator.md: switch AutoZygote → AutoMooncake with friendly_tangents. Tested locally with N_GRID=8 and shortened tspan — Mooncake gradient chain works end-to-end (loss decreasing from 0.131 to 0.059 in 3 Adam steps). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 23770d7 commit f8deaaf

2 files changed

Lines changed: 29 additions & 16 deletions

File tree

docs/src/examples/ode/second_order_adjoints.md

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@ optimizations.
1515

1616
!!! note
1717

18-
This example still uses Zygote because `NewtonTrustRegion` needs a true
19-
Hessian, and Mooncake does not yet have a forward-over-Mooncake path that
20-
Optimization.jl can use to assemble one (the auto-fallback to
21-
`SecondOrder(AutoMooncake(), AutoMooncake())` raises `ArgumentError`).
22-
The Adam phase below works fine with `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`
23-
if you only want first-order training; the full Newton/NewtonTrustRegion
24-
pipeline will become Mooncake-compatible once forward-mode Mooncake
25-
matures.
18+
The Adam (first-order) phase below uses Mooncake. The
19+
`NewtonTrustRegion` (second-order) phase still uses Zygote because
20+
Mooncake currently has no working forward-over-reverse path through
21+
`SciMLSensitivity` + `OrdinaryDiffEq`: `SecondOrder(AutoMooncake(),
22+
AutoMooncake())` raises a "reverse-over-reverse not supported" error
23+
and `SecondOrder(AutoForwardDiff(), AutoMooncake())` is blocked on
24+
Mooncake's `IEEEFloat`-only gradient interface (it rejects
25+
`ForwardDiff.Dual` as the primal type). Tracking issue:
26+
[chalk-lab/Mooncake.jl#1142](https://github.com/chalk-lab/Mooncake.jl/pull/1142)
27+
is the first step in unblocking this. Once forward-over-Mooncake is
28+
available end-to-end, this tutorial can be switched to Mooncake for
29+
both phases.
2630

2731
```@example secondorderadjoints
2832
import SciMLSensitivity as SMS
@@ -34,6 +38,7 @@ import OrdinaryDiffEq as ODE
3438
import Plots
3539
import Random
3640
import OptimizationOptimJL as OOJ
41+
import Mooncake
3742
3843
u0 = Float32[2.0; 0.0]
3944
datasize = 30
@@ -94,13 +99,20 @@ callback = function (state, l; doplot = false)
9499
return l < 0.01
95100
end
96101
97-
adtype = OPT.AutoZygote()
98-
optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
99-
100-
optprob1 = OPT.OptimizationProblem(optf, ps)
102+
# First-order training: Mooncake gives the gradient through the
103+
# `SciMLSensitivity` adjoint chain.
104+
adtype1 = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
105+
optf1 = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype1)
106+
optprob1 = OPT.OptimizationProblem(optf1, ps)
101107
pstart = OPT.solve(optprob1, OPO.Adam(0.01); callback, maxiters = 100).u
102108
103-
optprob2 = OPT.OptimizationProblem(optf, pstart)
109+
# Second-order training: NewtonTrustRegion needs a true Hessian, which
110+
# `OptimizationBase` assembles via `SecondOrder(AutoForwardDiff(),
111+
# AutoZygote())`. Mooncake cannot fill that role yet (see the note above),
112+
# so the Hessian phase keeps the Zygote VJP.
113+
adtype2 = OPT.AutoZygote()
114+
optf2 = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype2)
115+
optprob2 = OPT.OptimizationProblem(optf2, pstart)
104116
pmin = OPT.solve(optprob2, OOJ.NewtonTrustRegion(); callback, maxiters = 200)
105117
```
106118

docs/src/examples/pde/brusselator.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ First, we have to define and configure the neural network that has to be used fo
156156

157157
```@example bruss
158158
import Lux, Random, Optimization as OPT, OptimizationOptimJL as OOJ,
159-
SciMLSensitivity as SMS, Zygote
159+
SciMLSensitivity as SMS, Mooncake
160160
161161
model = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1))
162162
rng = Random.default_rng()
@@ -223,12 +223,13 @@ function loss_fn(ps, _)
223223
end
224224
```
225225

226-
Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's `Optimization.jl` tools, and gradients are computed via automatic differentiation using `AutoZygote()` from `SciMLSensitivity`:
226+
Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's `Optimization.jl` tools, and gradients are computed via automatic differentiation using Mooncake through the `SciMLSensitivity` adjoint chain:
227227

228228
```@example bruss
229229
println("[Training] Starting optimization...")
230230
import OptimizationOptimisers as OPO
231-
optf = OPT.OptimizationFunction(loss_fn, SMS.AutoZygote())
231+
adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
232+
optf = OPT.OptimizationFunction(loss_fn, adtype)
232233
optprob = OPT.OptimizationProblem(optf, ps_init)
233234
loss_history = Float32[]
234235

0 commit comments

Comments
 (0)