Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 141 additions & 5 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
VectorOfArray{T, N + 1, typeof(vec)}(vec)
end

# allow multi-dimensional arrays as long as they're linearly indexed.
# allow multi-dimensional arrays as long as they're linearly indexed.
# currently restricted to arrays whose elements are all the same type
function VectorOfArray(array::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
@assert IndexStyle(typeof(array)) isa IndexLinear
Expand Down Expand Up @@ -402,6 +402,17 @@ function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
return lastindex(VA.u)
end

@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer)
if d == ndims(VA)
return lastindex(VA.u)
elseif d < ndims(VA)
isempty(VA.u) && return 0
return RaggedEnd(Int(d))
else
return 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type unstable, add JET testing to this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's also why I asked

So, should Base.lastindex always return a tuple then? Also in the non-ragged case?

above. In 7d9e008, I now let lastindex always return a RaggedEnd to make lastindex type stable (because again, I think we need to return an object of a custom struct for the ragged case). For non-ragged arrays, this is now encoded in RaggedEnd as dim = 0. This required some special handling for this case to behave the same as before. AI helped me again for that. I also added some more tests.

end
end

Base.getindex(A::AbstractVectorOfArray, I::Int) = A.u[I]
Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) = A.u[I]
Base.getindex(A::AbstractDiffEqArray, I::Int) = A.u[I]
Expand All @@ -417,6 +428,38 @@ Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) = A.u[I]

__parameterless_type(T) = Base.typename(T).wrapper

# `end` support for ragged inner arrays
# Use runtime fields instead of type parameters for type stability
struct RaggedEnd
dim::Int
offset::Int
end
RaggedEnd(dim::Int) = RaggedEnd(dim, 0)

Base.:+(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset + Int(n))
Base.:-(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset - Int(n))
Base.:+(n::Integer, re::RaggedEnd) = re + n

struct RaggedRange
dim::Int
start::Int
step::Int
offset::Int
end

Base.:(:)(stop::RaggedEnd) = RaggedRange(stop.dim, 1, 1, stop.offset)
Base.:(:)(start::Integer, stop::RaggedEnd) = RaggedRange(stop.dim, Int(start), 1, stop.offset)
Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd) = RaggedRange(stop.dim, Int(start), Int(step), stop.offset)

@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
length(VA.u) <= 1 && return false
first_size = size(VA.u[1], d)
@inbounds for idx in 2:length(VA.u)
size(VA.u[idx], d) == first_size || return true
end
return false
end

Base.@propagate_inbounds function _getindex(
A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
A.u[I]
Expand Down Expand Up @@ -487,11 +530,98 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
return getindex(A, all_variable_symbols(A), args...)
end

@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ? eachindex(VA.u) : idx
@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool})
findall(idx)
end

