Skip to content

Commit a22d73a

Browse files
Merge pull request #575 from ChrisRackauckas-Claude/mooncake-arraypartition-rdata
Add Mooncake extension for ArrayPartition cotangents
2 parents 3f59673 + c547a94 commit a22d73a

4 files changed

Lines changed: 90 additions & 2 deletions

File tree

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
3-
version = "4.0.1"
43
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4+
version = "4.0.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -20,6 +20,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
2020
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2121
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2222
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
23+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2324
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2425
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
2526
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -38,6 +39,7 @@ RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
3839
RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions"
3940
RecursiveArrayToolsMeasurementsExt = "Measurements"
4041
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
42+
RecursiveArrayToolsMooncakeExt = "Mooncake"
4143
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
4244
RecursiveArrayToolsSparseArraysExt = ["SparseArrays"]
4345
RecursiveArrayToolsStatisticsExt = "Statistics"
@@ -59,6 +61,7 @@ KernelAbstractions = "0.9.36"
5961
LinearAlgebra = "1.10"
6062
Measurements = "2.11"
6163
MonteCarloMeasurements = "1.2"
64+
Mooncake = "0.5"
6265
NLsolve = "4.5"
6366
Pkg = "1"
6467
Polyester = "0.7.16"
@@ -86,6 +89,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
8689
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8790
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
8891
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
92+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
8993
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
9094
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
9195
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -102,4 +106,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
102106
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
103107

104108
[targets]
105-
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
109+
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "Mooncake", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module RecursiveArrayToolsMooncakeExt
2+
3+
using RecursiveArrayTools
4+
using Mooncake
5+
6+
# `ArrayPartition` cotangent handling for `@from_chainrules` /
7+
# `@from_rrule`-generated rules.
8+
#
9+
# When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
10+
# `_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
11+
# such as the one produced by `SecondOrderODEProblem`) returns a parameter
12+
# / state cotangent as an `ArrayPartition`, Mooncake's
13+
# `@from_chainrules`/`@from_rrule` accumulator looks for an
14+
# `increment_and_get_rdata!` method matching
15+
# `(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)`
16+
# — and there isn't one by default, so the call falls through to the
17+
# generic error path:
18+
#
19+
# ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
20+
# rdata type Mooncake.NoRData, and tangent type
21+
# RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
22+
# combination is not supported with @from_chainrules or @from_rrule.
23+
#
24+
# Add the missing dispatch. An `ArrayPartition`'s only field is `x::Tuple`
25+
# of inner arrays, so the FData layout is
26+
# `FData{@NamedTuple{x::Tuple{...}}}` and the inner tuple positions line up
27+
# with `t.x`. Walk the tuple element-by-element and forward each leaf to
28+
# the existing `increment_and_get_rdata!` for the leaf's array type, which
29+
# does the actual in-place accumulation.
30+
function Mooncake.increment_and_get_rdata!(
31+
f::Mooncake.FData{@NamedTuple{x::T}},
32+
r::Mooncake.NoRData,
33+
t::ArrayPartition{P, T},
34+
) where {P, T <: Tuple}
35+
fxs = f.data[:x]
36+
txs = t.x
37+
@assert length(fxs) == length(txs)
38+
for i in eachindex(fxs)
39+
Mooncake.increment_and_get_rdata!(fxs[i], Mooncake.NoRData(), txs[i])
40+
end
41+
return Mooncake.NoRData()
42+
end
43+
44+
end

test/mooncake.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using RecursiveArrayTools, Mooncake, Test
2+
3+
# Regression test for the `RecursiveArrayToolsMooncakeExt` dispatch that
4+
# lets Mooncake's `@from_chainrules`/`@from_rrule` accumulator handle an
5+
# `ArrayPartition` cotangent returned by an upstream ChainRule (e.g.
6+
# SciMLSensitivity's `_concrete_solve_adjoint` for a `SecondOrderODEProblem`).
7+
# Without the extension, the call below fell through to Mooncake's generic
8+
# error path:
9+
#
10+
# ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float64}, Vector{Float64}}}},
11+
# rdata type Mooncake.NoRData, and tangent type
12+
# RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}
13+
# combination is not supported with @from_chainrules or @from_rrule.
14+
15+
@testset "ArrayPartition increment_and_get_rdata!" begin
16+
@test Base.get_extension(RecursiveArrayTools, :RecursiveArrayToolsMooncakeExt) !==
17+
nothing
18+
19+
# Tangent produced by an upstream ChainRule.
20+
t = ArrayPartition([1.0, 2.0], [3.0, 4.0])
21+
# Pre-existing FData that the method should accumulate into in place.
22+
f = Mooncake.FData((x = ([10.0, 20.0], [30.0, 40.0]),))
23+
24+
r = Mooncake.increment_and_get_rdata!(f, Mooncake.NoRData(), t)
25+
26+
@test r === Mooncake.NoRData()
27+
@test f.data.x[1] == [11.0, 22.0]
28+
@test f.data.x[2] == [33.0, 44.0]
29+
30+
# Three-way partition with Float32 leaves — exercises the inner
31+
# per-leaf dispatch on a different eltype and arity.
32+
t32 = ArrayPartition(Float32[1, 2], Float32[3, 4, 5], Float32[6])
33+
f32 = Mooncake.FData((x = (Float32[10, 20], Float32[30, 40, 50], Float32[60]),))
34+
r32 = Mooncake.increment_and_get_rdata!(f32, Mooncake.NoRData(), t32)
35+
@test r32 === Mooncake.NoRData()
36+
@test f32.data.x[1] == Float32[11, 22]
37+
@test f32.data.x[2] == Float32[33, 44, 55]
38+
@test f32.data.x[3] == Float32[66]
39+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ end
4646
@time @safetestset "StaticArrays Tests" include("copy_static_array_test.jl")
4747
@time @safetestset "Linear Algebra Tests" include("linalg.jl")
4848
@time @safetestset "Adjoint Tests" include("adjoints.jl")
49+
@time @safetestset "Mooncake Tests" include("mooncake.jl")
4950
@time @safetestset "Measurement Tests" include("measurements.jl")
5051
end
5152

0 commit comments

Comments
 (0)