Skip to content

Commit 12e6b06

Browse files
Route callback Mooncake cache through get_cb_paramjac_config
Mirror the ReactantVJP layout: `src/callback_tracking.jl` now calls `get_cb_paramjac_config(MooncakeLoaded(), autojacvec, raw_affect, event_idx, y, _p, _t, mode)` with `mode = :state` / `:param`, and the MooncakeVJP branch collapses to the same shape as the ReactantVJP branch (just the two config/diffcache constructions, no inline closure building or SciMLStructures handling in the hot path). The actual Mooncake pullback-cache assembly — hoisting `raw_affect`, building the flat `(out, u, p, t)` closure that bakes it in, and sizing the `dy_mem` / `λ_mem` buffers to state-sized (for :state) or parameter-sized (for :param) — now lives in a dedicated `get_cb_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, ...)` method in `ext/SciMLSensitivityMooncakeExt.jl`, with a matching fallback in `src/adjoint_common.jl` next to the ReactantVJP fallback. This also reverts the `out_sample` kwarg added to `get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, ...)` in the previous commits — the wp-sizing concern is now handled inside `get_cb_paramjac_config` where it belongs, rather than by overloading the ODE-path entry point with a callback-specific knob. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 7a5e90d commit 12e6b06

3 files changed

Lines changed: 94 additions & 95 deletions

File tree

ext/SciMLSensitivityMooncakeExt.jl

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module SciMLSensitivityMooncakeExt
22

3-
using SciMLSensitivity: SciMLSensitivity
3+
using SciMLSensitivity: SciMLSensitivity, FakeIntegrator
44
using Mooncake: Mooncake
5-
import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded,
5+
import SciMLSensitivity: get_paramjac_config, get_cb_paramjac_config, mooncake_run_ad,
6+
MooncakeVJP, MooncakeLoaded,
67
DiffEqBase, MooncakeAdjoint, _init_originator_gradient
78
using SciMLSensitivity: SciMLBase, SciMLStructures, canonicalize, Tunable, isscimlstructure,
89
SciMLStructuresCompatibilityError, convert_tspan,
@@ -23,16 +24,9 @@ function _init_originator_gradient(
2324
return igs
2425
end
2526

26-
function get_paramjac_config(
27-
::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t; out_sample = nothing
28-
)
29-
# `out_sample` lets callers size the output/cotangent buffers for functions
30-
# whose output is not state-sized (e.g. the parameter-affect callback wrapper
31-
# whose output has the shape of the flat tunables). Defaults to `y` which is
32-
# correct for the usual ODE/state-output case.
33-
_out_sample = out_sample === nothing ? y : out_sample
34-
dy_mem = zero(_out_sample)
35-
λ_mem = zero(_out_sample)
27+
function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t)
28+
dy_mem = zero(y)
29+
λ_mem = zero(y)
3630
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
3731
# Pre-allocate buffer for tangent_to_primal!! conversion of struct-based
3832
# array types (e.g. ComponentArray) whose Mooncake tangent is Mooncake.Tangent.
@@ -42,6 +36,68 @@ function get_paramjac_config(
4236
return cache, pf, λ_mem, dy_mem, p_grad_buf
4337
end
4438

39+
"""
40+
get_cb_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, raw_affect, event_idx, y, p, _t, mode)
41+
42+
Build a Mooncake pullback cache for a tracked callback affect function. Mirrors
43+
the `get_cb_paramjac_config(::ReactantLoaded, ::ReactantVJP, ...)` entry point:
44+
`raw_affect` is extracted upfront (`get_affect!(cb, pos_neg)` at the call site)
45+
so the Mooncake-traced closure does not need to recursively unwrap
46+
`TrackedAffect`, which would otherwise trip on the `Base.argument_datatype`
47+
ccall surfaced by that dispatch.
48+
49+
`mode === :state` builds a cache for the state-affect closure (state-sized
50+
output); `mode === :param` builds one for the parameter-affect closure
51+
(parameter-sized output) so its Mooncake cotangent/output buffers match the
52+
flat tunables shape rather than the state shape. The returned 5-tuple has the
53+
same layout as `get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, ...)` so
54+
`_vecjacobian!(::MooncakeVJP)` / `mooncake_run_ad` can consume it unchanged.
55+
"""
56+
function get_cb_paramjac_config(
57+
::MooncakeLoaded, ::MooncakeVJP, raw_affect, event_idx, y, p, _t, mode
58+
)
59+
has_event_idx = event_idx !== nothing
60+
tprev0 = _t
61+
62+
if mode === :state
63+
pf = let raw = raw_affect, ev = event_idx, tprev = tprev0, has_ev = has_event_idx
64+
(out, u, p, t) -> begin
65+
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
66+
if has_ev
67+
raw(fakeinteg, ev)
68+
else
69+
raw(fakeinteg)
70+
end
71+
copyto!(out, fakeinteg.u)
72+
return out
73+
end
74+
end
75+
out_sample = y
76+
elseif mode === :param
77+
pf = let raw = raw_affect, ev = event_idx, tprev = tprev0, has_ev = has_event_idx
78+
(out, u, p, t) -> begin
79+
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
80+
if has_ev
81+
raw(fakeinteg, ev)
82+
else
83+
raw(fakeinteg)
84+
end
85+
copyto!(out, fakeinteg.p)
86+
return out
87+
end
88+
end
89+
out_sample = p
90+
else
91+
error("get_cb_paramjac_config: unknown mode $(mode); expected :state or :param")
92+
end
93+
94+
dy_mem = zero(out_sample)
95+
λ_mem = zero(out_sample)
96+
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
97+
p_grad_buf = p isa AbstractArray && !(p isa Array) ? similar(p) : nothing
98+
return cache, pf, λ_mem, dy_mem, p_grad_buf
99+
end
100+
45101
function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ)
46102
cache, pf, λ_mem, dy_mem, p_grad_buf = paramjac_config
47103
λ_mem .= λ

