Skip to content

Commit c0a4647

Browse files
Merge pull request #1420 from ChrisRackauckas-Claude/fix-mooncake-reversediffadjoint
Enable Mooncake + ReverseDiffAdjoint nested-AD path
2 parents 1ab9331 + 4c868c1 commit c0a4647

2 files changed

Lines changed: 184 additions & 24 deletions

File tree

src/concrete_solve.jl

Lines changed: 160 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,8 @@ function SciMLBase._concrete_solve_adjoint(
11211121
J = du[i]
11221122
if Δ isa AbstractVector
11231123
v = Δ[i]
1124-
elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray
1124+
elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray ||
1125+
Δ isa Tangent
11251126
v = Δ.u[i]
11261127
elseif Δ isa AbstractMatrix
11271128
v = @view Δ[:, i]
@@ -1149,6 +1150,69 @@ function SciMLBase._concrete_solve_adjoint(
11491150
return out, forward_sensitivity_backpass
11501151
end
11511152

1153+
# Mooncake-specific `ForwardSensitivity` path. The main method builds an
1154+
# `ODEForwardSensitivityProblem` whose `f` is an
1155+
# `ODEForwardSensitivityFunction` carrying ForwardDiff internals
1156+
# (`ForwardDiff.JacobianConfig`, `Dual` caches, …) in its type parameters.
1157+
# The returned `sensitivity_solution(augmented_sol, u, ts)` inherits those
1158+
# types in `sol.prob.f`, which confuses Mooncake's `@from_rrule` tangent
1159+
# recursion the same way the tracked types in `ReverseDiffAdjoint` /
1160+
# `TrackerAdjoint` do. Delegate to the `ChainRulesOriginator` path for the
1161+
# sensitivity tape and re-solve the plain problem for the primal.
1162+
#
1163+
# A tempting alternative is to walk the returned `primal` and strip
1164+
# tracked / augmented types via `SciMLBase.value` recursively. That
1165+
# approach *almost* works but fails on one specific slot: `solve()` wraps
1166+
# `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during
1167+
# `get_concrete_problem`, and the resulting `FunctionWrapper` has no
1168+
# public positional constructor for `ConstructionBase.setproperties`, so a
1169+
# generic walker can't rebuild it. Mooncake's `@from_rrule` type
1170+
# inference nonetheless expects the `FunctionWrapper` version (because
1171+
# that's what `solve()` returns), so anything less than actually invoking
1172+
# `solve()` produces a type mismatch on the `DerivedRule` assertion.
1173+
# A truly general `strip_values(sol)` in SciMLBase would need the same
1174+
# `solve()` round-trip internally, so the cost is unavoidable here.
1175+
#
1176+
# Additionally, the main `forward_sensitivity_backpass` returns `du0 =
1177+
# @not_implemented(...)` because `ForwardSensitivity` can't differentiate
1178+
# w.r.t. `u0`. Mooncake's `@from_rrule` plumbing then tries to convert that
1179+
# `ChainRulesCore.NotImplemented` tangent back through
1180+
# `increment_and_get_rdata!` against the `Vector{Float64}` fdata of `u0`,
1181+
# and Mooncake doesn't have a method for that combination (only scalar
1182+
# `IEEEFloat` + `NotImplemented` is handled). Since Mooncake will dutifully
1183+
# thread the cotangent of *every* argument through `increment_and_get_rdata!`
1184+
# regardless of whether the caller is actually differentiating `u0`, we
1185+
# replace the `du0` slot in the delegated ChainRules pullback with
1186+
# `NoTangent()` so the Mooncake conversion has a shape it understands. Any
1187+
# caller that genuinely differentiates `u0` while using `ForwardSensitivity`
1188+
# is already using the wrong sensealg (the main method's error message says
1189+
# as much).
1190+
function SciMLBase._concrete_solve_adjoint(
1191+
prob::SciMLBase.AbstractODEProblem, alg,
1192+
sensealg::ForwardSensitivity,
1193+
u0, p, originator::SciMLBase.MooncakeOriginator,
1194+
args...; kwargs...
1195+
)
1196+
_, backpass = SciMLBase._concrete_solve_adjoint(
1197+
prob, alg, sensealg, u0, p,
1198+
SciMLBase.ChainRulesOriginator(), args...; kwargs...
1199+
)
1200+
# ChainRules branch of `forward_sensitivity_backpass` returns
1201+
# `(NoTangent(), NoTangent(), NoTangent(), du0, adj, NoTangent(), rest...)`.
1202+
# Replace position 4 (`du0`) with `NoTangent()`.
1203+
function mooncake_forward_sensitivity_backpass(Δ)
1204+
cr = backpass(Δ)
1205+
return (cr[1], cr[2], cr[3], NoTangent(), cr[5:end]...)
1206+
end
1207+
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
1208+
primal = solve(
1209+
remake(prob; u0, p), alg, args...;
1210+
sensealg = DiffEqBase.SensitivityADPassThrough(),
1211+
kwargs_filtered...
1212+
)
1213+
return primal, mooncake_forward_sensitivity_backpass
1214+
end
1215+
11521216
function SciMLBase._concrete_solve_forward(
11531217
prob::SciMLBase.AbstractODEProblem, alg,
11541218
sensealg::AbstractForwardSensitivityAlgorithm,
@@ -1846,21 +1910,6 @@ function Base.showerror(io::IO, e::EnzymeTrackedRealError)
18461910
return println(io, ENZYME_TRACKED_REAL_ERROR_MESSAGE)
18471911
end
18481912

1849-
const MOONCAKE_TRACKED_REAL_ERROR_MESSAGE = """
1850-
`Mooncake` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
1851-
Either choose a different adjoint method like `GaussAdjoint`,
1852-
or use a different AD system like `ReverseDiff`.
1853-
For more details, on these methods see
1854-
https://docs.sciml.ai/SciMLSensitivity/stable/.
1855-
"""
1856-
1857-
struct MooncakeTrackedRealError <: Exception
1858-
end
1859-
1860-
function Base.showerror(io::IO, e::MooncakeTrackedRealError)
1861-
return println(io, MOONCAKE_TRACKED_REAL_ERROR_MESSAGE)
1862-
end
1863-
18641913
function SciMLBase._concrete_solve_adjoint(
18651914
prob::Union{
18661915
SciMLBase.AbstractDiscreteProblem,
@@ -1881,10 +1930,6 @@ function SciMLBase._concrete_solve_adjoint(
18811930
throw(EnzymeTrackedRealError())
18821931
end
18831932

1884-
if originator isa SciMLBase.MooncakeOriginator
1885-
throw(MooncakeTrackedRealError())
1886-
end
1887-
18881933
if !(p === nothing || p isa SciMLBase.NullParameters)
18891934
if !isscimlstructure(p)
18901935
throw(SciMLStructuresCompatibilityError())
@@ -2093,6 +2138,41 @@ function SciMLBase._concrete_solve_adjoint(
20932138
tracker_adjoint_backpass
20942139
end
20952140

2141+
# Mooncake-specific `TrackerAdjoint` path. Same reasoning as the
2142+
# `ReverseDiffAdjoint` + `MooncakeOriginator` method below: the main method
2143+
# returns `sensitivity_solution(tracked_sol, …)` with `Tracker.TrackedReal` /
2144+
# `TrackedArray` type parameters embedded in `tracked_sol.interp` / `.prob`
2145+
# / `.alg`, and Mooncake's `@from_rrule` plumbing chokes when recursively
2146+
# computing `tangent_type` on those fields. Delegate the tape to the
2147+
# `ChainRulesOriginator` path and re-solve with `SensitivityADPassThrough`
2148+
# for the primal.
2149+
function SciMLBase._concrete_solve_adjoint(
2150+
prob::Union{
2151+
SciMLBase.AbstractDiscreteProblem,
2152+
SciMLBase.AbstractODEProblem,
2153+
SciMLBase.AbstractDAEProblem,
2154+
SciMLBase.AbstractDDEProblem,
2155+
SciMLBase.AbstractSDEProblem,
2156+
SciMLBase.AbstractSDDEProblem,
2157+
SciMLBase.AbstractRODEProblem,
2158+
},
2159+
alg, sensealg::TrackerAdjoint,
2160+
u0, p, originator::SciMLBase.MooncakeOriginator,
2161+
args...; kwargs...
2162+
)
2163+
_, backpass = SciMLBase._concrete_solve_adjoint(
2164+
prob, alg, sensealg, u0, p,
2165+
SciMLBase.ChainRulesOriginator(), args...; kwargs...
2166+
)
2167+
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
2168+
primal = solve(
2169+
remake(prob; u0, p), alg, args...;
2170+
sensealg = DiffEqBase.SensitivityADPassThrough(),
2171+
kwargs_filtered...
2172+
)
2173+
return primal, backpass
2174+
end
2175+
20962176
const REVERSEDIFF_ADJOINT_GPU_COMPATIBILITY_MESSAGE = """
20972177
ReverseDiffAdjoint is not compatible GPU-based array types. Use a different
20982178
sensitivity analysis method, like InterpolatingAdjoint or TrackerAdjoint,
@@ -2148,10 +2228,6 @@ function SciMLBase._concrete_solve_adjoint(
21482228
throw(EnzymeTrackedRealError())
21492229
end
21502230

2151-
if originator isa SciMLBase.MooncakeOriginator
2152-
throw(MooncakeTrackedRealError())
2153-
end
2154-
21552231
t = eltype(prob.tspan)[]
21562232
u = typeof(u0)[]
21572233

@@ -2277,6 +2353,66 @@ function SciMLBase._concrete_solve_adjoint(
22772353
reversediff_adjoint_backpass
22782354
end
22792355

2356+
# Mooncake-specific `ReverseDiffAdjoint` path. The main `ReverseDiffAdjoint`
2357+
# method above returns `SciMLBase.sensitivity_solution(sol, …)` where `sol`
2358+
# still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in
2359+
# nested fields (`interp`, `prob`, `alg`, …). ChainRules / Zygote don't
2360+
# inspect the primal's type parameters, so they don't care. Mooncake's
2361+
# `@from_rrule` plumbing, on the other hand, calls
2362+
# `zero_tangent(y_primal)` and therefore recursively computes `tangent_type`
2363+
# for every nested field of the returned solution; that recursion fails on
2364+
# the tracked type parameters with either a `TypeError` or an unhelpful
2365+
# tangent-type error, which is what the `hybrid_diffeq` tutorial and
2366+
# PR #1419 ran into.
2367+
#
2368+
# This method delegates the tape construction (and hence the whole backward
2369+
# pass) to the `ChainRulesOriginator` path, then replaces the primal with
2370+
# a fresh plain-arithmetic solve of the same problem. The obvious
2371+
# alternative — walking the returned primal and stripping tracked scalars
2372+
# via `SciMLBase.value` recursively — *almost* works but fails on one
2373+
# specific slot: `solve()` wraps `prob.f` in a
2374+
# `FunctionWrappersWrappers.FunctionWrappersWrapper` during
2375+
# `get_concrete_problem`, and the resulting `FunctionWrapper` has no
2376+
# public positional constructor for `ConstructionBase.setproperties`, so a
2377+
# generic walker can't rebuild it. Mooncake's `@from_rrule` type
2378+
# inference nonetheless expects the `FunctionWrapper` version (because
2379+
# that's what `solve()` normally returns from a plain-arithmetic solve),
2380+
# so anything short of actually invoking `solve()` produces a type
2381+
# mismatch on the `DerivedRule` assertion. A truly general
2382+
# `strip_values(sol)` in SciMLBase would need the same `solve()`
2383+
# round-trip internally, so the cost is unavoidable here.
2384+
#
2385+
# Keeping this in a dedicated method dispatched on `MooncakeOriginator`
2386+
# stops Julia type inference from joining two different return shapes
2387+
# into a `Union{ODESolution{tracked…}, ODESolution{plain…}}`, which would
2388+
# otherwise trip Mooncake's `DerivedRule` type assertion.
2389+
function SciMLBase._concrete_solve_adjoint(
2390+
prob::Union{
2391+
SciMLBase.AbstractDiscreteProblem,
2392+
SciMLBase.AbstractODEProblem,
2393+
SciMLBase.AbstractDAEProblem,
2394+
SciMLBase.AbstractDDEProblem,
2395+
SciMLBase.AbstractSDEProblem,
2396+
SciMLBase.AbstractSDDEProblem,
2397+
SciMLBase.AbstractRODEProblem,
2398+
},
2399+
alg, sensealg::ReverseDiffAdjoint,
2400+
u0, p, originator::SciMLBase.MooncakeOriginator,
2401+
args...; kwargs...
2402+
)
2403+
_, backpass = SciMLBase._concrete_solve_adjoint(
2404+
prob, alg, sensealg, u0, p,
2405+
SciMLBase.ChainRulesOriginator(), args...; kwargs...
2406+
)
2407+
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
2408+
primal = solve(
2409+
remake(prob; u0, p), alg, args...;
2410+
sensealg = DiffEqBase.SensitivityADPassThrough(),
2411+
kwargs_filtered...
2412+
)
2413+
return primal, backpass
2414+
end
2415+
22802416
function SciMLBase._concrete_solve_adjoint(
22812417
prob::SciMLBase.AbstractODEProblem, alg,
22822418
sensealg::AbstractShadowingSensitivityAlgorithm,

test/concrete_solve_derivatives.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,20 @@ Tests callable structs with different AD backends
442442
end
443443
end
444444

445+
# Mooncake is not in `REVERSE_BACKENDS` because it doesn't yet compose
446+
# with every sensealg, but it does compose with `ReverseDiffAdjoint` and
447+
# `TrackerAdjoint` via the dedicated `MooncakeOriginator` dispatches
448+
# added in #1420 (the hybrid_diffeq.md pattern from #1419).
449+
@testset "Mooncake with ReverseDiffAdjoint" begin
450+
result = gradient_mooncake(senseloss(ReverseDiffAdjoint()), u0p)
451+
@test result ref_grad_senseloss
452+
end
453+
454+
@testset "Mooncake with TrackerAdjoint" begin
455+
result = gradient_mooncake(senseloss(TrackerAdjoint()), u0p)
456+
@test result ref_grad_senseloss
457+
end
458+
445459
# Test with p-only differentiation (senseloss3 and senseloss4 from alternative_ad_frontend.jl)
446460
struct senseloss_p{T}
447461
sense::T
@@ -470,6 +484,16 @@ Tests callable structs with different AD backends
470484
@test result ref_grad_p
471485
end
472486
end
487+
488+
# Mooncake + `ForwardSensitivity` via the dedicated `MooncakeOriginator`
489+
# dispatch added in #1420. p-only because `ForwardSensitivity` can't
490+
# differentiate `u0`, and the Mooncake dispatch rewrites the `du0`
491+
# slot to `NoTangent()` so Mooncake's cotangent threading doesn't trip
492+
# on the main method's `@not_implemented` stub for `du0`.
493+
@testset "Mooncake with ForwardSensitivity (p-only)" begin
494+
result = gradient_mooncake(senseloss_p(ForwardSensitivity()), p_only)
495+
@test result ref_grad_p
496+
end
473497
end
474498

475499
#=

0 commit comments

Comments
 (0)