@@ -283,11 +283,12 @@ function _setup_reverse_callbacks(
283283 # if save_positions = [1,0] the gradient contribution is added before, and in principle we would need to correct the adjoint state again. Therefore,
284284
285285 cb. save_positions == [1 , 0 ] && error (" save_positions=[1,0] is currently not supported." )
286- # Callbacks require ReverseDiffVJP or EnzymeVJP for their own VJP computations.
287- # The ODE adjoint may use a different autojacvec (even numerical/false), but the
288- # callback affect functions (CallbackAffectWrapper) are separate and typically work
289- # with ReverseDiff even when the ODE function doesn't.
290- cb_autojacvec = if sensealg. autojacvec isa Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP}
286+ # Callback affect functions (CallbackAffectWrapper) are traced separately
287+ # from the ODE function, so the ODE adjoint may use a different autojacvec
288+ # (even numerical/false) while the callback path uses its own compatible
289+ # backend. The `supports_callback_vjp` trait marks which VJP backends have
290+ # a dedicated callback path; anything else falls back to ReverseDiffVJP.
291+ cb_autojacvec = if supports_callback_vjp (sensealg. autojacvec)
291292 sensealg. autojacvec
292293 else
293294 @warn " autojacvec=$(sensealg. autojacvec) is not compatible with callbacks, using ReverseDiffVJP() for callback VJPs"
506507
507508function setup_w_wp (
508509 cb:: Union{DiscreteCallback, ContinuousCallback, VectorContinuousCallback} ,
509- autojacvec:: Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP} , pos_neg, event_idx,
510- tprev
510+ autojacvec, pos_neg, event_idx, tprev
511511 )
512+ supports_callback_vjp (autojacvec) ||
513+ error (" setup_w_wp called with a VJP backend that does not support callbacks: $(autojacvec) . This is an internal error — the callback path should have redirected to a compatible backend in `_setup_reverse_callbacks`." )
512514 w = CallbackAffectWrapper (cb, autojacvec, pos_neg, event_idx, tprev)
513515 wp = CallbackAffectPWrapper (cb, autojacvec, pos_neg, event_idx, tprev)
514516 return w, wp
@@ -519,6 +521,7 @@ function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
519521end
520522get_FakeIntegrator (autojacvec:: EnzymeVJP , u, p, t, tprev) = FakeIntegrator (u, p, t, tprev)
521523get_FakeIntegrator (autojacvec:: ReactantVJP , u, p, t, tprev) = FakeIntegrator (u, p, t, tprev)
524+ get_FakeIntegrator (autojacvec:: MooncakeVJP , u, p, t, tprev) = FakeIntegrator (u, p, t, tprev)
522525
523526function _get_wp_paramjac_config (autojacvec:: EnzymeVJP , _p, wp, y, __p, _t)
524527 return (zero (y), zero (_p), zero (_p), zero (_p), zero (y))
@@ -603,6 +606,41 @@ function get_cb_diffcaches(
603606 nothing , nothing , nothing , false ,
604607 nothing , identity
605608 )
609+ elseif autojacvec isa MooncakeVJP
610+ # MooncakeVJP: build Mooncake pullback caches for the
611+ # state-affect and parameter-affect callback wrappers via
612+ # `get_cb_paramjac_config` in the Mooncake extension, mirroring
613+ # the ReactantVJP branch above. Mooncake can't trace through
614+ # the recursive TrackedAffect unwrapping in
615+ # `CallbackAffectWrapper` (it trips on the
616+ # `Base.argument_datatype` ccall surfaced by that dispatch),
617+ # so the extension hoists `raw_affect` out and bakes it into
618+ # a flat (out, u, p, t) closure before preparing the cache.
619+ raw_affect = get_affect! (cb, pos_neg)
620+
621+ w_paramjac = get_cb_paramjac_config (
622+ MooncakeLoaded (), autojacvec, raw_affect,
623+ event_idx, y, _p, _t, :state
624+ )
625+ diffcache_w = AdjointDiffCache (
626+ nothing , nothing , nothing , nothing , nothing ,
627+ nothing , nothing , nothing , w_paramjac,
628+ nothing , nothing , nothing , nothing , nothing ,
629+ nothing , nothing , nothing , false ,
630+ nothing , identity
631+ )
632+
633+ wp_paramjac = get_cb_paramjac_config (
634+ MooncakeLoaded (), autojacvec, raw_affect,
635+ event_idx, y, _p, _t, :param
636+ )
637+ diffcache_wp = AdjointDiffCache (
638+ nothing , nothing , nothing , nothing , nothing ,
639+ nothing , nothing , nothing , wp_paramjac,
640+ nothing , nothing , nothing , nothing , nothing ,
641+ nothing , nothing , nothing , false ,
642+ nothing , identity
643+ )
606644 else
607645 w, wp = setup_w_wp (cb, autojacvec, pos_neg, event_idx, _t)
608646
0 commit comments