Skip to content

Commit 951f7da

Browse files
Merge pull request #584 from ChrisRackauckas-Claude/fix-vector-int-indexing-583
Fix NamedArrayPartition Vector{Int} / UnitRange indexing (#583)
2 parents 0c18182 + 9cbf6d0 commit 951f7da

2 files changed

Lines changed: 79 additions & 9 deletions

File tree

src/named_array_partition.jl

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,16 @@ function Base.similar(A::NamedArrayPartition)
3434
)
3535
end
3636

37-
# return ArrayPartition when possible, otherwise next best thing of the correct size
37+
# return NamedArrayPartition when the requested dims still match the partition layout;
38+
# otherwise fall back to the plain backing array of the correct size. ArrayPartition's
39+
# own `similar(A, dims)` already does this degradation (it returns a Vector when
40+
# `dims != size(A)`), and we simply propagate that result instead of trying to
41+
# wrap a non-ArrayPartition in a NamedArrayPartition (which would hit the inner
42+
# constructor signature `NamedArrayPartition(::A<:ArrayPartition, ::NamedTuple)`).
3843
function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N}
39-
return NamedArrayPartition(
40-
similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices)
41-
)
44+
inner = similar(getfield(A, :array_partition), dims)
45+
inner isa ArrayPartition || return inner
46+
return NamedArrayPartition(inner, getfield(A, :names_to_indices))
4247
end
4348

4449
# similar array partition of common type
@@ -48,11 +53,10 @@ end
4853
)
4954
end
5055

51-
# return ArrayPartition when possible, otherwise next best thing of the correct size
5256
function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
53-
return NamedArrayPartition(
54-
similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices)
55-
)
57+
inner = similar(getfield(A, :array_partition), T, dims)
58+
inner isa ArrayPartition || return inner
59+
return NamedArrayPartition(inner, getfield(A, :names_to_indices))
5660
end
5761

5862
# similar array partition with different types
@@ -96,6 +100,28 @@ Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
96100
# Use concrete index types to avoid invalidating AbstractArray's generic setindex!.
97101
Base.@propagate_inbounds Base.getindex(x::NamedArrayPartition, i::Int) = ArrayPartition(x)[i]
98102
Base.@propagate_inbounds Base.setindex!(x::NamedArrayPartition, v, i::Int) = (ArrayPartition(x)[i] = v)
103+
104+
# Indexing with non-scalar indices (UnitRange, Vector{Int}, etc.) goes through
105+
# AbstractArray's generic path, which routes via `similar(A, T, dims)`. NAP's
106+
# `similar(::NAP, T, dims)` cannot in general produce a NamedArrayPartition for
107+
# arbitrary `dims` (the partition layout is fixed by `names_to_indices`), so it
108+
# falls back to a plain Vector — making the inferred return type a small Union.
109+
#
110+
# Mirror ArrayPartition's `_unsafe_getindex` shortcut at `array_partition.jl:317`:
111+
# allocate the destination directly off the first underlying array and fill it
112+
# via `_unsafe_getindex!`. The result is always a Vector for non-scalar indexing,
113+
# so `x[I]` is type-stable. This matches the v3 indexing semantics (`x[1:end]`
114+
# returns a `Vector`, not a `NamedArrayPartition`); use `similar(x)` /
115+
# `copy(x)` if you want a NamedArrayPartition back.
116+
Base.@propagate_inbounds function Base._unsafe_getindex(
117+
::IndexStyle, A::NamedArrayPartition,
118+
I::Vararg{Union{Real, AbstractArray}, N}
119+
) where {N}
120+
shape = Base.index_shape(I...)
121+
dest = similar(getfield(A, :array_partition).x[1], shape)
122+
Base._unsafe_getindex!(dest, A, I...)
123+
return dest
124+
end
99125
function Base.map(f, x::NamedArrayPartition)
100126
return NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
101127
end

test/named_array_partition_tests.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using RecursiveArrayTools, ArrayInterface, Test
77
@test typeof(similar(x)) <: NamedArrayPartition
88
@test typeof(similar(x, Int)) <: NamedArrayPartition
99
@test x.a ones(10)
10-
@test typeof(x .+ x[1:end]) <: NamedArrayPartition # x[1:end] preserves type
10+
@test typeof(x .+ x[1:end]) <: Vector # x[1:end] is a plain Vector (type-stable slicing)
1111
@test all(x .== x[1:end])
1212
@test ArrayInterface.zeromatrix(x) isa Matrix
1313
@test size(ArrayInterface.zeromatrix(x)) == (30, 30)
@@ -37,3 +37,47 @@ using RecursiveArrayTools, ArrayInterface, Test
3737
@test typeof((x -> x[1]).(x)) <: NamedArrayPartition
3838
@test typeof(map(x -> x[1], x)) <: NamedArrayPartition
3939
end
40+
41+
# Regression test for https://github.com/SciML/RecursiveArrayTools.jl/issues/583:
42+
# indexing a NamedArrayPartition with a UnitRange / Vector{Int} smaller than the
43+
# whole array used to throw a MethodError because the AbstractArray indexing
44+
# path called `similar(::NAP, T, dims)`, which tried to wrap a plain Vector
45+
# (returned by `similar(::ArrayPartition, T, dims)` for `dims != size(A)`) in
46+
# NamedArrayPartition's inner constructor, which requires an ArrayPartition.
47+
#
48+
# The `_unsafe_getindex(::IndexStyle, ::NAP, I...)` shortcut bypasses `similar`
49+
# entirely, allocating a plain Vector destination directly. Slicing therefore
50+
# always returns a Vector and is type-stable.
51+
@testset "NamedArrayPartition issue #583 indexing" begin
52+
x = NamedArrayPartition(a = ones(2), b = 2 * ones(3))
53+
54+
# UnitRange / Vector{Int} indexing all return Vector and are type-stable
55+
@test x[1:2] == [1.0, 1.0]
56+
@test x[2:4] == [1.0, 2.0, 2.0]
57+
@test x[1:end] == [1.0, 1.0, 2.0, 2.0, 2.0]
58+
@test x[[1, 2]] == [1.0, 1.0]
59+
@test x[[1, 4]] == [1.0, 2.0]
60+
61+
@test x[1:2] isa Vector{Float64}
62+
@test x[1:end] isa Vector{Float64}
63+
@test x[[1, 4]] isa Vector{Float64}
64+
65+
# Inferred return types: Vector, not Union
66+
@test (@inferred x[1:2]) isa Vector{Float64}
67+
@test (@inferred x[1:length(x)]) isa Vector{Float64}
68+
@test (@inferred x[[1, 4]]) isa Vector{Float64}
69+
70+
# `similar` with a non-matching dims falls back to the backing array;
71+
# with matching dims keeps the NamedArrayPartition wrapper.
72+
@test similar(x, Float64, (2,)) isa Vector{Float64}
73+
@test similar(x, (2,)) isa Vector{Float64}
74+
@test similar(x, Float64, size(x)) isa NamedArrayPartition
75+
@test similar(x, size(x)) isa NamedArrayPartition
76+
77+
# Scalar indexing untouched and type-stable
78+
@test x[1] == 1.0
79+
@test x[3] == 2.0
80+
@test (@inferred x[1]) === 1.0
81+
x[1] = 99.0
82+
@test x[1] == 99.0
83+
end

0 commit comments

Comments
 (0)