Skip to content

Commit e1e0c07

Browse files
committed
add SciMLBase extension
1 parent 08d3450 commit e1e0c07

3 files changed

Lines changed: 49 additions & 1 deletion

File tree

lib/RecursiveArrayToolsRaggedArrays/Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1010
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1111
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1212

13+
[weakdeps]
14+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
15+
16+
[extensions]
17+
RecursiveArrayToolsRaggedArraysDiffEqBaseExt = "DiffEqBase"
18+
1319
[compat]
1420
Adapt = "4"
1521
ArrayInterface = "7"
@@ -20,9 +26,10 @@ SymbolicIndexingInterface = "0.3.35"
2026
julia = "1.10"
2127

2228
[extras]
29+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2330
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2431
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2532
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2633

2734
[targets]
28-
test = ["SparseArrays", "SymbolicIndexingInterface", "Test"]
35+
test = ["DiffEqBase", "SparseArrays", "SymbolicIndexingInterface", "Test"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module RecursiveArrayToolsRaggedArraysDiffEqBaseExt
2+
3+
import RecursiveArrayTools: AbstractRaggedVectorOfArray
4+
import DiffEqBase
5+
6+
# Mirror the AbstractVectorOfArray dispatch in DiffEqBase so that adaptive ODE
7+
# solvers compute the correct RMS-normalized norm instead of the unnormalized
8+
# Euclidean norm. Without these methods, ODE_DEFAULT_NORM falls through to
9+
# `norm(u)` = sqrt(sum_abs2), which is sqrt(n_elements) times larger than the
10+
# intended RMS norm, making the adaptive controller target a stricter tolerance
11+
# than requested (abstol/reltol).
12+
13+
function DiffEqBase.UNITLESS_ABS2(x::AbstractRaggedVectorOfArray)
14+
return mapreduce(DiffEqBase.UNITLESS_ABS2, +, x.u;
15+
init = zero(real(eltype(x))))
16+
end
17+
18+
function DiffEqBase.recursive_length(u::AbstractRaggedVectorOfArray)
19+
return sum(DiffEqBase.recursive_length, u.u; init = 0)
20+
end
21+
22+
function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractRaggedVectorOfArray, _)
23+
return Base.FastMath.sqrt_fast(
24+
DiffEqBase.UNITLESS_ABS2(u) / max(DiffEqBase.recursive_length(u), 1))
25+
end
26+
27+
end # module

lib/RecursiveArrayToolsRaggedArrays/test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using RecursiveArrayTools, RecursiveArrayToolsRaggedArrays
22
using RecursiveArrayToolsRaggedArrays: RaggedEnd, RaggedRange
33
using SymbolicIndexingInterface
44
using SymbolicIndexingInterface: SymbolCache
5+
import DiffEqBase: ODE_DEFAULT_NORM, UNITLESS_ABS2, recursive_length
56
using Test
67

78
@testset "RecursiveArrayToolsRaggedArrays" begin
@@ -1027,4 +1028,17 @@ using Test
10271028
@test mapreduce(identity, +, u) == 15.0 # (2+3)*3
10281029
end
10291030

1031+
@testset "ODE_DEFAULT_NORM: RMS-normalised for RaggedVectorOfArray" begin
1032+
# Loading OrdinaryDiffEqTsit5 (which depends on DiffEqBase) triggers the weakdep
1033+
# extension, giving the correct RMS-normalised norm instead of the unnormalised
1034+
# Euclidean norm used by the generic fallback.
1035+
r = RaggedVectorOfArray([ones(3), ones(3)]) # 6 ones
1036+
@test UNITLESS_ABS2(r) 6.0
1037+
@test recursive_length(r) == 6
1038+
# RMS norm of 6 ones = sqrt(6/6) = 1
1039+
@test ODE_DEFAULT_NORM(r, 0.0) 1.0
1040+
# Unnormalised Euclidean norm would be sqrt(6) ≈ 2.449 — make sure we don't get that
1041+
@test ODE_DEFAULT_NORM(r, 0.0) < 2.0
1042+
end
1043+
10301044
end

0 commit comments

Comments
 (0)