Skip to content

Commit 9cbf6d0

Browse files
Make NamedArrayPartition slicing fully type-stable via Vector
The first commit on this branch fixed `x[1:2]` and `x[[1, 4]]` by making `similar(::NAP, T, dims)` degrade to a Vector when `dims != size(A)`, mirroring ArrayPartition. That worked, but it left the indexing path inferring as `Union{NamedArrayPartition, Vector{Float64}}` because `similar(::ArrayPartition, T, dims)` itself is a Union (the `dims == size(A)` branch is a runtime check). Add a `_unsafe_getindex(::IndexStyle, ::NAP, I::Vararg{Union{Real, AbstractArray}})` shortcut that mirrors the one at array_partition.jl:317. Allocate the destination directly off the underlying first array and fill it with `Base._unsafe_getindex!`. The shortcut bypasses `similar` entirely for the indexing path, so `x[1:2]`, `x[[1, 4]]`, `x[1:length(x)]` all infer to a clean `Vector{Float64}`. Trade-off: this regresses the post-05faa730 test `typeof(x .+ x[1:end]) <: NamedArrayPartition` back to `<: Vector` (the v3 behavior). That test was added in 05faa73 alongside the invalidation cleanup, but its only effect was to mask the unstable small-Union return; restoring v3 semantics here gives full type stability and matches what ArrayPartition already does. Use `similar(x)` / `copy(x)` if you want a NamedArrayPartition back. The `similar(::NAP, dims)` and `similar(::NAP, T, dims)` overloads keep the graceful-degrade-to-Vector behavior from the previous commit, so direct `similar(x, T, (2,))` calls (e.g. from downstream library code) still work. `SnoopCompile.invalidation_trees(@snoop_invalidations using RecursiveArrayTools)` still reports 0 trees, full `Pkg.test()` passes, and the new regression test asserts type stability via `@inferred`. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2466954 commit 9cbf6d0

2 files changed

Lines changed: 47 additions & 17 deletions

File tree

src/named_array_partition.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,28 @@ Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
100100
# Use concrete index types to avoid invalidating AbstractArray's generic setindex!.
101101
Base.@propagate_inbounds Base.getindex(x::NamedArrayPartition, i::Int) = ArrayPartition(x)[i]
102102
Base.@propagate_inbounds Base.setindex!(x::NamedArrayPartition, v, i::Int) = (ArrayPartition(x)[i] = v)
103+
104+
# Indexing with non-scalar indices (UnitRange, Vector{Int}, etc.) goes through
105+
# AbstractArray's generic path, which routes via `similar(A, T, dims)`. NAP's
106+
# `similar(::NAP, T, dims)` cannot in general produce a NamedArrayPartition for
107+
# arbitrary `dims` (the partition layout is fixed by `names_to_indices`), so it
108+
# falls back to a plain Vector — making the inferred return type a small Union.
109+
#
110+
# Mirror ArrayPartition's `_unsafe_getindex` shortcut at `array_partition.jl:317`:
111+
# allocate the destination directly off the first underlying array and fill it
112+
# via `_unsafe_getindex!`. The result is always a Vector for non-scalar indexing,
113+
# so `x[I]` is type-stable. This matches the v3 indexing semantics (`x[1:end]`
114+
# returns a `Vector`, not a `NamedArrayPartition`); use `similar(x)` /
115+
# `copy(x)` if you want a NamedArrayPartition back.
116+
Base.@propagate_inbounds function Base._unsafe_getindex(
117+
::IndexStyle, A::NamedArrayPartition,
118+
I::Vararg{Union{Real, AbstractArray}, N}
119+
) where {N}
120+
shape = Base.index_shape(I...)
121+
dest = similar(getfield(A, :array_partition).x[1], shape)
122+
Base._unsafe_getindex!(dest, A, I...)
123+
return dest
124+
end
103125
function Base.map(f, x::NamedArrayPartition)
104126
return NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
105127
end

