Skip to content

Commit c768c3c

Browse files
Merge pull request #1419 from ChrisRackauckas-Claude/docs-prefer-mooncake
[WIP] docs: prefer Mooncake over Zygote where it works end-to-end
2 parents 38ae573 + f4e8553 commit c768c3c

25 files changed

Lines changed: 124 additions & 81 deletions

docs/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
77
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
88
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
99
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
10+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1011
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1112
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1213
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -15,6 +16,7 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1516
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1617
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
1718
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
19+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1820
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1921
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
2022
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
@@ -35,13 +37,14 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3537

3638
[compat]
3739
Calculus = "0.5"
38-
ComponentArrays = "0.15"
40+
ComponentArrays = "0.15.34"
3941
DataInterpolations = "3.10, 4, 5, 6, 7, 8"
4042
DelayDiffEq = "5"
4143
DelimitedFiles = "1"
4244
DiffEqCallbacks = "2.24, 3, 4"
4345
DiffEqNoiseProcess = "5.14"
4446
DifferentialEquations = "7"
47+
DifferentiationInterface = "0.6, 0.7"
4548
Documenter = "1"
4649
Enzyme = "0.12, 0.13"
4750
Flux = "0.14, 0.15, 0.16"
@@ -50,6 +53,7 @@ IterTools = "1"
5053
Lux = "1"
5154
LuxCUDA = "0.3"
5255
MLUtils = "0.4"
56+
Mooncake = "0.5"
5357
Optimization = "3.9, 4, 5"
5458
OptimizationOptimJL = "0.2, 0.3, 0.4"
5559
OptimizationOptimisers = "0.2, 0.3"

docs/src/Benchmark.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ Quick summary:
4040
import OrdinaryDiffEq as ODE
4141
import Lux
4242
import SciMLSensitivity as SMS
43-
import Zygote
43+
import Mooncake
44+
import DifferentiationInterface as DI
4445
import BenchmarkTools
4546
import Random
4647
import ComponentArrays as CA
@@ -80,7 +81,9 @@ for sensealg in (SMS.InterpolatingAdjoint(autojacvec = SMS.ZygoteVJP()),
8081
return loss
8182
end
8283

83-
t = BenchmarkTools.@belapsed Zygote.gradient($loss_neuralode, $u0, $ps, $st)
84+
backend = DI.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
85+
loss_ps = p -> loss_neuralode(u0, p, st)
86+
t = BenchmarkTools.@belapsed DI.gradient($loss_ps, $backend, $ps)
8487
println("$(sensealg) took $(t)s")
8588
end
8689

docs/src/examples/dde/delay_diffeq.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Optimization as OPT
1010
import SciMLSensitivity as SMS
1111
import OptimizationPolyalgorithms as OPA
1212
import DelayDiffEq as DDE
13+
import Mooncake
1314
1415
# Define the same LV equation, but including a delay parameter
1516
function delay_lotka_volterra!(du, u, h, p, t)
@@ -35,7 +36,7 @@ prob_dde = DDE.DDEProblem(delay_lotka_volterra!, u0, h, (0.0, 10.0),
3536
3637
function predict_dde(p)
3738
return Array(ODE.solve(prob_dde, DDE.MethodOfSteps(ODE.Tsit5());
38-
u0, p, saveat = 0.1, sensealg = SMS.ReverseDiffAdjoint()))
39+
u0, p, saveat = 0.1))
3940
end
4041
4142
loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p))
@@ -50,14 +51,18 @@ callback = function (state, l; doplot = false)
5051
return false
5152
end
5253
53-
adtype = OPT.AutoZygote()
54+
adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
5455
optf = OPT.OptimizationFunction((x, p) -> loss_dde(x), adtype)
5556
optprob = OPT.OptimizationProblem(optf, p)
5657
result_dde = OPT.solve(optprob, OPA.PolyOpt(); maxiters = 300, callback)
5758
```
5859

59-
Notice that we chose `sensealg = ReverseDiffAdjoint()` to utilize the ReverseDiff.jl
60-
reverse-mode to handle the delay differential equation.
60+
The `sensealg` is left at its default. For DDEs the automatic choice is
61+
[`ForwardDiffSensitivity`](@ref) (which differentiates through
62+
`MethodOfSteps` via dual numbers) for problems with fewer than 100
63+
parameters, and [`ReverseDiffAdjoint`](@ref) for larger ones —
64+
[continuous adjoints](@ref sensitivity_diffeq) are not yet defined for
65+
DDEs, so the discretize-then-optimize methods are the only option.
6166

6267
We define a callback to display the solution at the current parameters for each step of the training:
6368

@@ -76,7 +81,7 @@ end
7681
We use `Optimization.solve` to optimize the parameters for our loss function:
7782

7883
```@example dde
79-
adtype = OPT.AutoZygote()
84+
adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
8085
optf = OPT.OptimizationFunction((x, p) -> loss_dde(x), adtype)
8186
optprob = OPT.OptimizationProblem(optf, p)
8287
result_dde = OPT.solve(optprob, OPA.PolyOpt(); callback)

