Skip to content

Commit 7a5e90d

Browse files
Use supports_callback_vjp trait and preserve FakeIntegrator types
Replaces the hardcoded `Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP, MooncakeVJP}` checks in `_setup_reverse_callbacks` and `setup_w_wp` with a `supports_callback_vjp(autojacvec)` trait defined alongside `supports_structured_vjp` in sensitivity_algorithms.jl. Future VJP backends opt in by adding a single method instead of threading a new type through two Union annotations. Also fixes `get_FakeIntegrator(::MooncakeVJP, ...)` to pass `u` and `p` through unchanged (matching the EnzymeVJP / ReactantVJP methods) instead of using `[x for x in u]` / `[x for x in p]`. The comprehension is only needed by the ReverseDiffVJP path to hand ReverseDiff a fresh tracked container; Mooncake's pullback cache is fed by the dedicated `cb_state_fn` / `cb_param_fn` closures in `get_cb_diffcaches`, so this dispatch only fires at the plain `w(y, y, integrator.p, integrator.t)` call site where preserving the original container type (e.g. ComponentArray, MVector) matters. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent ae32854 commit 7a5e90d

2 files changed

Lines changed: 27 additions & 11 deletions

File tree

src/callback_tracking.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,12 @@ function _setup_reverse_callbacks(
283283
# if save_positions = [1,0] the gradient contribution is added before, and in principle we would need to correct the adjoint state again. Therefore,
284284

285285
cb.save_positions == [1, 0] && error("save_positions=[1,0] is currently not supported.")
286-
# Callbacks require ReverseDiffVJP, EnzymeVJP, ReactantVJP, or MooncakeVJP for
287-
# their own VJP computations. The ODE adjoint may use a different autojacvec
288-
# (even numerical/false), but the callback affect functions (CallbackAffectWrapper)
289-
# are separate and typically work with ReverseDiff even when the ODE function doesn't.
290-
cb_autojacvec = if sensealg.autojacvec isa
291-
Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP, MooncakeVJP}
286+
# Callback affect functions (CallbackAffectWrapper) are traced separately
287+
# from the ODE function, so the ODE adjoint may use a different autojacvec
288+
# (even numerical/false) while the callback path uses its own compatible
289+
# backend. The `supports_callback_vjp` trait marks which VJP backends have
290+
# a dedicated callback path; anything else falls back to ReverseDiffVJP.
291+
cb_autojacvec = if supports_callback_vjp(sensealg.autojacvec)
292292
sensealg.autojacvec
293293
else
294294
@warn "autojacvec=$(sensealg.autojacvec) is not compatible with callbacks, using ReverseDiffVJP() for callback VJPs"
@@ -507,9 +507,10 @@ end
507507

508508
function setup_w_wp(
509509
cb::Union{DiscreteCallback, ContinuousCallback, VectorContinuousCallback},
510-
autojacvec::Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP, MooncakeVJP}, pos_neg,
511-
event_idx, tprev
510+
autojacvec, pos_neg, event_idx, tprev
512511
)
512+
supports_callback_vjp(autojacvec) ||
513+
error("setup_w_wp called with a VJP backend that does not support callbacks: $(autojacvec). This is an internal error — the callback path should have redirected to a compatible backend in `_setup_reverse_callbacks`.")
513514
w = CallbackAffectWrapper(cb, autojacvec, pos_neg, event_idx, tprev)
514515
wp = CallbackAffectPWrapper(cb, autojacvec, pos_neg, event_idx, tprev)
515516
return w, wp
@@ -520,9 +521,7 @@ function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
520521
end
521522
get_FakeIntegrator(autojacvec::EnzymeVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
522523
get_FakeIntegrator(autojacvec::ReactantVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
523-
function get_FakeIntegrator(autojacvec::MooncakeVJP, u, p, t, tprev)
524-
return FakeIntegrator([x for x in u], [x for x in p], t, tprev)
525-
end
524+
get_FakeIntegrator(autojacvec::MooncakeVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
526525

527526
function _get_wp_paramjac_config(autojacvec::EnzymeVJP, _p, wp, y, __p, _t)
528527
return (zero(y), zero(_p), zero(_p), zero(_p), zero(y))

src/sensitivity_algorithms.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,23 @@ supports_structured_vjp(::ReverseDiffVJP) = false
16961696
supports_structured_vjp(::Bool) = false
16971697
supports_structured_vjp(::Nothing) = false
16981698

1699+
"""
1700+
supports_callback_vjp(autojacvec) -> Bool
1701+
1702+
Return `true` if the VJP backend can differentiate the affect functions of
1703+
tracked callbacks (i.e. has a dedicated branch in `get_cb_diffcaches` and a
1704+
`_vecjacobian!` path that works with `CallbackSensitivityFunction`).
1705+
1706+
When `false`, the callback adjoint code falls back to `ReverseDiffVJP()` for
1707+
its own VJP computations — the ODE adjoint may still use the user-requested
1708+
backend, since the callback affect functions are traced separately.
1709+
"""
1710+
supports_callback_vjp(::ReverseDiffVJP) = true
1711+
supports_callback_vjp(::EnzymeVJP) = true
1712+
supports_callback_vjp(::ReactantVJP) = true
1713+
supports_callback_vjp(::MooncakeVJP) = true
1714+
supports_callback_vjp(::Any) = false
1715+
16991716
"""
17001717
```julia
17011718
ForwardDiffOverAdjoint{A} <: AbstractSecondOrderSensitivityAlgorithm{nothing, true, nothing}

0 commit comments

Comments
 (0)