Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lib/RecursiveArrayToolsRaggedArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[weakdeps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

[extensions]
RecursiveArrayToolsRaggedArraysDiffEqBaseExt = "DiffEqBase"

[compat]
Adapt = "4"
ArrayInterface = "7"
Expand All @@ -20,9 +26,10 @@ SymbolicIndexingInterface = "0.3.35"
julia = "1.10"

[extras]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["SparseArrays", "SymbolicIndexingInterface", "Test"]
test = ["DiffEqBase", "SparseArrays", "SymbolicIndexingInterface", "Test"]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is circular: these overloads need to go into DiffEqBase instead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally fine with me. Do you (or one of your bots) create a PR in DiffEqBase.jl? Please ping me there if there is one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sorry Mr. Bot is a bit slow right now with all the travel and the big stuff of the recent majors.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, no worries. I removed the extension again as this will be handled in DiffEqBase.jl now. So I think this is ready from my side.

Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module RecursiveArrayToolsRaggedArraysDiffEqBaseExt

import RecursiveArrayTools: AbstractRaggedVectorOfArray
import DiffEqBase

# Mirror the AbstractVectorOfArray dispatch in DiffEqBase so that adaptive ODE
# solvers compute the correct RMS-normalized norm instead of the unnormalized
# Euclidean norm. Without these methods, ODE_DEFAULT_NORM falls through to
# `norm(u)` = sqrt(sum_abs2), which is sqrt(n_elements) times larger than the
# intended RMS norm, making the adaptive controller target a stricter tolerance
# than requested (abstol/reltol).

function DiffEqBase.UNITLESS_ABS2(x::AbstractRaggedVectorOfArray)
return mapreduce(DiffEqBase.UNITLESS_ABS2, +, x.u;
init = zero(real(eltype(x))))
end

function DiffEqBase.recursive_length(u::AbstractRaggedVectorOfArray)
return sum(DiffEqBase.recursive_length, u.u; init = 0)
end

function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractRaggedVectorOfArray, _)
return Base.FastMath.sqrt_fast(
DiffEqBase.UNITLESS_ABS2(u) / max(DiffEqBase.recursive_length(u), 1))
end

end # module
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module RecursiveArrayToolsRaggedArrays

import RecursiveArrayTools: RecursiveArrayTools, AbstractRaggedVectorOfArray,
AbstractRaggedDiffEqArray, VectorOfArray, DiffEqArray,
AbstractVectorOfArray, AbstractDiffEqArray, AllObserved
AbstractVectorOfArray, AbstractDiffEqArray, AllObserved,
recursivefill!, recursivecopy!
using SymbolicIndexingInterface
using SymbolicIndexingInterface: ParameterTimeseriesCollection, ParameterIndexingProxy,
ScalarSymbolic, ArraySymbolic, NotSymbolic, Timeseries, SymbolCache
Expand Down Expand Up @@ -1519,6 +1520,26 @@ end

Base.map(f, A::AbstractRaggedVectorOfArray) = map(f, A.u)

# Named functor used by the nested-ragged mapreduce to ensure type-stable dispatch.
struct _RaggedMapReduce{F, Op}
f::F
op::Op
end
@inline (w::_RaggedMapReduce)(u) = mapreduce(w.f, w.op, u)

# When inner elements are themselves ragged, the view-based approach fails: view uses
# size(A.u[1]) for every column, causing BoundsErrors when inner shapes differ.
# We recurse element-by-element instead. Dispatching on the type of A.u (rather than
# using an if-check at runtime) keeps inference type-stable down to Julia 1.10.
function Base.mapreduce(
f, op,
A::AbstractRaggedVectorOfArray{T, N, <:AbstractVector{<:AbstractRaggedVectorOfArray}};
kwargs...
) where {T, N}
isempty(kwargs) || return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
return mapreduce(_RaggedMapReduce(f, op), op, A.u)
end

function Base.mapreduce(f, op, A::AbstractRaggedVectorOfArray; kwargs...)
return mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
end
Expand Down Expand Up @@ -1725,4 +1746,39 @@ end
# Re-export has_discretes and get_discretes for the non-ragged types
has_discretes(::TT) where {TT <: AbstractDiffEqArray} = hasfield(TT, :discretes)