docs/src/examples/hybrid_jump/bouncing_ball.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import OptimizationPolyalgorithms as OPA
1313
import SciMLSensitivity as SMS
1414
import OrdinaryDiffEq as ODE
1515
import DiffEqCallbacks as DEC
16+
import Mooncake
1617
1718
function f(du, u, p, t)
1819
du[1] = u[2]
@@ -44,11 +45,15 @@ the value 20:
4445
function loss(θ)
4546
sol = ODE.solve(prob, ODE.Tsit5(), p = [9.8, θ[1]]; callback)
4647
target = 20.0
47-
abs2(sol[end][1] - target)
48+
# Use `last(sol.u)[1]` instead of `sol[end][1]` — Mooncake's pullback for
49+
# `getindex(::ODESolution, end)` currently has a `BoundsError` bug
50+
# (`SciMLBaseMooncakeExt._scatter_pullback`). Indexing the underlying
51+
# `sol.u::Vector{Vector{Float64}}` directly avoids the bad path.
52+
abs2(last(sol.u)[1] - target)
4853
end
4954
5055
loss([0.8])
51-
adtype = OPT.AutoZygote()
56+
adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
5257
optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype)
5358
optprob = OPT.OptimizationProblem(optf, [0.8])
5459
@time res = OPT.solve(optprob, OPA.PolyOpt(), maxiters = 300)