src/adjoint_common.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ struct ReactantVJPConfig{FK, DK, FC, DC}
555555
chunk_size::Int
556556
end
557557

558-
function get_paramjac_config(::Any, ::MooncakeVJP, pf, p, f, y, _t; out_sample = nothing)
558+
function get_paramjac_config(::Any, ::MooncakeVJP, pf, p, f, y, _t)
559559
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " * "`using Mooncake` to use this functionality"
560560
error(msg)
561561
end
@@ -666,6 +666,12 @@ function get_cb_paramjac_config(::Any, ::ReactantVJP, raw_affect, event_idx, y,
666666
error(msg)
667667
end
668668

669+
function get_cb_paramjac_config(::Any, ::MooncakeVJP, raw_affect, event_idx, y, p, _t, mode)
670+
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " *
671+
"`using Mooncake` to use this functionality"
672+
error(msg)
673+
end
674+
669675
function getprob(S::SensitivityFunction)
670676
return (S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob
671677
end

src/callback_tracking.jl

Lines changed: 19 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -608,98 +608,35 @@ function get_cb_diffcaches(
608608
)
609609
elseif autojacvec isa MooncakeVJP
610610
# MooncakeVJP: build Mooncake pullback caches for the
611-
# state-affect and parameter-affect callback wrappers.
612-
# Mooncake can't trace through the recursive TrackedAffect
613-
# unwrapping in `CallbackAffectWrapper` (it trips on the
614-
# `Base.argument_datatype` ccall that surfaces during the
615-
# dispatch), so we mirror the ReactantVJP path and build a
616-
# flat closure that bakes `raw_affect` directly in. The
617-
# state-output closure has y-sized output; the parameter-
618-
# output closure has (flat) tunables-sized output.
611+
# state-affect and parameter-affect callback wrappers via
612+
# `get_cb_paramjac_config` in the Mooncake extension, mirroring
613+
# the ReactantVJP branch above. Mooncake can't trace through
614+
# the recursive TrackedAffect unwrapping in
615+
# `CallbackAffectWrapper` (it trips on the
616+
# `Base.argument_datatype` ccall surfaced by that dispatch),
617+
# so the extension hoists `raw_affect` out and bakes it into
618+
# a flat (out, u, p, t) closure before preparing the cache.
619619
raw_affect = get_affect!(cb, pos_neg)
620-
_has_event_idx = event_idx !== nothing
621-
_ev = event_idx
622-
_tprev0 = _t
623-
624-
cb_state_fn = let raw = raw_affect, ev = _ev, tprev = _tprev0,
625-
has_ev = _has_event_idx
626-
627-
(out, u, p, t) -> begin
628-
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
629-
if has_ev
630-
raw(fakeinteg, ev)
631-
else
632-
raw(fakeinteg)
633-
end
634-
copyto!(out, fakeinteg.u)
635-
return out
636-
end
637-
end
638-
639-
cb_param_fn = let raw = raw_affect, ev = _ev, tprev = _tprev0,
640-
has_ev = _has_event_idx
641-
642-
(out, u, p, t) -> begin
643-
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
644-
if has_ev
645-
raw(fakeinteg, ev)
646-
else
647-
raw(fakeinteg)
648-
end
649-
copyto!(out, fakeinteg.p)
650-
return out
651-
end
652-
end
653-
654-
if _p === nothing || _p isa SciMLBase.NullParameters
655-
tunables, repack = _p, identity
656-
else
657-
tunables, repack, _ = canonicalize(Tunable(), _p)
658-
end
659-
_needs_repack = !(
660-
_p === nothing ||
661-
_p isa SciMLBase.NullParameters
662-
) &&
663-
isscimlstructure(_p) && !(_p isa AbstractArray)
664-
665-
pf_w = if _needs_repack
666-
let f = cb_state_fn, repack = repack
667-
(out, u, _tunables, t) -> f(out, u, repack(_tunables), t)
668-
end
669-
else
670-
cb_state_fn
671-
end
672-
673-
pf_wp = if _needs_repack
674-
let f = cb_param_fn, repack = repack
675-
(out, u, _tunables, t) -> f(out, u, repack(_tunables), t)
676-
end
677-
else
678-
cb_param_fn
679-
end
680620

681-
paramjac_config_w = get_paramjac_config(
682-
MooncakeLoaded(), autojacvec, pf_w, tunables, cb_state_fn, y, _t
683-
)
684-
# For wp, the output is parameter-sized (flat tunables) rather
685-
# than state-sized, so pass `out_sample = tunables` to size the
686-
# Mooncake cotangent/output buffers correctly.
687-
paramjac_config_wp = get_paramjac_config(
688-
MooncakeLoaded(), autojacvec, pf_wp, tunables, cb_param_fn,
689-
y, _t; out_sample = tunables
621+
w_paramjac = get_cb_paramjac_config(
622+
MooncakeLoaded(), autojacvec, raw_affect,
623+
event_idx, y, _p, _t, :state
690624
)
691-
692625
diffcache_w = AdjointDiffCache(
693-
nothing, pf_w, nothing, nothing, nothing,
694-
nothing, nothing, nothing, paramjac_config_w,
626+
nothing, nothing, nothing, nothing, nothing,
627+
nothing, nothing, nothing, w_paramjac,
695628
nothing, nothing, nothing, nothing, nothing,
696629
nothing, nothing, nothing, false,
697630
nothing, identity
698631
)
699632

633+
wp_paramjac = get_cb_paramjac_config(
634+
MooncakeLoaded(), autojacvec, raw_affect,
635+
event_idx, y, _p, _t, :param
636+
)
700637
diffcache_wp = AdjointDiffCache(
701-
nothing, pf_wp, nothing, nothing, nothing,
702-
nothing, nothing, nothing, paramjac_config_wp,
638+
nothing, nothing, nothing, nothing, nothing,
639+
nothing, nothing, nothing, wp_paramjac,
703640
nothing, nothing, nothing, nothing, nothing,
704641
nothing, nothing, nothing, false,
705642
nothing, identity

0 commit comments

Comments
 (0)