Skip to content

Commit ae32854

Browse files
Add MooncakeVJP to callback-compatible VJP methods
Previously only ReverseDiffVJP, EnzymeVJP, and ReactantVJP were accepted for differentiating callback affect functions; passing MooncakeVJP as the sensealg autojacvec silently fell back to ReverseDiffVJP for callback VJPs. This change adds MooncakeVJP to the compatible Union and wires up a Mooncake-specific branch in get_cb_diffcaches that builds pullback caches for both the state-affect and parameter-affect callback wrappers. The Mooncake path mirrors the ReactantVJP extension's approach: it hoists `raw_affect` out of the pullback-traced code (Mooncake cannot trace through the recursive TrackedAffect unwrapping in CallbackAffectWrapper because it trips on the `Base.argument_datatype` ccall that surfaces during dispatch) and builds flat (out, u, p, t) closures that bake the extracted affect in directly. Two caches are built per event: one with a state-sized output (for w) and one with a parameter-sized output (for wp); the existing `get_paramjac_config(::MooncakeLoaded, ...)` entry point gains an `out_sample` keyword so callers can choose the output buffer shape without duplicating the extension method. Verified against ForwardDiff on the existing Re-compile tape testset and on the full test/callbacks/continuous_callbacks.jl file (164 tests pass). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 1ab9331 commit ae32854

4 files changed

Lines changed: 131 additions & 12 deletions

File tree

ext/SciMLSensitivityMooncakeExt.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@ function _init_originator_gradient(
2323
return igs
2424
end
2525

26-
function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t)
27-
dy_mem = zero(y)
28-
λ_mem = zero(y)
26+
function get_paramjac_config(
27+
::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t; out_sample = nothing
28+
)
29+
# `out_sample` lets callers size the output/cotangent buffers for functions
30+
# whose output is not state-sized (e.g. the parameter-affect callback wrapper
31+
# whose output has the shape of the flat tunables). Defaults to `y` which is
32+
# correct for the usual ODE/state-output case.
33+
_out_sample = out_sample === nothing ? y : out_sample
34+
dy_mem = zero(_out_sample)
35+
λ_mem = zero(_out_sample)
2936
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
3037
# Pre-allocate buffer for tangent_to_primal!! conversion of struct-based
3138
# array types (e.g. ComponentArray) whose Mooncake tangent is Mooncake.Tangent.

src/adjoint_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ struct ReactantVJPConfig{FK, DK, FC, DC}
555555
chunk_size::Int
556556
end
557557

558-
function get_paramjac_config(::Any, ::MooncakeVJP, pf, p, f, y, _t)
558+
function get_paramjac_config(::Any, ::MooncakeVJP, pf, p, f, y, _t; out_sample = nothing)
559559
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " * "`using Mooncake` to use this functionality"
560560
error(msg)
561561
end

src/callback_tracking.jl

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,12 @@ function _setup_reverse_callbacks(
283283
# if save_positions = [1,0] the gradient contribution is added before, and in principle we would need to correct the adjoint state again. Therefore,
284284

285285
cb.save_positions == [1, 0] && error("save_positions=[1,0] is currently not supported.")
286-
# Callbacks require ReverseDiffVJP or EnzymeVJP for their own VJP computations.
287-
# The ODE adjoint may use a different autojacvec (even numerical/false), but the
288-
# callback affect functions (CallbackAffectWrapper) are separate and typically work
289-
# with ReverseDiff even when the ODE function doesn't.
290-
cb_autojacvec = if sensealg.autojacvec isa Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP}
286+
# Callbacks require ReverseDiffVJP, EnzymeVJP, ReactantVJP, or MooncakeVJP for
287+
# their own VJP computations. The ODE adjoint may use a different autojacvec
288+
# (even numerical/false), but the callback affect functions (CallbackAffectWrapper)
289+
# are separate and typically work with ReverseDiff even when the ODE function doesn't.
290+
cb_autojacvec = if sensealg.autojacvec isa
291+
Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP, MooncakeVJP}
291292
sensealg.autojacvec
292293
else
293294
@warn "autojacvec=$(sensealg.autojacvec) is not compatible with callbacks, using ReverseDiffVJP() for callback VJPs"
@@ -506,8 +507,8 @@ end
506507

