Commit 7a5e90d
committed
Use supports_callback_vjp trait and preserve FakeIntegrator types
Replaces the hardcoded `Union{ReverseDiffVJP, EnzymeVJP, ReactantVJP,
MooncakeVJP}` checks in `_setup_reverse_callbacks` and `setup_w_wp` with
a `supports_callback_vjp(autojacvec)` trait defined alongside
`supports_structured_vjp` in sensitivity_algorithms.jl. Future VJP
backends opt in by adding a single method instead of threading a new
type through two Union annotations.
Also fixes `get_FakeIntegrator(::MooncakeVJP, ...)` to pass `u` and `p`
through unchanged (matching the EnzymeVJP / ReactantVJP methods) instead
of using `[x for x in u]` / `[x for x in p]`. The comprehension is only
needed by the ReverseDiffVJP path to hand ReverseDiff a fresh tracked
container; Mooncake's pullback cache is fed by the dedicated
`cb_state_fn` / `cb_param_fn` closures in `get_cb_diffcaches`, so this
dispatch only fires at the plain `w(y, y, integrator.p, integrator.t)`
call site where preserving the original container type (e.g.
ComponentArray, MVector) matters.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>1 parent ae32854 commit 7a5e90d
2 files changed
Lines changed: 27 additions & 11 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
283 | 283 | | |
284 | 284 | | |
285 | 285 | | |
286 | | - | |
287 | | - | |
288 | | - | |
289 | | - | |
290 | | - | |
291 | | - | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
292 | 292 | | |
293 | 293 | | |
294 | 294 | | |
| |||
507 | 507 | | |
508 | 508 | | |
509 | 509 | | |
510 | | - | |
511 | | - | |
| 510 | + | |
512 | 511 | | |
| 512 | + | |
| 513 | + | |
513 | 514 | | |
514 | 515 | | |
515 | 516 | | |
| |||
520 | 521 | | |
521 | 522 | | |
522 | 523 | | |
523 | | - | |
524 | | - | |
525 | | - | |
| 524 | + | |
526 | 525 | | |
527 | 526 | | |
528 | 527 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1696 | 1696 | | |
1697 | 1697 | | |
1698 | 1698 | | |
| 1699 | + | |
| 1700 | + | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| 1705 | + | |
| 1706 | + | |
| 1707 | + | |
| 1708 | + | |
| 1709 | + | |
| 1710 | + | |
| 1711 | + | |
| 1712 | + | |
| 1713 | + | |
| 1714 | + | |
| 1715 | + | |
1699 | 1716 | | |
1700 | 1717 | | |
1701 | 1718 | | |
| |||
0 commit comments