Skip to content

Commit 38ae573

Browse files
Merge pull request #1421 from ChrisRackauckas-Claude/mooncakevjp-callback-compat
Add MooncakeVJP to callback-compatible VJP methods
2 parents c0a4647 + 6d8e14f commit 38ae573

6 files changed

Lines changed: 145 additions & 11 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Lux = "1"
8686
Markdown = "1.10"
8787
ModelingToolkit = "10, 11"
8888
ModelingToolkitStandardLibrary = "2"
89-
Mooncake = "0.5"
89+
Mooncake = "0.5.25"
9090
Reactant = "0.2.22"
9191
NLsolve = "4.5.1"
9292
NonlinearSolve = "3.0.1, 4"

ext/SciMLSensitivityMooncakeExt.jl

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module SciMLSensitivityMooncakeExt
22

3-
using SciMLSensitivity: SciMLSensitivity
3+
using SciMLSensitivity: SciMLSensitivity, FakeIntegrator
44
using Mooncake: Mooncake
5-
import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded,
5+
import SciMLSensitivity: get_paramjac_config, get_cb_paramjac_config, mooncake_run_ad,
6+
MooncakeVJP, MooncakeLoaded,
67
DiffEqBase, MooncakeAdjoint, _init_originator_gradient
78
using SciMLSensitivity: SciMLBase, SciMLStructures, canonicalize, Tunable, isscimlstructure,
89
SciMLStructuresCompatibilityError, convert_tspan,
@@ -35,6 +36,68 @@ function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t)
3536
return cache, pf, λ_mem, dy_mem, p_grad_buf
3637
end
3738

39+
"""
40+
get_cb_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, raw_affect, event_idx, y, p, _t, mode)
41+
42+
Build a Mooncake pullback cache for a tracked callback affect function. Mirrors
43+
the `get_cb_paramjac_config(::ReactantLoaded, ::ReactantVJP, ...)` entry point:
44+
`raw_affect` is extracted upfront (`get_affect!(cb, pos_neg)` at the call site)
45+
so the Mooncake-traced closure does not need to recursively unwrap
46+
`TrackedAffect`, which would otherwise trip on the `Base.argument_datatype`
47+
ccall surfaced by that dispatch.
48+
49+
`mode === :state` builds a cache for the state-affect closure (state-sized
50+
output); `mode === :param` builds one for the parameter-affect closure
51+
(parameter-sized output) so its Mooncake cotangent/output buffers match the
52+
flat tunables shape rather than the state shape. The returned 5-tuple has the
53+
same layout as `get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, ...)` so
54+
`_vecjacobian!(::MooncakeVJP)` / `mooncake_run_ad` can consume it unchanged.
55+
"""
56+
function get_cb_paramjac_config(
57+
::MooncakeLoaded, ::MooncakeVJP, raw_affect, event_idx, y, p, _t, mode
58+
)
59+
has_event_idx = event_idx !== nothing
60+
tprev0 = _t
61+
62+
if mode === :state
63+
pf = let raw = raw_affect, ev = event_idx, tprev = tprev0, has_ev = has_event_idx
64+
(out, u, p, t) -> begin
65+
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
66+
if has_ev
67+
raw(fakeinteg, ev)
68+
else
69+
raw(fakeinteg)
70+
end
71+
copyto!(out, fakeinteg.u)
72+
return out
73+
end
74+
end
75+
out_sample = y
76+
elseif mode === :param
77+
pf = let raw = raw_affect, ev = event_idx, tprev = tprev0, has_ev = has_event_idx
78+
(out, u, p, t) -> begin
79+
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
80+
if has_ev
81+
raw(fakeinteg, ev)
82+
else
83+
raw(fakeinteg)
84+
end
85+
copyto!(out, fakeinteg.p)
86+
return out
87+
end
88+
end
89+
out_sample = p
90+
else
91+
error("get_cb_paramjac_config: unknown mode $(mode); expected :state or :param")
92+
end
93+
94+
dy_mem = zero(out_sample)
95+
λ_mem = zero(out_sample)
96+
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
97+
p_grad_buf = p isa AbstractArray && !(p isa Array) ? similar(p) : nothing
98+
return cache, pf, λ_mem, dy_mem, p_grad_buf
99+
end
100+
38101
function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ)
39102
cache, pf, λ_mem, dy_mem, p_grad_buf = paramjac_config
40103
λ_mem .= λ

