diff --git a/lib/RecursiveArrayToolsRaggedArrays/src/RecursiveArrayToolsRaggedArrays.jl b/lib/RecursiveArrayToolsRaggedArrays/src/RecursiveArrayToolsRaggedArrays.jl index 009cd044..c5bf97e4 100644 --- a/lib/RecursiveArrayToolsRaggedArrays/src/RecursiveArrayToolsRaggedArrays.jl +++ b/lib/RecursiveArrayToolsRaggedArrays/src/RecursiveArrayToolsRaggedArrays.jl @@ -2,7 +2,8 @@ module RecursiveArrayToolsRaggedArrays import RecursiveArrayTools: RecursiveArrayTools, AbstractRaggedVectorOfArray, AbstractRaggedDiffEqArray, VectorOfArray, DiffEqArray, - AbstractVectorOfArray, AbstractDiffEqArray, AllObserved + AbstractVectorOfArray, AbstractDiffEqArray, AllObserved, + recursivefill!, recursivecopy! using SymbolicIndexingInterface using SymbolicIndexingInterface: ParameterTimeseriesCollection, ParameterIndexingProxy, ScalarSymbolic, ArraySymbolic, NotSymbolic, Timeseries, SymbolCache @@ -1519,6 +1520,26 @@ end Base.map(f, A::AbstractRaggedVectorOfArray) = map(f, A.u) +# Named functor used by the nested-ragged mapreduce to ensure type-stable dispatch. +struct _RaggedMapReduce{F, Op} + f::F + op::Op +end +@inline (w::_RaggedMapReduce)(u) = mapreduce(w.f, w.op, u) + +# When inner elements are themselves ragged, the view-based approach fails: view uses +# size(A.u[1]) for every column, causing BoundsErrors when inner shapes differ. +# We recurse element-by-element instead. Dispatching on the type of A.u (rather than +# using an if-check at runtime) keeps inference type-stable down to Julia 1.10. +function Base.mapreduce( + f, op, + A::AbstractRaggedVectorOfArray{T, N, <:AbstractVector{<:AbstractRaggedVectorOfArray}}; + kwargs... + ) where {T, N} + isempty(kwargs) || return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...) + return mapreduce(_RaggedMapReduce(f, op), op, A.u) +end + function Base.mapreduce(f, op, A::AbstractRaggedVectorOfArray; kwargs...) return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...) end @@ -1725,4 +1746,39 @@ end # Re-export has_discretes and get_discretes for the non-ragged types has_discretes(::TT) where {TT <: AbstractDiffEqArray} = hasfield(TT, :discretes) +function recursivecopy!(b::AbstractRaggedVectorOfArray, a::AbstractRaggedVectorOfArray) + @inbounds for i in eachindex(b.u, a.u) + if ArrayInterface.ismutable(b.u[i]) || b.u[i] isa AbstractRaggedVectorOfArray + recursivecopy!(b.u[i], a.u[i]) + else + b.u[i] = copy(a.u[i]) + end + end + return b +end + +function recursivefill!( + b::AbstractRaggedVectorOfArray{T, N}, + a::T2 + ) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N} + return fill!(b, a) +end + +function recursivefill!( + b::AbstractRaggedVectorOfArray{T, N}, + a::T2 + ) where {T <: StaticArraysCore.SArray, T2 <: Union{Number, Bool}, N} + @inbounds for arr in b.u, i in eachindex(arr) + arr[i] = map(_ -> a, arr[i]) + end + return b +end + +function recursivefill!(b::AbstractRaggedVectorOfArray{T, N}, a) where {T <: AbstractArray, N} + @inbounds for arr in b.u + recursivefill!(arr, a) + end + return b +end + end # module RecursiveArrayToolsRaggedArrays diff --git a/lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl b/lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl index b3489c41..923cc63f 100644 --- a/lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl +++ b/lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl @@ -986,4 +986,45 @@ using Test @test rd isa RecursiveArrayTools.AbstractRaggedVectorOfArray @test !(rd isa AbstractArray) end + + @testset "recursivefill! for RaggedVectorOfArray" begin + # Bool argument — the pattern used by ODE solver cache initialisation + r = RaggedVectorOfArray([ones(2), ones(3)]) + recursivefill!(r, false) + @test r[:, 1] == [0.0, 0.0] + @test r[:, 2] == [0.0, 0.0, 0.0] + + # Numeric argument + r2 = RaggedVectorOfArray([zeros(2), zeros(3)]) + recursivefill!(r2, 1.0) + @test r2[:, 1] == [1.0, 1.0] + @test r2[:, 2] == [1.0, 1.0, 1.0] + + # Ragged sizes are preserved + @test length(r[:, 1]) == 2 + @test length(r[:, 2]) == 3 + end + + @testset "recursivecopy! for RaggedVectorOfArray" begin + src = RaggedVectorOfArray([ones(2), 2 * ones(3)]) + dst = RaggedVectorOfArray([zeros(2), zeros(3)]) + recursivecopy!(dst, src) + @test dst[:, 1] == [1.0, 1.0] + @test dst[:, 2] == [2.0, 2.0, 2.0] + + # Verify deep copy — modifying src must not affect dst + src[:, 1] .= 99.0 + @test dst[:, 1] == [1.0, 1.0] + end + + @testset "mapreduce over nested ragged arrays" begin + # Outer array whose inner RaggedVoA elements have different column counts. + # mapreduce must recurse over A.u rather than building a fixed-shape view. + inner1 = RaggedVectorOfArray([ones(3), ones(3)]) # 2 columns + inner2 = RaggedVectorOfArray([ones(3), ones(3), ones(3)]) # 3 columns — ragged! + u = RaggedVectorOfArray([inner1, inner2]) + + @test mapreduce(identity, +, u) == 15.0 # (2+3)*3 + end + end