Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 138 additions & 5 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
VectorOfArray{T, N + 1, typeof(vec)}(vec)
end

# allow multi-dimensional arrays as long as they're linearly indexed.
# allow multi-dimensional arrays as long as they're linearly indexed.
# currently restricted to arrays whose elements are all the same type
function VectorOfArray(array::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
@assert IndexStyle(typeof(array)) isa IndexLinear
Expand Down Expand Up @@ -402,6 +402,17 @@ function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
return lastindex(VA.u)
end

@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer)
if d == ndims(VA)
return lastindex(VA.u)
elseif d < ndims(VA)
isempty(VA.u) && return 0
return _is_ragged_dim(VA, d) ? RaggedEnd{d}() : size(VA.u[1], d)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not give inference issues?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm yes, might be. As I said, this was AI-generated. I'll see if I can fix that.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of using types if it's runtime, so enum then I think the strategy is fine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand. RaggedEnd should be an enum instead of a struct? If yes, what would be the values?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just 0 and interpret what zero means.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I only understand half of what is going on here, but aren't we loosing information (the d) when we return 0 instead of RaggedEnd{d}()? I tried making it work with the help of Claude, but there were always tests failing.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I see. The point though is to just make it be some kind of runtime value instead of a compile time value. So instead of a type with d, just like a tuple (true, d) for ragged true/false.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for the clarification! So, should Base.lastindex always return a tuple then? Also in the non-ragged case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a suggestion in 36561a0. This uses runtime values, but still returns RaggedEnd. As far as I understand we cannot return tuples because we need to overload + and - for that, which we obviously do not want to do for a general tuple.

else
return 1
end
end

Base.getindex(A::AbstractVectorOfArray, I::Int) = A.u[I]
Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) = A.u[I]
Base.getindex(A::AbstractDiffEqArray, I::Int) = A.u[I]
Expand All @@ -417,6 +428,35 @@ Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) = A.u[I]

__parameterless_type(T) = Base.typename(T).wrapper

# `end` support for ragged inner arrays
struct RaggedEnd{D}
offset::Int
end
RaggedEnd{D}() where {D} = RaggedEnd{D}(0)

Base.:+(re::RaggedEnd{D}, n::Integer) where {D} = RaggedEnd{D}(re.offset + Int(n))
Base.:-(re::RaggedEnd{D}, n::Integer) where {D} = RaggedEnd{D}(re.offset - Int(n))
Base.:+(n::Integer, re::RaggedEnd{D}) where {D} = re + n

struct RaggedRange{D}
start::Int
step::Int
stop::RaggedEnd{D}
end

Base.:(:)(stop::RaggedEnd{D}) where {D} = RaggedRange{D}(1, 1, stop)
Base.:(:)(start::Integer, stop::RaggedEnd{D}) where {D} = RaggedRange{D}(Int(start), 1, stop)
Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd{D}) where {D} = RaggedRange{D}(Int(start), Int(step), stop)

@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
length(VA.u) <= 1 && return false
first_size = size(VA.u[1], d)
@inbounds for idx in 2:length(VA.u)
size(VA.u[idx], d) == first_size || return true
end
return false
end

Base.@propagate_inbounds function _getindex(
A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
A.u[I]
Expand Down Expand Up @@ -487,11 +527,98 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
return getindex(A, all_variable_symbols(A), args...)
end

@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ? eachindex(VA.u) : idx
@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool})
findall(idx)
end

