Skip to content

Commit 12c7a0c

Browse files
committed
implement lastindex for ragged arrays
1 parent eb25df4 commit 12c7a0c

2 files changed

Lines changed: 149 additions & 8 deletions

File tree

src/vector_of_array.jl

Lines changed: 138 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
168168
VectorOfArray{T, N + 1, typeof(vec)}(vec)
169169
end
170170

171-
# allow multi-dimensional arrays as long as they're linearly indexed.
171+
# allow multi-dimensional arrays as long as they're linearly indexed.
172172
# currently restricted to arrays whose elements are all the same type
173173
function VectorOfArray(array::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
174174
@assert IndexStyle(typeof(array)) isa IndexLinear
@@ -402,6 +402,17 @@ function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
402402
return lastindex(VA.u)
403403
end
404404

405+
@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer)
406+
if d == ndims(VA)
407+
return lastindex(VA.u)
408+
elseif d < ndims(VA)
409+
isempty(VA.u) && return 0
410+
return _is_ragged_dim(VA, d) ? RaggedEnd{d}() : size(VA.u[1], d)
411+
else
412+
return 1
413+
end
414+
end
415+
405416
Base.getindex(A::AbstractVectorOfArray, I::Int) = A.u[I]
406417
Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) = A.u[I]
407418
Base.getindex(A::AbstractDiffEqArray, I::Int) = A.u[I]
@@ -417,6 +428,35 @@ Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) = A.u[I]
417428

418429
__parameterless_type(T) = Base.typename(T).wrapper
419430