function recursivecopy!(b::AbstractRaggedVectorOfArray, a::AbstractRaggedVectorOfArray)
@inbounds for i in eachindex(b.u, a.u)
if ArrayInterface.ismutable(b.u[i]) || b.u[i] isa AbstractRaggedVectorOfArray
recursivecopy!(b.u[i], a.u[i])
else
b.u[i] = copy(a.u[i])
end
end
return b
end

function recursivefill!(
b::AbstractRaggedVectorOfArray{T, N},
a::T2
) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N}
return fill!(b, a)
end

function recursivefill!(
b::AbstractRaggedVectorOfArray{T, N},
a::T2
) where {T <: StaticArraysCore.SArray, T2 <: Union{Number, Bool}, N}
@inbounds for arr in b.u, i in eachindex(arr)
arr[i] = map(_ -> a, arr[i])
end
return b
end

function recursivefill!(b::AbstractRaggedVectorOfArray{T, N}, a) where {T <: AbstractArray, N}
@inbounds for arr in b.u
recursivefill!(arr, a)
end
return b
end

end # module RecursiveArrayToolsRaggedArrays
55 changes: 55 additions & 0 deletions lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using RecursiveArrayTools, RecursiveArrayToolsRaggedArrays
using RecursiveArrayToolsRaggedArrays: RaggedEnd, RaggedRange
using SymbolicIndexingInterface
using SymbolicIndexingInterface: SymbolCache
import DiffEqBase: ODE_DEFAULT_NORM, UNITLESS_ABS2, recursive_length
using Test

@testset "RecursiveArrayToolsRaggedArrays" begin
Expand Down Expand Up @@ -986,4 +987,58 @@ using Test
@test rd isa RecursiveArrayTools.AbstractRaggedVectorOfArray
@test !(rd isa AbstractArray)
end

@testset "recursivefill! for RaggedVectorOfArray" begin
# Bool argument — the pattern used by ODE solver cache initialisation
r = RaggedVectorOfArray([ones(2), ones(3)])
recursivefill!(r, false)
@test r[:, 1] == [0.0, 0.0]
@test r[:, 2] == [0.0, 0.0, 0.0]

# Numeric argument
r2 = RaggedVectorOfArray([zeros(2), zeros(3)])
recursivefill!(r2, 1.0)
@test r2[:, 1] == [1.0, 1.0]
@test r2[:, 2] == [1.0, 1.0, 1.0]

# Ragged sizes are preserved
@test length(r[:, 1]) == 2
@test length(r[:, 2]) == 3
end

@testset "recursivecopy! for RaggedVectorOfArray" begin
src = RaggedVectorOfArray([ones(2), 2 * ones(3)])
dst = RaggedVectorOfArray([zeros(2), zeros(3)])
recursivecopy!(dst, src)
@test dst[:, 1] == [1.0, 1.0]
@test dst[:, 2] == [2.0, 2.0, 2.0]

# Verify deep copy — modifying src must not affect dst
src[:, 1] .= 99.0
@test dst[:, 1] == [1.0, 1.0]
end

@testset "mapreduce over nested ragged arrays" begin
# Outer array whose inner RaggedVoA elements have different column counts.
# mapreduce must recurse over A.u rather than building a fixed-shape view.
inner1 = RaggedVectorOfArray([ones(3), ones(3)]) # 2 columns
inner2 = RaggedVectorOfArray([ones(3), ones(3), ones(3)]) # 3 columns — ragged!
u = RaggedVectorOfArray([inner1, inner2])

@test mapreduce(identity, +, u) == 15.0 # (2+3)*3
end

@testset "ODE_DEFAULT_NORM: RMS-normalised for RaggedVectorOfArray" begin
# Loading OrdinaryDiffEqTsit5 (which depends on DiffEqBase) triggers the weakdep
# extension, giving the correct RMS-normalised norm instead of the unnormalised
# Euclidean norm used by the generic fallback.
r = RaggedVectorOfArray([ones(3), ones(3)]) # 6 ones
@test UNITLESS_ABS2(r) ≈ 6.0
@test recursive_length(r) == 6
# RMS norm of 6 ones = sqrt(6/6) = 1
@test ODE_DEFAULT_NORM(r, 0.0) ≈ 1.0
# Unnormalised Euclidean norm would be sqrt(6) ≈ 2.449 — make sure we don't get that
@test ODE_DEFAULT_NORM(r, 0.0) < 2.0
end

end
Loading