test/named_array_partition_tests.jl

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using RecursiveArrayTools, ArrayInterface, Test
77
@test typeof(similar(x)) <: NamedArrayPartition
88
@test typeof(similar(x, Int)) <: NamedArrayPartition
99
@test x.a ones(10)
10-
@test typeof(x .+ x[1:end]) <: NamedArrayPartition # x[1:end] preserves type
10+
@test typeof(x .+ x[1:end]) <: Vector # x[1:end] is a plain Vector (type-stable slicing)
1111
@test all(x .== x[1:end])
1212
@test ArrayInterface.zeromatrix(x) isa Matrix
1313
@test size(ArrayInterface.zeromatrix(x)) == (30, 30)
@@ -39,37 +39,45 @@ using RecursiveArrayTools, ArrayInterface, Test
3939
end
4040

4141
# Regression test for https://github.com/SciML/RecursiveArrayTools.jl/issues/583:
42-
# indexing a NamedArrayPartition with a UnitRange / Vector{Int} smaller than
43-
# the whole array used to throw a MethodError because `similar(::NAP, T, dims)`
44-
# tried to wrap a plain Vector (returned by `similar(::ArrayPartition, T, dims)`
45-
# when `dims != size(A)`) in NamedArrayPartition's inner constructor, which
46-
# requires an ArrayPartition.
42+
# indexing a NamedArrayPartition with a UnitRange / Vector{Int} smaller than the
43+
# whole array used to throw a MethodError because the AbstractArray indexing
44+
# path called `similar(::NAP, T, dims)`, which tried to wrap a plain Vector
45+
# (returned by `similar(::ArrayPartition, T, dims)` for `dims != size(A)`) in
46+
# NamedArrayPartition's inner constructor, which requires an ArrayPartition.
47+
#
48+
# The `_unsafe_getindex(::IndexStyle, ::NAP, I...)` shortcut bypasses `similar`
49+
# entirely, allocating a plain Vector destination directly. Slicing therefore
50+
# always returns a Vector and is type-stable.
4751
@testset "NamedArrayPartition issue #583 indexing" begin
4852
x = NamedArrayPartition(a = ones(2), b = 2 * ones(3))
4953

50-
# UnitRange that doesn't span the whole array: returns a plain Vector
54+
# UnitRange / Vector{Int} indexing all return Vector and are type-stable
5155
@test x[1:2] == [1.0, 1.0]
52-
@test x[1:2] isa Vector{Float64}
5356
@test x[2:4] == [1.0, 2.0, 2.0]
54-
55-
# Vector{Int} indexing
57+
@test x[1:end] == [1.0, 1.0, 2.0, 2.0, 2.0]
5658
@test x[[1, 2]] == [1.0, 1.0]
5759
@test x[[1, 4]] == [1.0, 2.0]
60+
61+
@test x[1:2] isa Vector{Float64}
62+
@test x[1:end] isa Vector{Float64}
5863
@test x[[1, 4]] isa Vector{Float64}
5964

60-
# Existing behavior: full-range slice still preserves NamedArrayPartition
61-
@test typeof(x[1:end]) <: NamedArrayPartition
65+
# Inferred return types: Vector, not Union
66+
@test (@inferred x[1:2]) isa Vector{Float64}
67+
@test (@inferred x[1:length(x)]) isa Vector{Float64}
68+
@test (@inferred x[[1, 4]]) isa Vector{Float64}
6269

63-
# `similar` with a non-matching dims tuple gives the backing-array fallback
70+
# `similar` with a non-matching dims falls back to the backing array;
71+
# with matching dims keeps the NamedArrayPartition wrapper.
6472
@test similar(x, Float64, (2,)) isa Vector{Float64}
65-
@test similar(x, (2,)) isa Vector{Float64}
66-
# `similar` with matching dims preserves the NamedArrayPartition wrapper
73+
@test similar(x, (2,)) isa Vector{Float64}
6774
@test similar(x, Float64, size(x)) isa NamedArrayPartition
68-
@test similar(x, size(x)) isa NamedArrayPartition
75+
@test similar(x, size(x)) isa NamedArrayPartition
6976

70-
# Scalar indexing untouched
77+
# Scalar indexing untouched and type-stable
7178
@test x[1] == 1.0
7279
@test x[3] == 2.0
80+
@test (@inferred x[1]) === 1.0
7381
x[1] = 99.0
7482
@test x[1] == 99.0
7583
end

0 commit comments

Comments
 (0)