Skip to content

Commit 4c868c1

Browse files
Document why the Mooncake dispatch re-solves instead of stripping in place
The previous commit attempted to replace the second `solve()` call in each Mooncake dispatch with a recursive `SciMLBase.value`-based walker (via `ConstructionBase.setproperties`) that would rebuild the returned `ODESolution` with plain type parameters. The walker does successfully strip `ReverseDiff.TrackedReal` / `Tracker.TrackedReal` / `Dual` types from nested fields (`u`, `t`, `k`, `interp.timeseries`, `interp.ts`, `interp.ks`, cache scratch arrays, …) and produces a solution on which `Mooncake.tangent_type` succeeds. It does not, however, satisfy Mooncake's `DerivedRule` type assertion: `solve()` wraps `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during `get_concrete_problem`, and the resulting `FunctionWrapper` has no public positional constructor for `ConstructionBase.setproperties`, so a generic walker cannot rebuild it. Mooncake's inference on `CRC.rrule(solve_up, …)` nonetheless expects the `FunctionWrapper`- wrapped `ODEFunction` (because that's what a plain-arithmetic `solve()` normally returns), so anything short of actually invoking `solve()` — including substituting `sol.prob` with `remake(prob; u0, p)`, which keeps the raw `typeof(f)` instead of wrapping — produces a `TypeError` mismatch on the `ODEFunction{…, FunctionWrapper{…}, …}` slot. This commit therefore keeps the re-solve approach, and adds a comment next to each Mooncake dispatch explaining why the walker alternative was tried and abandoned. A truly general `strip_values(sol)` helper in SciMLBase would need the same `solve()` round-trip internally, so the one-extra-solve cost is unavoidable in this rrule. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 3469ee8 commit 4c868c1

1 file changed

Lines changed: 32 additions & 8 deletions

File tree

src/concrete_solve.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,19 @@ end
11601160
# `TrackerAdjoint` do. Delegate to the `ChainRulesOriginator` path for the
11611161
# sensitivity tape and re-solve the plain problem for the primal.
11621162
#
1163+
# A tempting alternative is to walk the returned `primal` and strip
1164+
# tracked / augmented types via `SciMLBase.value` recursively. That
1165+
# approach *almost* works but fails on one specific slot: `solve()` wraps
1166+
# `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during
1167+
# `get_concrete_problem`, and the resulting `FunctionWrapper` has no
1168+
# public positional constructor for `ConstructionBase.setproperties`, so a
1169+
# generic walker can't rebuild it. Mooncake's `@from_rrule` type
1170+
# inference nonetheless expects the `FunctionWrapper` version (because
1171+
# that's what `solve()` returns), so anything less than actually invoking
1172+
# `solve()` produces a type mismatch on the `DerivedRule` assertion.
1173+
# A truly general `strip_values(sol)` in SciMLBase would need the same
1174+
# `solve()` round-trip internally, so the cost is unavoidable here.
1175+
#
11631176
# Additionally, the main `forward_sensitivity_backpass` returns `du0 =
11641177
# @not_implemented(...)` because `ForwardSensitivity` can't differentiate
11651178
# w.r.t. `u0`. Mooncake's `@from_rrule` plumbing then tries to convert that
@@ -2354,14 +2367,25 @@ end
23542367
#
23552368
# This method delegates the tape construction (and hence the whole backward
23562369
# pass) to the `ChainRulesOriginator` path, then replaces the primal with
2357-
# a fresh plain-arithmetic solve of the same problem. This keeps the
2358-
# outward-facing return type identical to what the non-sensitivity solve
2359-
# would return (i.e. `InterpolationData` and `DEStats`, not the
2360-
# `LinearInterpolation` / `Nothing` shape `build_solution` would produce),
2361-
# which is important because Mooncake's `DerivedRule` specialises on the
2362-
# inferred return type of the underlying `solve_up` call — that inference
2363-
# does not narrow through the `originator` kwarg, so the compiled rule
2364-
# expects the *main* method's return shape even on the Mooncake dispatch.
2370+
# a fresh plain-arithmetic solve of the same problem. The obvious
2371+
# alternative — walking the returned primal and stripping tracked scalars
2372+
# via `SciMLBase.value` recursively — *almost* works but fails on one
2373+
# specific slot: `solve()` wraps `prob.f` in a
2374+
# `FunctionWrappersWrappers.FunctionWrappersWrapper` during
2375+
# `get_concrete_problem`, and the resulting `FunctionWrapper` has no
2376+
# public positional constructor for `ConstructionBase.setproperties`, so a
2377+
# generic walker can't rebuild it. Mooncake's `@from_rrule` type
2378+
# inference nonetheless expects the `FunctionWrapper` version (because
2379+
# that's what `solve()` normally returns from a plain-arithmetic solve),
2380+
# so anything short of actually invoking `solve()` produces a type
2381+
# mismatch on the `DerivedRule` assertion. A truly general
2382+
# `strip_values(sol)` in SciMLBase would need the same `solve()`
2383+
# round-trip internally, so the cost is unavoidable here.
2384+
#
2385+
# Keeping this in a dedicated method dispatched on `MooncakeOriginator`
2386+
# stops Julia type inference from joining two different return shapes
2387+
# into a `Union{ODESolution{tracked…}, ODESolution{plain…}}`, which would
2388+
# otherwise trip Mooncake's `DerivedRule` type assertion.
23652389
function SciMLBase._concrete_solve_adjoint(
23662390
prob::Union{
23672391
SciMLBase.AbstractDiscreteProblem,

0 commit comments

Comments
 (0)