Skip to content

Commit c888aa7

Browse files
Merge pull request #582 from JoshuaLampert/recursivefill-raggedarrays
Add `recursivefill!` and `recursivecopy!` for ragged arrays
2 parents 45eb808 + 546b1ab commit c888aa7

2 files changed

Lines changed: 98 additions & 1 deletion

File tree

lib/RecursiveArrayToolsRaggedArrays/src/RecursiveArrayToolsRaggedArrays.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module RecursiveArrayToolsRaggedArrays
22

33
import RecursiveArrayTools: RecursiveArrayTools, AbstractRaggedVectorOfArray,
44
AbstractRaggedDiffEqArray, VectorOfArray, DiffEqArray,
5-
AbstractVectorOfArray, AbstractDiffEqArray, AllObserved
5+
AbstractVectorOfArray, AbstractDiffEqArray, AllObserved,
6+
recursivefill!, recursivecopy!
67
using SymbolicIndexingInterface
78
using SymbolicIndexingInterface: ParameterTimeseriesCollection, ParameterIndexingProxy,
89
ScalarSymbolic, ArraySymbolic, NotSymbolic, Timeseries, SymbolCache
@@ -1519,6 +1520,26 @@ end
15191520

15201521
Base.map(f, A::AbstractRaggedVectorOfArray) = map(f, A.u)
15211522

1523+
# Named functor used by the nested-ragged mapreduce to ensure type-stable dispatch.
1524+
struct _RaggedMapReduce{F, Op}
1525+
f::F
1526+
op::Op
1527+
end
1528+
@inline (w::_RaggedMapReduce)(u) = mapreduce(w.f, w.op, u)
1529+
1530+
# When inner elements are themselves ragged, the view-based approach fails: view uses
1531+
# size(A.u[1]) for every column, causing BoundsErrors when inner shapes differ.
1532+
# We recurse element-by-element instead. Dispatching on the type of A.u (rather than
1533+
# using an if-check at runtime) keeps inference type-stable down to Julia 1.10.
1534+
function Base.mapreduce(
1535+
f, op,
1536+
A::AbstractRaggedVectorOfArray{T, N, <:AbstractVector{<:AbstractRaggedVectorOfArray}};
1537+
kwargs...
1538+
) where {T, N}
1539+
isempty(kwargs) || return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
1540+
return mapreduce(_RaggedMapReduce(f, op), op, A.u)
1541+
end
1542+
15221543
function Base.mapreduce(f, op, A::AbstractRaggedVectorOfArray; kwargs...)
15231544
return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
15241545
end
@@ -1725,4 +1746,39 @@ end
17251746
# Re-export has_discretes and get_discretes for the non-ragged types
17261747
has_discretes(::TT) where {TT <: AbstractDiffEqArray} = hasfield(TT, :discretes)
17271748

1749+
function recursivecopy!(b::AbstractRaggedVectorOfArray, a::AbstractRaggedVectorOfArray)
1750+
@inbounds for i in eachindex(b.u, a.u)
1751+
if ArrayInterface.ismutable(b.u[i]) || b.u[i] isa AbstractRaggedVectorOfArray
1752+
recursivecopy!(b.u[i], a.u[i])
1753+
else
1754+
b.u[i] = copy(a.u[i])
1755+
end
1756+
end
1757+
return b
1758+
end
1759+
1760+
function recursivefill!(
1761+
b::AbstractRaggedVectorOfArray{T, N},
1762+
a::T2
1763+
) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N}
1764+
return fill!(b, a)
1765+
end
1766+
1767+
function recursivefill!(
1768+
b::AbstractRaggedVectorOfArray{T, N},
1769+
a::T2
1770+
) where {T <: StaticArraysCore.SArray, T2 <: Union{Number, Bool}, N}
1771+
@inbounds for arr in b.u, i in eachindex(arr)
1772+
arr[i] = map(_ -> a, arr[i])
1773+
end
1774+
return b
1775+
end
1776+
1777+
function recursivefill!(b::AbstractRaggedVectorOfArray{T, N}, a) where {T <: AbstractArray, N}
1778+
@inbounds for arr in b.u
1779+
recursivefill!(arr, a)
1780+
end
1781+
return b
1782+
end
1783+
17281784
end # module RecursiveArrayToolsRaggedArrays

lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,4 +986,45 @@ using Test
986986
@test rd isa RecursiveArrayTools.AbstractRaggedVectorOfArray
987987
@test !(rd isa AbstractArray)
988988
end
989+
990+
@testset "recursivefill! for RaggedVectorOfArray" begin
991+
# Bool argument — the pattern used by ODE solver cache initialisation
992+
r = RaggedVectorOfArray([ones(2), ones(3)])
993+
recursivefill!(r, false)
994+
@test r[:, 1] == [0.0, 0.0]
995+
@test r[:, 2] == [0.0, 0.0, 0.0]
996+
997+
# Numeric argument
998+
r2 = RaggedVectorOfArray([zeros(2), zeros(3)])
999+
recursivefill!(r2, 1.0)
1000+
@test r2[:, 1] == [1.0, 1.0]
1001+
@test r2[:, 2] == [1.0, 1.0, 1.0]
1002+
1003+
# Ragged sizes are preserved
1004+
@test length(r[:, 1]) == 2
1005+
@test length(r[:, 2]) == 3
1006+
end
1007+
1008+
@testset "recursivecopy! for RaggedVectorOfArray" begin
1009+
src = RaggedVectorOfArray([ones(2), 2 * ones(3)])
1010+
dst = RaggedVectorOfArray([zeros(2), zeros(3)])
1011+
recursivecopy!(dst, src)
1012+
@test dst[:, 1] == [1.0, 1.0]
1013+
@test dst[:, 2] == [2.0, 2.0, 2.0]
1014+
1015+
# Verify deep copy — modifying src must not affect dst
1016+
src[:, 1] .= 99.0
1017+
@test dst[:, 1] == [1.0, 1.0]
1018+
end
1019+
1020+
@testset "mapreduce over nested ragged arrays" begin
1021+
# Outer array whose inner RaggedVoA elements have different column counts.
1022+
# mapreduce must recurse over A.u rather than building a fixed-shape view.
1023+
inner1 = RaggedVectorOfArray([ones(3), ones(3)]) # 2 columns
1024+
inner2 = RaggedVectorOfArray([ones(3), ones(3), ones(3)]) # 3 columns — ragged!
1025+
u = RaggedVectorOfArray([inner1, inner2])
1026+
1027+
@test mapreduce(identity, +, u) == 15.0 # (2+3)*3
1028+
end
1029+
9891030
end

0 commit comments

Comments
 (0)