@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col)
return lastindex(VA.u[col], idx.dim) + idx.offset
end
@inline function _resolve_ragged_index(idx::RaggedRange, VA::AbstractVectorOfArray, col)
stop_val = lastindex(VA.u[col], idx.dim) + idx.offset
return Base.range(idx.start; step = idx.step, stop = stop_val)
end
@inline function _resolve_ragged_index(idx::AbstractRange{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
return Base.range(_resolve_ragged_index(first(idx), VA, col); step = step(idx),
stop = _resolve_ragged_index(last(idx), VA, col))
end
@inline function _resolve_ragged_index(idx::Base.Slice, VA::AbstractVectorOfArray, col)
return Base.Slice(_resolve_ragged_index(idx.indices, VA, col))
end
@inline function _resolve_ragged_index(idx::CartesianIndex, VA::AbstractVectorOfArray, col)
return CartesianIndex(_resolve_ragged_indices(Tuple(idx), VA, col)...)
end
@inline function _resolve_ragged_index(idx::AbstractArray{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
return map(i -> _resolve_ragged_index(i, VA, col), idx)
end
@inline function _resolve_ragged_index(idx::AbstractArray{<:RaggedRange}, VA::AbstractVectorOfArray, col)
return map(i -> _resolve_ragged_index(i, VA, col), idx)
end
@inline function _resolve_ragged_index(idx::AbstractArray, VA::AbstractVectorOfArray, col)
return _has_ragged_end(idx) ? map(i -> _resolve_ragged_index(i, VA, col), idx) : idx
end

@inline function _resolve_ragged_indices(idxs::Tuple, VA::AbstractVectorOfArray, col)
map(i -> _resolve_ragged_index(i, VA, col), idxs)
end

@inline function _has_ragged_end(x)
x isa RaggedEnd && return true
x isa RaggedRange && return true
x isa Base.Slice && return _has_ragged_end(x.indices)
x isa CartesianIndex && return _has_ragged_end(Tuple(x))
x isa AbstractRange && return eltype(x) <: Union{RaggedEnd, RaggedRange}
if x isa AbstractArray
el = eltype(x)
return el <: Union{RaggedEnd, RaggedRange} || (el === Any && any(_has_ragged_end, x))
end
x isa Tuple && return any(_has_ragged_end, x)
return false
end
@inline _has_ragged_end(x, xs...) = _has_ragged_end(x) || _has_ragged_end(xs)

@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
cols = last(I)
prefix = Base.front(I)
if cols isa Int
resolved = _resolve_ragged_indices(prefix, A, cols)
return A.u[cols][resolved...]
else
col_idxs = _column_indices(A, cols)
vals = map(col_idxs) do col
resolved = _resolve_ragged_indices(prefix, A, col)
A.u[col][resolved...]
end
return stack(vals)
end
end

@inline function _checkbounds_ragged(::Type{Bool}, VA::AbstractVectorOfArray, idxs...)
cols = _column_indices(VA, last(idxs))
prefix = Base.front(idxs)
if cols isa Int
resolved = _resolve_ragged_indices(prefix, VA, cols)
return checkbounds(Bool, VA.u, cols) && checkbounds(Bool, VA.u[cols], resolved...)
else
for col in cols
resolved = _resolve_ragged_indices(prefix, VA, col)
checkbounds(Bool, VA.u, col) || return false
checkbounds(Bool, VA.u[col], resolved...) || return false
end
return true
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
symtype = symbolic_type(_arg)
elsymtype = symbolic_type(eltype(_arg))

if symtype == NotSymbolic() && elsymtype == NotSymbolic()
if _has_ragged_end(_arg, args...)
return _ragged_getindex(A, _arg, args...)
end
if _arg isa Union{Tuple, AbstractArray} &&
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
Expand Down Expand Up @@ -704,12 +834,18 @@ Base.ndims(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} = N
function Base.checkbounds(
::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
idxs...) where {T, N}
if _has_ragged_end(idxs...)
return _checkbounds_ragged(Bool, VA, idxs...)
end
if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1)
return checkbounds(Bool, VA.u, idxs[2])
end
return checkbounds(Bool, VA.u, idxs...)
end
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
if _has_ragged_end(idx...)
return _checkbounds_ragged(Bool, VA, idx...)
end
checkbounds(Bool, VA.u, last(idx)) || return false
for i in last(idx)
checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false
Expand Down Expand Up @@ -950,13 +1086,13 @@ end
# make vectorofarrays broadcastable so they aren't collected
Broadcast.broadcastable(x::AbstractVectorOfArray) = x

# recurse through broadcast arguments and return a parent array for
# recurse through broadcast arguments and return a parent array for
# the first VoA or DiffEqArray in the bc arguments
function find_VoA_parent(args)
arg = Base.first(args)
if arg isa AbstractDiffEqArray
# if first(args) is a DiffEqArray, use the underlying
# field `u` of DiffEqArray as a parent array.
# if first(args) is a DiffEqArray, use the underlying
# field `u` of DiffEqArray as a parent array.
return arg.u
elseif arg isa AbstractVectorOfArray
return parent(arg)
Expand All @@ -975,7 +1111,7 @@ end
map(1:N) do i
copy(unpack_voa(bc, i))
end
else # if parent isa AbstractArray
else # if parent isa AbstractArray
map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
copy(unpack_voa(bc, i))
end
Expand Down
36 changes: 33 additions & 3 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,36 @@ f2 = VectorOfArray([[1.0, 2.0], [3.0]])
@test collect(view(f2, :, 1)) == f2[:, 1]
@test collect(view(f2, :, 2)) == f2[:, 2]

# Test `end` with ragged arrays
ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
@test ragged[end, 1] == 2.0
@test ragged[end, 2] == 5.0
@test ragged[end, 3] == 9.0
@test ragged[end - 1, 1] == 1.0
@test ragged[end - 1, 2] == 4.0
@test ragged[end - 1, 3] == 8.0
@test ragged[1:end, 1] == [1.0, 2.0]
@test ragged[1:end, 2] == [3.0, 4.0, 5.0]
@test ragged[1:end, 3] == [6.0, 7.0, 8.0, 9.0]

ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
@test ragged2[end, 1] == 4.0
@test ragged2[end, 2] == 6.0
@test ragged2[end, 3] == 9.0
@test ragged2[end - 1, 1] == 3.0
@test ragged2[end - 1, 2] == 5.0
@test ragged2[end - 1, 3] == 8.0
@test ragged2[end - 2, 1] == 2.0
@test ragged2[1:end, 1] == [1.0, 2.0, 3.0, 4.0]
@test ragged2[1:end, 2] == [5.0, 6.0]
@test ragged2[1:end, 3] == [7.0, 8.0, 9.0]
@test ragged2[2:end, 1] == [2.0, 3.0, 4.0]
@test ragged2[2:end, 2] == [6.0]
@test ragged2[2:end, 3] == [8.0, 9.0]
@test ragged2[1:(end - 1), 1] == [1.0, 2.0, 3.0]
@test ragged2[1:(end - 1), 2] == [5.0]
@test ragged2[1:(end - 1), 3] == [7.0, 8.0]

# Test that views can be modified
f3 = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0]])
v = view(f3, :, 2)
Expand Down Expand Up @@ -259,14 +289,14 @@ a[1:8]
a[[1, 3, 8]]

####################################################################
# test when VectorOfArray is constructed from a linearly indexed
# test when VectorOfArray is constructed from a linearly indexed
# multidimensional array of arrays
####################################################################

u_matrix = VectorOfArray([[1, 2] for i in 1:2, j in 1:3])
u_vector = VectorOfArray([[1, 2] for i in 1:6])

# test broadcasting
# test broadcasting
function foo!(u)
@. u += 2 * u * abs(u)
return u
Expand All @@ -281,7 +311,7 @@ foo!(u_vector)
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x -> x).(u_matrix))) == typeof(parent(u_matrix))

# test efficiency
# test efficiency
num_allocs = @allocations foo!(u_matrix)
@test num_allocs == 0

Expand Down
Loading