@@ -1121,7 +1121,8 @@ function SciMLBase._concrete_solve_adjoint(
11211121 J = du[i]
11221122 if Δ isa AbstractVector
11231123 v = Δ[i]
1124- elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray
1124+ elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray ||
1125+ Δ isa Tangent
11251126 v = Δ. u[i]
11261127 elseif Δ isa AbstractMatrix
11271128 v = @view Δ[:, i]
@@ -1149,6 +1150,69 @@ function SciMLBase._concrete_solve_adjoint(
11491150 return out, forward_sensitivity_backpass
11501151end
11511152
1153+ # Mooncake-specific `ForwardSensitivity` path. The main method builds an
1154+ # `ODEForwardSensitivityProblem` whose `f` is an
1155+ # `ODEForwardSensitivityFunction` carrying ForwardDiff internals
1156+ # (`ForwardDiff.JacobianConfig`, `Dual` caches, …) in its type parameters.
1157+ # The returned `sensitivity_solution(augmented_sol, u, ts)` inherits those
1158+ # types in `sol.prob.f`, which confuses Mooncake's `@from_rrule` tangent
1159+ # recursion the same way the tracked types in `ReverseDiffAdjoint` /
1160+ # `TrackerAdjoint` do. Delegate to the `ChainRulesOriginator` path for the
1161+ # sensitivity tape and re-solve the plain problem for the primal.
1162+ #
1163+ # A tempting alternative is to walk the returned `primal` and strip
1164+ # tracked / augmented types via `SciMLBase.value` recursively. That
1165+ # approach *almost* works but fails on one specific slot: `solve()` wraps
1166+ # `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during
1167+ # `get_concrete_problem`, and the resulting `FunctionWrapper` has no
1168+ # public positional constructor for `ConstructionBase.setproperties`, so a
1169+ # generic walker can't rebuild it. Mooncake's `@from_rrule` type
1170+ # inference nonetheless expects the `FunctionWrapper` version (because
1171+ # that's what `solve()` returns), so anything less than actually invoking
1172+ # `solve()` produces a type mismatch on the `DerivedRule` assertion.
1173+ # A truly general `strip_values(sol)` in SciMLBase would need the same
1174+ # `solve()` round-trip internally, so the cost is unavoidable here.
1175+ #
1176+ # Additionally, the main `forward_sensitivity_backpass` returns `du0 =
1177+ # @not_implemented(...)` because `ForwardSensitivity` can't differentiate
1178+ # w.r.t. `u0`. Mooncake's `@from_rrule` plumbing then tries to convert that
1179+ # `ChainRulesCore.NotImplemented` tangent back through
1180+ # `increment_and_get_rdata!` against the `Vector{Float64}` fdata of `u0`,
1181+ # and Mooncake doesn't have a method for that combination (only scalar
1182+ # `IEEEFloat` + `NotImplemented` is handled). Since Mooncake will dutifully
1183+ # thread the cotangent of *every* argument through `increment_and_get_rdata!`
1184+ # regardless of whether the caller is actually differentiating `u0`, we
1185+ # replace the `du0` slot in the delegated ChainRules pullback with
1186+ # `NoTangent()` so the Mooncake conversion has a shape it understands. Any
1187+ # caller that genuinely differentiates `u0` while using `ForwardSensitivity`
1188+ # is already using the wrong sensealg (the main method's error message says
1189+ # as much).
1190+ function SciMLBase. _concrete_solve_adjoint (
1191+ prob:: SciMLBase.AbstractODEProblem , alg,
1192+ sensealg:: ForwardSensitivity ,
1193+ u0, p, originator:: SciMLBase.MooncakeOriginator ,
1194+ args... ; kwargs...
1195+ )
1196+ _, backpass = SciMLBase. _concrete_solve_adjoint (
1197+ prob, alg, sensealg, u0, p,
1198+ SciMLBase. ChainRulesOriginator (), args... ; kwargs...
1199+ )
1200+ # ChainRules branch of `forward_sensitivity_backpass` returns
1201+ # `(NoTangent(), NoTangent(), NoTangent(), du0, adj, NoTangent(), rest...)`.
1202+ # Replace position 4 (`du0`) with `NoTangent()`.
1203+ function mooncake_forward_sensitivity_backpass (Δ)
1204+ cr = backpass (Δ)
1205+ return (cr[1 ], cr[2 ], cr[3 ], NoTangent (), cr[5 : end ]. .. )
1206+ end
1207+ kwargs_filtered = NamedTuple (filter (x -> x[1 ] != :sensealg , kwargs))
1208+ primal = solve (
1209+ remake (prob; u0, p), alg, args... ;
1210+ sensealg = DiffEqBase. SensitivityADPassThrough (),
1211+ kwargs_filtered...
1212+ )
1213+ return primal, mooncake_forward_sensitivity_backpass
1214+ end
1215+
11521216function SciMLBase. _concrete_solve_forward (
11531217 prob:: SciMLBase.AbstractODEProblem , alg,
11541218 sensealg:: AbstractForwardSensitivityAlgorithm ,
@@ -1846,21 +1910,6 @@ function Base.showerror(io::IO, e::EnzymeTrackedRealError)
18461910 return println (io, ENZYME_TRACKED_REAL_ERROR_MESSAGE)
18471911end
18481912
1849- const MOONCAKE_TRACKED_REAL_ERROR_MESSAGE = """
1850- `Mooncake` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
1851- Either choose a different adjoint method like `GaussAdjoint`,
1852- or use a different AD system like `ReverseDiff`.
1853- For more details, on these methods see
1854- https://docs.sciml.ai/SciMLSensitivity/stable/.
1855- """
1856-
1857- struct MooncakeTrackedRealError <: Exception
1858- end
1859-
1860- function Base. showerror (io:: IO , e:: MooncakeTrackedRealError )
1861- return println (io, MOONCAKE_TRACKED_REAL_ERROR_MESSAGE)
1862- end
1863-
18641913function SciMLBase. _concrete_solve_adjoint (
18651914 prob:: Union {
18661915 SciMLBase. AbstractDiscreteProblem,
@@ -1881,10 +1930,6 @@ function SciMLBase._concrete_solve_adjoint(
18811930 throw (EnzymeTrackedRealError ())
18821931 end
18831932
1884- if originator isa SciMLBase. MooncakeOriginator
1885- throw (MooncakeTrackedRealError ())
1886- end
1887-
18881933 if ! (p === nothing || p isa SciMLBase. NullParameters)
18891934 if ! isscimlstructure (p)
18901935 throw (SciMLStructuresCompatibilityError ())
@@ -2093,6 +2138,41 @@ function SciMLBase._concrete_solve_adjoint(
20932138 tracker_adjoint_backpass
20942139end
20952140
2141+ # Mooncake-specific `TrackerAdjoint` path. Same reasoning as the
2142+ # `ReverseDiffAdjoint` + `MooncakeOriginator` method below: the main method
2143+ # returns `sensitivity_solution(tracked_sol, …)` with `Tracker.TrackedReal` /
2144+ # `TrackedArray` type parameters embedded in `tracked_sol.interp` / `.prob`
2145+ # / `.alg`, and Mooncake's `@from_rrule` plumbing chokes when recursively
2146+ # computing `tangent_type` on those fields. Delegate the tape to the
2147+ # `ChainRulesOriginator` path and re-solve with `SensitivityADPassThrough`
2148+ # for the primal.
2149+ function SciMLBase. _concrete_solve_adjoint (
2150+ prob:: Union {
2151+ SciMLBase. AbstractDiscreteProblem,
2152+ SciMLBase. AbstractODEProblem,
2153+ SciMLBase. AbstractDAEProblem,
2154+ SciMLBase. AbstractDDEProblem,
2155+ SciMLBase. AbstractSDEProblem,
2156+ SciMLBase. AbstractSDDEProblem,
2157+ SciMLBase. AbstractRODEProblem,
2158+ },
2159+ alg, sensealg:: TrackerAdjoint ,
2160+ u0, p, originator:: SciMLBase.MooncakeOriginator ,
2161+ args... ; kwargs...
2162+ )
2163+ _, backpass = SciMLBase. _concrete_solve_adjoint (
2164+ prob, alg, sensealg, u0, p,
2165+ SciMLBase. ChainRulesOriginator (), args... ; kwargs...
2166+ )
2167+ kwargs_filtered = NamedTuple (filter (x -> x[1 ] != :sensealg , kwargs))
2168+ primal = solve (
2169+ remake (prob; u0, p), alg, args... ;
2170+ sensealg = DiffEqBase. SensitivityADPassThrough (),
2171+ kwargs_filtered...
2172+ )
2173+ return primal, backpass
2174+ end
2175+
20962176const REVERSEDIFF_ADJOINT_GPU_COMPATIBILITY_MESSAGE = """
20972177ReverseDiffAdjoint is not compatible GPU-based array types. Use a different
20982178sensitivity analysis method, like InterpolatingAdjoint or TrackerAdjoint,
@@ -2148,10 +2228,6 @@ function SciMLBase._concrete_solve_adjoint(
21482228 throw (EnzymeTrackedRealError ())
21492229 end
21502230
2151- if originator isa SciMLBase. MooncakeOriginator
2152- throw (MooncakeTrackedRealError ())
2153- end
2154-
21552231 t = eltype (prob. tspan)[]
21562232 u = typeof (u0)[]
21572233
@@ -2277,6 +2353,66 @@ function SciMLBase._concrete_solve_adjoint(
22772353 reversediff_adjoint_backpass
22782354end
22792355
2356+ # Mooncake-specific `ReverseDiffAdjoint` path. The main `ReverseDiffAdjoint`
2357+ # method above returns `SciMLBase.sensitivity_solution(sol, …)` where `sol`
2358+ # still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in
2359+ # nested fields (`interp`, `prob`, `alg`, …). ChainRules / Zygote don't
2360+ # inspect the primal's type parameters, so they don't care. Mooncake's
2361+ # `@from_rrule` plumbing, on the other hand, calls
2362+ # `zero_tangent(y_primal)` and therefore recursively computes `tangent_type`
2363+ # for every nested field of the returned solution; that recursion fails on
2364+ # the tracked type parameters with either a `TypeError` or an unhelpful
2365+ # tangent-type error, which is what the `hybrid_diffeq` tutorial and
2366+ # PR #1419 ran into.
2367+ #
2368+ # This method delegates the tape construction (and hence the whole backward
2369+ # pass) to the `ChainRulesOriginator` path, then replaces the primal with
2370+ # a fresh plain-arithmetic solve of the same problem. The obvious
2371+ # alternative — walking the returned primal and stripping tracked scalars
2372+ # via `SciMLBase.value` recursively — *almost* works but fails on one
2373+ # specific slot: `solve()` wraps `prob.f` in a
2374+ # `FunctionWrappersWrappers.FunctionWrappersWrapper` during
2375+ # `get_concrete_problem`, and the resulting `FunctionWrapper` has no
2376+ # public positional constructor for `ConstructionBase.setproperties`, so a
2377+ # generic walker can't rebuild it. Mooncake's `@from_rrule` type
2378+ # inference nonetheless expects the `FunctionWrapper` version (because
2379+ # that's what `solve()` normally returns from a plain-arithmetic solve),
2380+ # so anything short of actually invoking `solve()` produces a type
2381+ # mismatch on the `DerivedRule` assertion. A truly general
2382+ # `strip_values(sol)` in SciMLBase would need the same `solve()`
2383+ # round-trip internally, so the cost is unavoidable here.
2384+ #
2385+ # Keeping this in a dedicated method dispatched on `MooncakeOriginator`
2386+ # stops Julia type inference from joining two different return shapes
2387+ # into a `Union{ODESolution{tracked…}, ODESolution{plain…}}`, which would
2388+ # otherwise trip Mooncake's `DerivedRule` type assertion.
2389+ function SciMLBase. _concrete_solve_adjoint (
2390+ prob:: Union {
2391+ SciMLBase. AbstractDiscreteProblem,
2392+ SciMLBase. AbstractODEProblem,
2393+ SciMLBase. AbstractDAEProblem,
2394+ SciMLBase. AbstractDDEProblem,
2395+ SciMLBase. AbstractSDEProblem,
2396+ SciMLBase. AbstractSDDEProblem,
2397+ SciMLBase. AbstractRODEProblem,
2398+ },
2399+ alg, sensealg:: ReverseDiffAdjoint ,
2400+ u0, p, originator:: SciMLBase.MooncakeOriginator ,
2401+ args... ; kwargs...
2402+ )
2403+ _, backpass = SciMLBase. _concrete_solve_adjoint (
2404+ prob, alg, sensealg, u0, p,
2405+ SciMLBase. ChainRulesOriginator (), args... ; kwargs...
2406+ )
2407+ kwargs_filtered = NamedTuple (filter (x -> x[1 ] != :sensealg , kwargs))
2408+ primal = solve (
2409+ remake (prob; u0, p), alg, args... ;
2410+ sensealg = DiffEqBase. SensitivityADPassThrough (),
2411+ kwargs_filtered...
2412+ )
2413+ return primal, backpass
2414+ end
2415+
22802416function SciMLBase. _concrete_solve_adjoint (
22812417 prob:: SciMLBase.AbstractODEProblem , alg,
22822418 sensealg:: AbstractShadowingSensitivityAlgorithm ,
0 commit comments