431+
# `end` support for ragged inner arrays
432+
struct RaggedEnd{D}
433+
offset::Int
434+
end
435+
RaggedEnd{D}() where {D} = RaggedEnd{D}(0)
436+
437+
Base.:+(re::RaggedEnd{D}, n::Integer) where {D} = RaggedEnd{D}(re.offset + Int(n))
438+
Base.:-(re::RaggedEnd{D}, n::Integer) where {D} = RaggedEnd{D}(re.offset - Int(n))
439+
Base.:+(n::Integer, re::RaggedEnd{D}) where {D} = re + n
440+
441+
struct RaggedRange{D}
442+
start::Int
443+
step::Int
444+
stop::RaggedEnd{D}
445+
end
446+
447+
Base.:(:)(stop::RaggedEnd{D}) where {D} = RaggedRange{D}(1, 1, stop)
448+
Base.:(:)(start::Integer, stop::RaggedEnd{D}) where {D} = RaggedRange{D}(Int(start), 1, stop)
449+
Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd{D}) where {D} = RaggedRange{D}(Int(start), Int(step), stop)
450+
451+
@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
452+
length(VA.u) <= 1 && return false
453+
first_size = size(VA.u[1], d)
454+
@inbounds for idx in 2:length(VA.u)
455+
size(VA.u[idx], d) == first_size || return true
456+
end
457+
return false
458+
end
459+
420460
Base.@propagate_inbounds function _getindex(
421461
A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
422462
A.u[I]
@@ -487,11 +527,98 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
487527
return getindex(A, all_variable_symbols(A), args...)
488528
end
489529

530+
@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ? eachindex(VA.u) : idx
531+
@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool})
532+
findall(idx)
533+
end
534+
535+
@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
536+
@inline function _resolve_ragged_index(idx::RaggedEnd{D}, VA::AbstractVectorOfArray, col) where {D}
537+
return lastindex(VA.u[col], D) + idx.offset
538+
end
539+
@inline function _resolve_ragged_index(idx::RaggedRange{D}, VA::AbstractVectorOfArray, col) where {D}
540+
stop_val = _resolve_ragged_index(idx.stop, VA, col)
541+
return Base.range(idx.start; step = idx.step, stop = stop_val)
542+
end
543+
@inline function _resolve_ragged_index(idx::AbstractRange{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
544+
return Base.range(_resolve_ragged_index(first(idx), VA, col); step = step(idx),
545+
stop = _resolve_ragged_index(last(idx), VA, col))
546+
end
547+
@inline function _resolve_ragged_index(idx::Base.Slice, VA::AbstractVectorOfArray, col)
548+
return Base.Slice(_resolve_ragged_index(idx.indices, VA, col))
549+
end
550+
@inline function _resolve_ragged_index(idx::CartesianIndex, VA::AbstractVectorOfArray, col)
551+
return CartesianIndex(_resolve_ragged_indices(Tuple(idx), VA, col)...)
552+
end
553+
@inline function _resolve_ragged_index(idx::AbstractArray{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
554+
return map(i -> _resolve_ragged_index(i, VA, col), idx)
555+
end
556+
@inline function _resolve_ragged_index(idx::AbstractArray{<:RaggedRange}, VA::AbstractVectorOfArray, col)
557+
return map(i -> _resolve_ragged_index(i, VA, col), idx)
558+
end
559+
@inline function _resolve_ragged_index(idx::AbstractArray, VA::AbstractVectorOfArray, col)
560+
return _has_ragged_end(idx) ? map(i -> _resolve_ragged_index(i, VA, col), idx) : idx
561+
end
562+
563+
@inline function _resolve_ragged_indices(idxs::Tuple, VA::AbstractVectorOfArray, col)
564+
map(i -> _resolve_ragged_index(i, VA, col), idxs)
565+
end
566+
567+
@inline function _has_ragged_end(x)
568+
x isa RaggedEnd && return true
569+
x isa RaggedRange && return true
570+
x isa Base.Slice && return _has_ragged_end(x.indices)
571+
x isa CartesianIndex && return _has_ragged_end(Tuple(x))
572+
x isa AbstractRange && return eltype(x) <: Union{RaggedEnd, RaggedRange}
573+
if x isa AbstractArray
574+
el = eltype(x)
575+
return el <: Union{RaggedEnd, RaggedRange} || (el === Any && any(_has_ragged_end, x))
576+
end
577+
x isa Tuple && return any(_has_ragged_end, x)
578+
return false
579+
end
580+
@inline _has_ragged_end(x, xs...) = _has_ragged_end(x) || _has_ragged_end(xs)
581+
582+
@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
583+
cols = last(I)
584+
prefix = Base.front(I)
585+
if cols isa Int
586+
resolved = _resolve_ragged_indices(prefix, A, cols)
587+
return A.u[cols][resolved...]
588+
else
589+
col_idxs = _column_indices(A, cols)
590+
vals = map(col_idxs) do col
591+
resolved = _resolve_ragged_indices(prefix, A, col)
592+
A.u[col][resolved...]
593+
end
594+
return stack(vals)
595+
end
596+
end
597+
598+
@inline function _checkbounds_ragged(::Type{Bool}, VA::AbstractVectorOfArray, idxs...)
599+
cols = _column_indices(VA, last(idxs))
600+
prefix = Base.front(idxs)
601+
if cols isa Int
602+
resolved = _resolve_ragged_indices(prefix, VA, cols)
603+
return checkbounds(Bool, VA.u, cols) && checkbounds(Bool, VA.u[cols], resolved...)
604+
else
605+
for col in cols
606+
resolved = _resolve_ragged_indices(prefix, VA, col)
607+
checkbounds(Bool, VA.u, col) || return false
608+
checkbounds(Bool, VA.u[col], resolved...) || return false
609+
end
610+
return true
611+
end
612+
end
613+
490614
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
491615
symtype = symbolic_type(_arg)
492616
elsymtype = symbolic_type(eltype(_arg))
493617

494618
if symtype == NotSymbolic() && elsymtype == NotSymbolic()
619+
if _has_ragged_end(_arg, args...)
620+
return _ragged_getindex(A, _arg, args...)
621+
end
495622
if _arg isa Union{Tuple, AbstractArray} &&
496623
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
497624
_getindex(A, symtype, elsymtype, _arg, args...)
@@ -704,12 +831,18 @@ Base.ndims(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} = N
704831
function Base.checkbounds(
705832
::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
706833
idxs...) where {T, N}
834+
if _has_ragged_end(idxs...)
835+
return _checkbounds_ragged(Bool, VA, idxs...)
836+
end
707837
if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1)
708838
return checkbounds(Bool, VA.u, idxs[2])
709839
end
710840
return checkbounds(Bool, VA.u, idxs...)
711841
end
712842
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
843+
if _has_ragged_end(idx...)
844+
return _checkbounds_ragged(Bool, VA, idx...)
845+
end
713846
checkbounds(Bool, VA.u, last(idx)) || return false
714847
for i in last(idx)
715848
checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false
@@ -950,13 +1083,13 @@ end
9501083
# make vectorofarrays broadcastable so they aren't collected
9511084
Broadcast.broadcastable(x::AbstractVectorOfArray) = x
9521085

953-
# recurse through broadcast arguments and return a parent array for
1086+
# recurse through broadcast arguments and return a parent array for
9541087
# the first VoA or DiffEqArray in the bc arguments
9551088
function find_VoA_parent(args)
9561089
arg = Base.first(args)
9571090
if arg isa AbstractDiffEqArray
958-
# if first(args) is a DiffEqArray, use the underlying
959-
# field `u` of DiffEqArray as a parent array.
1091+
# if first(args) is a DiffEqArray, use the underlying
1092+
# field `u` of DiffEqArray as a parent array.
9601093
return arg.u
9611094
elseif arg isa AbstractVectorOfArray
9621095
return parent(arg)
@@ -975,7 +1108,7 @@ end
9751108
map(1:N) do i
9761109
copy(unpack_voa(bc, i))
9771110
end
978-
else # if parent isa AbstractArray
1111+
else # if parent isa AbstractArray
9791112
map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
9801113
copy(unpack_voa(bc, i))
9811114
end

test/basic_indexing.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ diffeq = DiffEqArray(recs, t)
145145
@test diffeq[:, 1] == recs[1]
146146
@test diffeq[1:2, 1:2] == [1 3; 2 5]
147147

148+
ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
149+
@test ragged[end, 1] == 2.0
150+
@test ragged[end, 2] == 5.0
151+
@test ragged[end, 3] == 9.0
152+
@test ragged[end - 1, 3] == 8.0
153+
@test ragged[1:end, 1] == [1.0, 2.0]
154+
@test ragged[1:end, 2] == [3.0, 4.0, 5.0]
155+
148156
# Test views of heterogeneous arrays (issue #453)
149157
f = VectorOfArray([[1.0], [2.0, 3.0]])
150158
@test length(view(f, :, 1)) == 1
@@ -259,14 +267,14 @@ a[1:8]
259267
a[[1, 3, 8]]
260268

261269
####################################################################
262-
# test when VectorOfArray is constructed from a linearly indexed
270+
# test when VectorOfArray is constructed from a linearly indexed
263271
# multidimensional array of arrays
264272
####################################################################
265273

266274
u_matrix = VectorOfArray([[1, 2] for i in 1:2, j in 1:3])
267275
u_vector = VectorOfArray([[1, 2] for i in 1:6])
268276

269-
# test broadcasting
277+
# test broadcasting
270278
function foo!(u)
271279
@. u += 2 * u * abs(u)
272280
return u
@@ -281,7 +289,7 @@ foo!(u_vector)
281289
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
282290
@test typeof(parent((x -> x).(u_matrix))) == typeof(parent(u_matrix))
283291

284-
# test efficiency
292+
# test efficiency
285293
num_allocs = @allocations foo!(u_matrix)
286294
@test num_allocs == 0
287295

0 commit comments

Comments
 (0)