diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index b2546334..887f658e 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -34,11 +34,16 @@ function Base.similar(A::NamedArrayPartition) ) end -# return ArrayPartition when possible, otherwise next best thing of the correct size +# return NamedArrayPartition when the requested dims still match the partition layout; +# otherwise fall back to the plain backing array of the correct size. ArrayPartition's +# own `similar(A, dims)` already does this degradation (it returns a Vector when +# `dims != size(A)`), and we simply propagate that result instead of trying to +# wrap a non-ArrayPartition in a NamedArrayPartition (which would hit the inner +# constructor signature `NamedArrayPartition(::A<:ArrayPartition, ::NamedTuple)`). function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} - return NamedArrayPartition( - similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices) - ) + inner = similar(getfield(A, :array_partition), dims) + inner isa ArrayPartition || return inner + return NamedArrayPartition(inner, getfield(A, :names_to_indices)) end # similar array partition of common type @@ -48,11 +53,10 @@ end ) end -# return ArrayPartition when possible, otherwise next best thing of the correct size function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N} - return NamedArrayPartition( - similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices) - ) + inner = similar(getfield(A, :array_partition), T, dims) + inner isa ArrayPartition || return inner + return NamedArrayPartition(inner, getfield(A, :names_to_indices)) end # similar array partition with different types @@ -96,6 +100,28 @@ Base.length(x::NamedArrayPartition) = length(ArrayPartition(x)) # Use concrete index types to avoid invalidating AbstractArray's generic setindex!. Base.@propagate_inbounds Base.getindex(x::NamedArrayPartition, i::Int) = ArrayPartition(x)[i] Base.@propagate_inbounds Base.setindex!(x::NamedArrayPartition, v, i::Int) = (ArrayPartition(x)[i] = v) + +# Indexing with non-scalar indices (UnitRange, Vector{Int}, etc.) goes through +# AbstractArray's generic path, which routes via `similar(A, T, dims)`. NAP's +# `similar(::NAP, T, dims)` cannot in general produce a NamedArrayPartition for +# arbitrary `dims` (the partition layout is fixed by `names_to_indices`), so it +# falls back to a plain Vector — making the inferred return type a small Union. +# +# Mirror ArrayPartition's `_unsafe_getindex` shortcut at `array_partition.jl:317`: +# allocate the destination directly off the first underlying array and fill it +# via `_unsafe_getindex!`. The result is always a Vector for non-scalar indexing, +# so `x[I]` is type-stable. This matches the v3 indexing semantics (`x[1:end]` +# returns a `Vector`, not a `NamedArrayPartition`); use `similar(x)` / +# `copy(x)` if you want a NamedArrayPartition back. +Base.@propagate_inbounds function Base._unsafe_getindex( + ::IndexStyle, A::NamedArrayPartition, + I::Vararg{Union{Real, AbstractArray}, N} + ) where {N} + shape = Base.index_shape(I...) + dest = similar(getfield(A, :array_partition).x[1], shape) + Base._unsafe_getindex!(dest, A, I...) + return dest +end function Base.map(f, x::NamedArrayPartition) return NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) end diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index 9a012635..1cc9f952 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -7,7 +7,7 @@ using RecursiveArrayTools, ArrayInterface, Test @test typeof(similar(x)) <: NamedArrayPartition @test typeof(similar(x, Int)) <: NamedArrayPartition @test x.a ≈ ones(10) - @test typeof(x .+ x[1:end]) <: NamedArrayPartition # x[1:end] preserves type + @test typeof(x .+ x[1:end]) <: Vector # x[1:end] is a plain Vector (type-stable slicing) @test all(x .== x[1:end]) @test ArrayInterface.zeromatrix(x) isa Matrix @test size(ArrayInterface.zeromatrix(x)) == (30, 30) @@ -37,3 +37,47 @@ using RecursiveArrayTools, ArrayInterface, Test @test typeof((x -> x[1]).(x)) <: NamedArrayPartition @test typeof(map(x -> x[1], x)) <: NamedArrayPartition end + +# Regression test for https://github.com/SciML/RecursiveArrayTools.jl/issues/583: +# indexing a NamedArrayPartition with a UnitRange / Vector{Int} smaller than the +# whole array used to throw a MethodError because the AbstractArray indexing +# path called `similar(::NAP, T, dims)`, which tried to wrap a plain Vector +# (returned by `similar(::ArrayPartition, T, dims)` for `dims != size(A)`) in +# NamedArrayPartition's inner constructor, which requires an ArrayPartition. +# +# The `_unsafe_getindex(::IndexStyle, ::NAP, I...)` shortcut bypasses `similar` +# entirely, allocating a plain Vector destination directly. Slicing therefore +# always returns a Vector and is type-stable. +@testset "NamedArrayPartition issue #583 indexing" begin + x = NamedArrayPartition(a = ones(2), b = 2 * ones(3)) + + # UnitRange / Vector{Int} indexing all return Vector and are type-stable + @test x[1:2] == [1.0, 1.0] + @test x[2:4] == [1.0, 2.0, 2.0] + @test x[1:end] == [1.0, 1.0, 2.0, 2.0, 2.0] + @test x[[1, 2]] == [1.0, 1.0] + @test x[[1, 4]] == [1.0, 2.0] + + @test x[1:2] isa Vector{Float64} + @test x[1:end] isa Vector{Float64} + @test x[[1, 4]] isa Vector{Float64} + + # Inferred return types: Vector, not Union + @test (@inferred x[1:2]) isa Vector{Float64} + @test (@inferred x[1:length(x)]) isa Vector{Float64} + @test (@inferred x[[1, 4]]) isa Vector{Float64} + + # `similar` with a non-matching dims falls back to the backing array; + # with matching dims keeps the NamedArrayPartition wrapper. + @test similar(x, Float64, (2,)) isa Vector{Float64} + @test similar(x, (2,)) isa Vector{Float64} + @test similar(x, Float64, size(x)) isa NamedArrayPartition + @test similar(x, size(x)) isa NamedArrayPartition + + # Scalar indexing untouched and type-stable + @test x[1] == 1.0 + @test x[3] == 2.0 + @test (@inferred x[1]) === 1.0 + x[1] = 99.0 + @test x[1] == 99.0 +end