docs/src/examples/hybrid_jump/hybrid_diffeq.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import ComponentArrays as CA
1212
import Random
1313
import SciMLSensitivity as SMS
1414
import Lux
15+
import Mooncake
1516
import OrdinaryDiffEq as ODE
1617
import Plots
1718
import Optimization as OPT
@@ -50,9 +51,7 @@ cb = DEC.PresetTimeCallback(dosetimes, affect!, save_positions = (false, false))
5051
5152
function predict_n_ode(p)
5253
_prob = ODE.remake(prob; p)
53-
Array(ODE.solve(_prob, ODE.Tsit5(); u0 = z0, p, callback = cb, saveat = t,
54-
sensealg = SMS.ReverseDiffAdjoint()))[1:2, :]
55-
#Array(solve(prob,Tsit5();u0=z0,p,saveat=t))[1:2,:]
54+
Array(ODE.solve(_prob, ODE.Tsit5(); u0 = z0, p, callback = cb, saveat = t))[1:2, :]
5655
end
5756
5857
function loss_n_ode(p, _)
@@ -73,7 +72,10 @@ cba = function (state, l; doplot = false) #callback function to observe training
7372
end
7473
7574
res = OPT.solve(
76-
OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode, OPT.AutoZygote()),
75+
OPT.OptimizationProblem(
76+
OPT.OptimizationFunction(
77+
loss_n_ode, OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
78+
),
7779
CA.ComponentArray(ps)),
7880
OPO.Adam(0.05); callback = cba, maxiters = 1000)
7981
```

docs/src/examples/neural_ode/simplechains.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Faster Neural Ordinary Differential Equations with SimpleChains
22

3+
34
[SimpleChains](https://github.com/PumasAI/SimpleChains.jl) has demonstrated performance boosts of ~5x and ~30x when compared to other mainstream deep learning frameworks like Pytorch for the training and evaluation in the specific case of small neural networks. For the nitty-gritty details, as well as, some SciML related videos around the need and applications of such a library, we can refer to this [blogpost](https://julialang.org/blog/2022/04/simple-chains/). As for doing Scientific Machine Learning, how do we even begin with training neural ODEs with any generic deep learning library?
45

56
## Training Data

docs/src/examples/ode/exogenous_input.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ import OptimizationPolyalgorithms as OPA
4949
import OptimizationOptimisers as OPO
5050
import Plots
5151
import Random
52+
import Mooncake
5253
5354
rng = Random.default_rng()
5455
tspan = (0.1, 10.0)
@@ -93,7 +94,7 @@ function loss(p)
9394
return sum(abs2.(y[1:N] .- sol')) / N
9495
end
9596
96-
adtype = OPT.AutoZygote()
97+
adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
9798
optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype)
9899
optprob = OPT.OptimizationProblem(optf, CA.ComponentArray{Float64}(p_model))
99100

docs/src/examples/ode/second_order_adjoints.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ optimization, while `KrylovTrustRegion` will utilize a Krylov-based method
1313
with Hessian-vector products (never forming the Hessian) for large parameter
1414
optimizations.
1515

16+
1617
```@example secondorderadjoints
1718
import SciMLSensitivity as SMS
1819
import Lux
@@ -23,6 +24,7 @@ import OrdinaryDiffEq as ODE
2324
import Plots
2425
import Random
2526
import OptimizationOptimJL as OOJ
27+
import Mooncake
2628
2729
u0 = Float32[2.0; 0.0]
2830
datasize = 30
@@ -83,13 +85,14 @@ callback = function (state, l; doplot = false)
8385
return l < 0.01
8486
end
8587
86-
adtype = OPT.AutoZygote()
87-
optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
88-
89-
optprob1 = OPT.OptimizationProblem(optf, ps)
88+
adtype1 = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
89+
optf1 = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype1)
90+
optprob1 = OPT.OptimizationProblem(optf1, ps)
9091
pstart = OPT.solve(optprob1, OPO.Adam(0.01); callback, maxiters = 100).u
9192
92-
optprob2 = OPT.OptimizationProblem(optf, pstart)
93+
adtype2 = OPT.AutoZygote()
94+
optf2 = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype2)
95+
optprob2 = OPT.OptimizationProblem(optf2, pstart)
9396
pmin = OPT.solve(optprob2, OOJ.NewtonTrustRegion(); callback, maxiters = 200)
9497
```
9598

docs/src/examples/ode/second_order_neural.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import OptimizationOptimisers as OPO
2929
import RecursiveArrayTools
3030
import Random
3131
import ComponentArrays as CA
32+
import Mooncake
3233
3334
u0 = Float32[0.0; 2.0]
3435
du0 = Float32[0.0; 0.0]
@@ -61,7 +62,7 @@ callback = function (state, l)
6162
l < 0.01
6263
end
6364
64-
adtype = OPT.AutoZygote()
65+
adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
6566
optf = OPT.OptimizationFunction((x, p) -> loss_n_ode(x), adtype)
6667
optprob = OPT.OptimizationProblem(optf, ps)
6768

docs/src/examples/optimal_control/feedback_control.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ You can also mix a known differential equation and a neural differential
44
equation, so that the parameters and the neural network are estimated
55
simultaneously!
66

7+
78
We will assume that we know the dynamics of the second equation
89
(linear dynamics), and our goal is to find a neural network that is dependent
910
on the current state of the dynamical system that will control the second

0 commit comments

Comments
 (0)