@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
@inline function _resolve_ragged_index(idx::RaggedEnd{D}, VA::AbstractVectorOfArray, col) where {D}
return lastindex(VA.u[col], D) + idx.offset
end
@inline function _resolve_ragged_index(idx::RaggedRange{D}, VA::AbstractVectorOfArray, col) where {D}
stop_val = _resolve_ragged_index(idx.stop, VA, col)
return Base.range(idx.start; step = idx.step, stop = stop_val)
end
@inline function _resolve_ragged_index(idx::AbstractRange{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
return Base.range(_resolve_ragged_index(first(idx), VA, col); step = step(idx),
stop = _resolve_ragged_index(last(idx), VA, col))
end
@inline function _resolve_ragged_index(idx::Base.Slice, VA::AbstractVectorOfArray, col)
return Base.Slice(_resolve_ragged_index(idx.indices, VA, col))
end
@inline function _resolve_ragged_index(idx::CartesianIndex, VA::AbstractVectorOfArray, col)
return CartesianIndex(_resolve_ragged_indices(Tuple(idx), VA, col)...)
end
@inline function _resolve_ragged_index(idx::AbstractArray{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
return map(i -> _resolve_ragged_index(i, VA, col), idx)
end
@inline function _resolve_ragged_index(idx::AbstractArray{<:RaggedRange}, VA::AbstractVectorOfArray, col)
return map(i -> _resolve_ragged_index(i, VA, col), idx)
end
@inline function _resolve_ragged_index(idx::AbstractArray, VA::AbstractVectorOfArray, col)
return _has_ragged_end(idx) ? map(i -> _resolve_ragged_index(i, VA, col), idx) : idx
end

@inline function _resolve_ragged_indices(idxs::Tuple, VA::AbstractVectorOfArray, col)
map(i -> _resolve_ragged_index(i, VA, col), idxs)
end

@inline function _has_ragged_end(x)
x isa RaggedEnd && return true
x isa RaggedRange && return true
x isa Base.Slice && return _has_ragged_end(x.indices)
x isa CartesianIndex && return _has_ragged_end(Tuple(x))
x isa AbstractRange && return eltype(x) <: Union{RaggedEnd, RaggedRange}
if x isa AbstractArray
el = eltype(x)
return el <: Union{RaggedEnd, RaggedRange} || (el === Any && any(_has_ragged_end, x))
end
x isa Tuple && return any(_has_ragged_end, x)
return false
end
@inline _has_ragged_end(x, xs...) = _has_ragged_end(x) || _has_ragged_end(xs)

@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
cols = last(I)
prefix = Base.front(I)
if cols isa Int
resolved = _resolve_ragged_indices(prefix, A, cols)
return A.u[cols][resolved...]
else
col_idxs = _column_indices(A, cols)
vals = map(col_idxs) do col
resolved = _resolve_ragged_indices(prefix, A, col)
A.u[col][resolved...]
end
return stack(vals)
end
end

@inline function _checkbounds_ragged(::Type{Bool}, VA::AbstractVectorOfArray, idxs...)
cols = _column_indices(VA, last(idxs))
prefix = Base.front(idxs)
if cols isa Int
resolved = _resolve_ragged_indices(prefix, VA, cols)
return checkbounds(Bool, VA.u, cols) && checkbounds(Bool, VA.u[cols], resolved...)
else
for col in cols
resolved = _resolve_ragged_indices(prefix, VA, col)
checkbounds(Bool, VA.u, col) || return false
checkbounds(Bool, VA.u[col], resolved...) || return false
end
return true
end
end

Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
symtype = symbolic_type(_arg)
elsymtype = symbolic_type(eltype(_arg))

if symtype == NotSymbolic() && elsymtype == NotSymbolic()
if _has_ragged_end(_arg, args...)
return _ragged_getindex(A, _arg, args...)
end
if _arg isa Union{Tuple, AbstractArray} &&
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
Expand Down Expand Up @@ -704,12 +831,18 @@ Base.ndims(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} = N
function Base.checkbounds(
::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
idxs...) where {T, N}
if _has_ragged_end(idxs...)
return _checkbounds_ragged(Bool, VA, idxs...)
end
if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1)
return checkbounds(Bool, VA.u, idxs[2])
end
return checkbounds(Bool, VA.u, idxs...)
end
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
if _has_ragged_end(idx...)
return _checkbounds_ragged(Bool, VA, idx...)
end
checkbounds(Bool, VA.u, last(idx)) || return false
for i in last(idx)
checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false
Expand Down Expand Up @@ -950,13 +1083,13 @@ end
# make vectorofarrays broadcastable so they aren't collected
Broadcast.broadcastable(x::AbstractVectorOfArray) = x

# recurse through broadcast arguments and return a parent array for
# recurse through broadcast arguments and return a parent array for
# the first VoA or DiffEqArray in the bc arguments
function find_VoA_parent(args)
arg = Base.first(args)
if arg isa AbstractDiffEqArray
# if first(args) is a DiffEqArray, use the underlying
# field `u` of DiffEqArray as a parent array.
# if first(args) is a DiffEqArray, use the underlying
# field `u` of DiffEqArray as a parent array.
return arg.u
elseif arg isa AbstractVectorOfArray
return parent(arg)
Expand All @@ -975,7 +1108,7 @@ end
map(1:N) do i
copy(unpack_voa(bc, i))
end
else # if parent isa AbstractArray
else # if parent isa AbstractArray
map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
copy(unpack_voa(bc, i))
end
Expand Down
33 changes: 30 additions & 3 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,33 @@ f2 = VectorOfArray([[1.0, 2.0], [3.0]])
@test collect(view(f2, :, 1)) == f2[:, 1]
@test collect(view(f2, :, 2)) == f2[:, 2]

# Test `end` with ragged arrays
ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
@test ragged[end, 1] == 2.0
@test ragged[end, 2] == 5.0
@test ragged[end, 3] == 9.0
@test ragged[end - 1, 1] == 1.0
@test ragged[end - 1, 2] == 4.0
@test ragged[end - 1, 3] == 8.0
@test ragged[1:end, 1] == [1.0, 2.0]
@test ragged[1:end, 2] == [3.0, 4.0, 5.0]
@test ragged[1:end, 3] == [6.0, 7.0, 8.0, 9.0]

ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
@test ragged2[end, 1] == 4.0
@test ragged2[end, 2] == 6.0
@test ragged2[end, 3] == 9.0
@test ragged2[end - 1, 1] == 3.0
@test ragged2[end - 1, 2] == 5.0
@test ragged2[end - 1, 3] == 8.0
@test ragged2[end - 2, 1] == 2.0
@test ragged2[1:end, 1] == [1.0, 2.0, 3.0, 4.0]
@test ragged2[1:end, 2] == [5.0, 6.0]
@test ragged2[1:end, 3] == [7.0, 8.0, 9.0]
@test ragged2[2:end, 1] == [2.0, 3.0, 4.0]
@test ragged2[2:end, 2] == [6.0]
@test ragged2[2:end, 3] == [8.0, 9.0]

# Test that views can be modified
f3 = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0]])
v = view(f3, :, 2)
Expand Down Expand Up @@ -259,14 +286,14 @@ a[1:8]
a[[1, 3, 8]]

####################################################################
# test when VectorOfArray is constructed from a linearly indexed
# test when VectorOfArray is constructed from a linearly indexed
# multidimensional array of arrays
####################################################################

u_matrix = VectorOfArray([[1, 2] for i in 1:2, j in 1:3])
u_vector = VectorOfArray([[1, 2] for i in 1:6])

# test broadcasting
# test broadcasting
function foo!(u)
@. u += 2 * u * abs(u)
return u
Expand All @@ -281,7 +308,7 @@ foo!(u_vector)
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x -> x).(u_matrix))) == typeof(parent(u_matrix))

# test efficiency
# test efficiency
num_allocs = @allocations foo!(u_matrix)
@test num_allocs == 0

Expand Down