src/adjoint_common.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,12 @@ function get_cb_paramjac_config(::Any, ::ReactantVJP, raw_affect, event_idx, y,
666666
error(msg)
667667
end
668668

669+
function get_cb_paramjac_config(::Any, ::MooncakeVJP, raw_affect, event_idx, y, p, _t, mode)
670+
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " *
671+
"`using Mooncake` to use this functionality"
672+
error(msg)
673+
end
674+
669675
function getprob(S::SensitivityFunction)
670676
return (S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob
671677
end

src/callback_tracking.jl

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,12 @@ function _setup_reverse_callbacks(
283283
# if save_positions = [1,0] the gradient contribution is added before, and in principle we would need to correct the adjoint state again. Therefore,
284284

285285
cb.save_positions == [1, 0] && error("save_positions=[1,0] is currently not supported.")
286-
# Callbacks require ReverseDiffVJP or EnzymeVJP for their own VJP computations.
287-
# The ODE adjoint may use a different autojacvec (even numerical/false), but the
288-
# callback affect functions (CallbackAffectWrapper) are separate and typically work
289-
# with ReverseDiff even when the ODE function doesn't.
290-
cb_autojacvec = if sensealg.autojacvec isa Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP}
286+
# Callback affect functions (CallbackAffectWrapper) are traced separately
287+
# from the ODE function, so the ODE adjoint may use a different autojacvec
288+
# (even numerical/false) while the callback path uses its own compatible
289+
# backend. The `supports_callback_vjp` trait marks which VJP backends have
290+
# a dedicated callback path; anything else falls back to ReverseDiffVJP.
291+
cb_autojacvec = if supports_callback_vjp(sensealg.autojacvec)
291292
sensealg.autojacvec
292293
else
293294
@warn "autojacvec=$(sensealg.autojacvec) is not compatible with callbacks, using ReverseDiffVJP() for callback VJPs"
@@ -506,9 +507,10 @@ end
506507

507508
function setup_w_wp(
508509
cb::Union{DiscreteCallback, ContinuousCallback, VectorContinuousCallback},
509-
autojacvec::Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP}, pos_neg, event_idx,
510-
tprev
510+
autojacvec, pos_neg, event_idx, tprev
511511
)
512+
supports_callback_vjp(autojacvec) ||
513+
error("setup_w_wp called with a VJP backend that does not support callbacks: $(autojacvec). This is an internal error — the callback path should have redirected to a compatible backend in `_setup_reverse_callbacks`.")
512514
w = CallbackAffectWrapper(cb, autojacvec, pos_neg, event_idx, tprev)
513515
wp = CallbackAffectPWrapper(cb, autojacvec, pos_neg, event_idx, tprev)
514516
return w, wp
@@ -519,6 +521,7 @@ function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
519521
end
520522
get_FakeIntegrator(autojacvec::EnzymeVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
521523
get_FakeIntegrator(autojacvec::ReactantVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
524+
get_FakeIntegrator(autojacvec::MooncakeVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
522525

523526
function _get_wp_paramjac_config(autojacvec::EnzymeVJP, _p, wp, y, __p, _t)
524527
return (zero(y), zero(_p), zero(_p), zero(_p), zero(y))
@@ -603,6 +606,41 @@ function get_cb_diffcaches(
603606
nothing, nothing, nothing, false,
604607
nothing, identity
605608
)
609+
elseif autojacvec isa MooncakeVJP
610+
# MooncakeVJP: build Mooncake pullback caches for the
611+
# state-affect and parameter-affect callback wrappers via
612+
# `get_cb_paramjac_config` in the Mooncake extension, mirroring
613+
# the ReactantVJP branch above. Mooncake can't trace through
614+
# the recursive TrackedAffect unwrapping in
615+
# `CallbackAffectWrapper` (it trips on the
616+
# `Base.argument_datatype` ccall surfaced by that dispatch),
617+
# so the extension hoists `raw_affect` out and bakes it into
618+
# a flat (out, u, p, t) closure before preparing the cache.
619+
raw_affect = get_affect!(cb, pos_neg)
620+
621+
w_paramjac = get_cb_paramjac_config(
622+
MooncakeLoaded(), autojacvec, raw_affect,
623+
event_idx, y, _p, _t, :state
624+
)
625+
diffcache_w = AdjointDiffCache(
626+
nothing, nothing, nothing, nothing, nothing,
627+
nothing, nothing, nothing, w_paramjac,
628+
nothing, nothing, nothing, nothing, nothing,
629+
nothing, nothing, nothing, false,
630+
nothing, identity
631+
)
632+
633+
wp_paramjac = get_cb_paramjac_config(
634+
MooncakeLoaded(), autojacvec, raw_affect,
635+
event_idx, y, _p, _t, :param
636+
)
637+
diffcache_wp = AdjointDiffCache(
638+
nothing, nothing, nothing, nothing, nothing,
639+
nothing, nothing, nothing, wp_paramjac,
640+
nothing, nothing, nothing, nothing, nothing,
641+
nothing, nothing, nothing, false,
642+
nothing, identity
643+
)
606644
else
607645
w, wp = setup_w_wp(cb, autojacvec, pos_neg, event_idx, _t)
608646

src/sensitivity_algorithms.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,23 @@ supports_structured_vjp(::ReverseDiffVJP) = false
16961696
supports_structured_vjp(::Bool) = false
16971697
supports_structured_vjp(::Nothing) = false
16981698

1699+
"""
1700+
supports_callback_vjp(autojacvec) -> Bool
1701+
1702+
Return `true` if the VJP backend can differentiate the affect functions of
1703+
tracked callbacks (i.e. has a dedicated branch in `get_cb_diffcaches` and a
1704+
`_vecjacobian!` path that works with `CallbackSensitivityFunction`).
1705+
1706+
When `false`, the callback adjoint code falls back to `ReverseDiffVJP()` for
1707+
its own VJP computations — the ODE adjoint may still use the user-requested
1708+
backend, since the callback affect functions are traced separately.
1709+
"""
1710+
supports_callback_vjp(::ReverseDiffVJP) = true
1711+
supports_callback_vjp(::EnzymeVJP) = true
1712+
supports_callback_vjp(::ReactantVJP) = true
1713+
supports_callback_vjp(::MooncakeVJP) = true
1714+
supports_callback_vjp(::Any) = false
1715+
16991716
"""
17001717
```julia
17011718
ForwardDiffOverAdjoint{A} <: AbstractSecondOrderSensitivityAlgorithm{nothing, true, nothing}

test/callbacks/continuous_callbacks.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using OrdinaryDiffEq, Zygote, Reactant
1+
using OrdinaryDiffEq, Zygote, Reactant, Mooncake
22
using SciMLSensitivity, Test, ForwardDiff, FiniteDiff
3+
using SciMLSensitivity: MooncakeVJP
34

45
abstol = 1.0e-12
56
reltol = 1.0e-12
@@ -362,5 +363,14 @@ println("Continuous Callbacks")
362363
sensealg = GaussAdjoint(autojacvec = ReactantVJP(allow_scalar = true))
363364
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
364365
@test gFD gZy rtol = 1.0e-10
366+
367+
# MooncakeVJP with callbacks
368+
sensealg = InterpolatingAdjoint(autojacvec = MooncakeVJP())
369+
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
370+
@test gFD gZy rtol = 1.0e-10
371+
372+
sensealg = GaussAdjoint(autojacvec = MooncakeVJP())
373+
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
374+
@test gFD gZy rtol = 1.0e-10
365375
end
366376
end

0 commit comments

Comments
 (0)