507508
function setup_w_wp(
508509
cb::Union{DiscreteCallback, ContinuousCallback, VectorContinuousCallback},
509-
autojacvec::Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP}, pos_neg, event_idx,
510-
tprev
510+
autojacvec::Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP, MooncakeVJP}, pos_neg,
511+
event_idx, tprev
511512
)
512513
w = CallbackAffectWrapper(cb, autojacvec, pos_neg, event_idx, tprev)
513514
wp = CallbackAffectPWrapper(cb, autojacvec, pos_neg, event_idx, tprev)
@@ -519,6 +520,9 @@ function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
519520
end
520521
get_FakeIntegrator(autojacvec::EnzymeVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
521522
get_FakeIntegrator(autojacvec::ReactantVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
523+
function get_FakeIntegrator(autojacvec::MooncakeVJP, u, p, t, tprev)
524+
return FakeIntegrator([x for x in u], [x for x in p], t, tprev)
525+
end
522526

523527
function _get_wp_paramjac_config(autojacvec::EnzymeVJP, _p, wp, y, __p, _t)
524528
return (zero(y), zero(_p), zero(_p), zero(_p), zero(y))
@@ -603,6 +607,104 @@ function get_cb_diffcaches(
603607
nothing, nothing, nothing, false,
604608
nothing, identity
605609
)
610+
elseif autojacvec isa MooncakeVJP
611+
# MooncakeVJP: build Mooncake pullback caches for the
612+
# state-affect and parameter-affect callback wrappers.
613+
# Mooncake can't trace through the recursive TrackedAffect
614+
# unwrapping in `CallbackAffectWrapper` (it trips on the
615+
# `Base.argument_datatype` ccall that surfaces during the
616+
# dispatch), so we mirror the ReactantVJP path and build a
617+
# flat closure that bakes `raw_affect` directly in. The
618+
# state-output closure has y-sized output; the parameter-
619+
# output closure has (flat) tunables-sized output.
620+
raw_affect = get_affect!(cb, pos_neg)
621+
_has_event_idx = event_idx !== nothing
622+
_ev = event_idx
623+
_tprev0 = _t
624+
625+
cb_state_fn = let raw = raw_affect, ev = _ev, tprev = _tprev0,
626+
has_ev = _has_event_idx
627+
628+
(out, u, p, t) -> begin
629+
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
630+
if has_ev
631+
raw(fakeinteg, ev)
632+
else
633+
raw(fakeinteg)
634+
end
635+
copyto!(out, fakeinteg.u)
636+
return out
637+
end
638+
end
639+
640+
cb_param_fn = let raw = raw_affect, ev = _ev, tprev = _tprev0,
641+
has_ev = _has_event_idx
642+
643+
(out, u, p, t) -> begin
644+
fakeinteg = FakeIntegrator(copy(u), copy(p), t, tprev)
645+
if has_ev
646+
raw(fakeinteg, ev)
647+
else
648+
raw(fakeinteg)
649+
end
650+
copyto!(out, fakeinteg.p)
651+
return out
652+
end
653+
end
654+
655+
if _p === nothing || _p isa SciMLBase.NullParameters
656+
tunables, repack = _p, identity
657+
else
658+
tunables, repack, _ = canonicalize(Tunable(), _p)
659+
end
660+
_needs_repack = !(
661+
_p === nothing ||
662+
_p isa SciMLBase.NullParameters
663+
) &&
664+
isscimlstructure(_p) && !(_p isa AbstractArray)
665+
666+
pf_w = if _needs_repack
667+
let f = cb_state_fn, repack = repack
668+
(out, u, _tunables, t) -> f(out, u, repack(_tunables), t)
669+
end
670+
else
671+
cb_state_fn
672+
end
673+
674+
pf_wp = if _needs_repack
675+
let f = cb_param_fn, repack = repack
676+
(out, u, _tunables, t) -> f(out, u, repack(_tunables), t)
677+
end
678+
else
679+
cb_param_fn
680+
end
681+
682+
paramjac_config_w = get_paramjac_config(
683+
MooncakeLoaded(), autojacvec, pf_w, tunables, cb_state_fn, y, _t
684+
)
685+
# For wp, the output is parameter-sized (flat tunables) rather
686+
# than state-sized, so pass `out_sample = tunables` to size the
687+
# Mooncake cotangent/output buffers correctly.
688+
paramjac_config_wp = get_paramjac_config(
689+
MooncakeLoaded(), autojacvec, pf_wp, tunables, cb_param_fn,
690+
y, _t; out_sample = tunables
691+
)
692+
693+
diffcache_w = AdjointDiffCache(
694+
nothing, pf_w, nothing, nothing, nothing,
695+
nothing, nothing, nothing, paramjac_config_w,
696+
nothing, nothing, nothing, nothing, nothing,
697+
nothing, nothing, nothing, false,
698+
nothing, identity
699+
)
700+
701+
diffcache_wp = AdjointDiffCache(
702+
nothing, pf_wp, nothing, nothing, nothing,
703+
nothing, nothing, nothing, paramjac_config_wp,
704+
nothing, nothing, nothing, nothing, nothing,
705+
nothing, nothing, nothing, false,
706+
nothing, identity
707+
)
606708
else
607709
w, wp = setup_w_wp(cb, autojacvec, pos_neg, event_idx, _t)
608710

test/callbacks/continuous_callbacks.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using OrdinaryDiffEq, Zygote, Reactant
1+
using OrdinaryDiffEq, Zygote, Reactant, Mooncake
22
using SciMLSensitivity, Test, ForwardDiff, FiniteDiff
3+
using SciMLSensitivity: MooncakeVJP
34

45
abstol = 1.0e-12
56
reltol = 1.0e-12
@@ -362,5 +363,14 @@ println("Continuous Callbacks")
362363
sensealg = GaussAdjoint(autojacvec = ReactantVJP(allow_scalar = true))
363364
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
364365
@test gFD gZy rtol = 1.0e-10
366+
367+
# MooncakeVJP with callbacks
368+
sensealg = InterpolatingAdjoint(autojacvec = MooncakeVJP())
369+
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
370+
@test gFD gZy rtol = 1.0e-10
371+
372+
sensealg = GaussAdjoint(autojacvec = MooncakeVJP())
373+
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
374+
@test gFD gZy rtol = 1.0e-10
365375
end
366376
end

0 commit comments

Comments
 (0)