@@ -168,7 +168,7 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
168168 VectorOfArray {T, N + 1, typeof(vec)} (vec)
169169end
170170
171- # allow multi-dimensional arrays as long as they're linearly indexed.
171+ # allow multi-dimensional arrays as long as they're linearly indexed.
172172# currently restricted to arrays whose elements are all the same type
173173function VectorOfArray (array:: AbstractArray{AT} ) where {T, N, AT <: AbstractArray{T, N} }
174174 @assert IndexStyle (typeof (array)) isa IndexLinear
@@ -402,6 +402,17 @@ function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
402402 return lastindex (VA. u)
403403end
404404
405+ @inline function Base. lastindex (VA:: AbstractVectorOfArray , d:: Integer )
406+ if d == ndims (VA)
407+ return lastindex (VA. u)
408+ elseif d < ndims (VA)
409+ isempty (VA. u) && return 0
410+ return _is_ragged_dim (VA, d) ? RaggedEnd {d} () : size (VA. u[1 ], d)
411+ else
412+ return 1
413+ end
414+ end
415+
405416Base. getindex (A:: AbstractVectorOfArray , I:: Int ) = A. u[I]
406417Base. getindex (A:: AbstractVectorOfArray , I:: AbstractArray{Int} ) = A. u[I]
407418Base. getindex (A:: AbstractDiffEqArray , I:: Int ) = A. u[I]
@@ -417,6 +428,35 @@ Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) = A.u[I]
417428
418429__parameterless_type (T) = Base. typename (T). wrapper
419430
431+ # `end` support for ragged inner arrays
432+ struct RaggedEnd{D}
433+ offset:: Int
434+ end
435+ RaggedEnd {D} () where {D} = RaggedEnd {D} (0 )
436+
437+ Base.:+ (re:: RaggedEnd{D} , n:: Integer ) where {D} = RaggedEnd {D} (re. offset + Int (n))
438+ Base.:- (re:: RaggedEnd{D} , n:: Integer ) where {D} = RaggedEnd {D} (re. offset - Int (n))
439+ Base.:+ (n:: Integer , re:: RaggedEnd{D} ) where {D} = re + n
440+
441+ struct RaggedRange{D}
442+ start:: Int
443+ step:: Int
444+ stop:: RaggedEnd{D}
445+ end
446+
447+ Base.:(:)(stop:: RaggedEnd{D} ) where {D} = RaggedRange {D} (1 , 1 , stop)
448+ Base.:(:)(start:: Integer , stop:: RaggedEnd{D} ) where {D} = RaggedRange {D} (Int (start), 1 , stop)
449+ Base.:(:)(start:: Integer , step:: Integer , stop:: RaggedEnd{D} ) where {D} = RaggedRange {D} (Int (start), Int (step), stop)
450+
451+ @inline function _is_ragged_dim (VA:: AbstractVectorOfArray , d:: Integer )
452+ length (VA. u) <= 1 && return false
453+ first_size = size (VA. u[1 ], d)
454+ @inbounds for idx in 2 : length (VA. u)
455+ size (VA. u[idx], d) == first_size || return true
456+ end
457+ return false
458+ end
459+
420460Base. @propagate_inbounds function _getindex (
421461 A:: AbstractVectorOfArray , :: NotSymbolic , :: Colon , I:: Int )
422462 A. u[I]
@@ -487,11 +527,98 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
487527 return getindex (A, all_variable_symbols (A), args... )
488528end
489529
530+ @inline _column_indices (VA:: AbstractVectorOfArray , idx) = idx === Colon () ? eachindex (VA. u) : idx
531+ @inline function _column_indices (VA:: AbstractVectorOfArray , idx:: AbstractArray{Bool} )
532+ findall (idx)
533+ end
534+
535+ @inline _resolve_ragged_index (idx, :: AbstractVectorOfArray , :: Any ) = idx
536+ @inline function _resolve_ragged_index (idx:: RaggedEnd{D} , VA:: AbstractVectorOfArray , col) where {D}
537+ return lastindex (VA. u[col], D) + idx. offset
538+ end
539+ @inline function _resolve_ragged_index (idx:: RaggedRange{D} , VA:: AbstractVectorOfArray , col) where {D}
540+ stop_val = _resolve_ragged_index (idx. stop, VA, col)
541+ return Base. range (idx. start; step = idx. step, stop = stop_val)
542+ end
543+ @inline function _resolve_ragged_index (idx:: AbstractRange{<:RaggedEnd} , VA:: AbstractVectorOfArray , col)
544+ return Base. range (_resolve_ragged_index (first (idx), VA, col); step = step (idx),
545+ stop = _resolve_ragged_index (last (idx), VA, col))
546+ end
547+ @inline function _resolve_ragged_index (idx:: Base.Slice , VA:: AbstractVectorOfArray , col)
548+ return Base. Slice (_resolve_ragged_index (idx. indices, VA, col))
549+ end
550+ @inline function _resolve_ragged_index (idx:: CartesianIndex , VA:: AbstractVectorOfArray , col)
551+ return CartesianIndex (_resolve_ragged_indices (Tuple (idx), VA, col)... )
552+ end
553+ @inline function _resolve_ragged_index (idx:: AbstractArray{<:RaggedEnd} , VA:: AbstractVectorOfArray , col)
554+ return map (i -> _resolve_ragged_index (i, VA, col), idx)
555+ end
556+ @inline function _resolve_ragged_index (idx:: AbstractArray{<:RaggedRange} , VA:: AbstractVectorOfArray , col)
557+ return map (i -> _resolve_ragged_index (i, VA, col), idx)
558+ end
559+ @inline function _resolve_ragged_index (idx:: AbstractArray , VA:: AbstractVectorOfArray , col)
560+ return _has_ragged_end (idx) ? map (i -> _resolve_ragged_index (i, VA, col), idx) : idx
561+ end
562+
563+ @inline function _resolve_ragged_indices (idxs:: Tuple , VA:: AbstractVectorOfArray , col)
564+ map (i -> _resolve_ragged_index (i, VA, col), idxs)
565+ end
566+
567+ @inline function _has_ragged_end (x)
568+ x isa RaggedEnd && return true
569+ x isa RaggedRange && return true
570+ x isa Base. Slice && return _has_ragged_end (x. indices)
571+ x isa CartesianIndex && return _has_ragged_end (Tuple (x))
572+ x isa AbstractRange && return eltype (x) <: Union{RaggedEnd, RaggedRange}
573+ if x isa AbstractArray
574+ el = eltype (x)
575+ return el <: Union{RaggedEnd, RaggedRange} || (el === Any && any (_has_ragged_end, x))
576+ end
577+ x isa Tuple && return any (_has_ragged_end, x)
578+ return false
579+ end
580+ @inline _has_ragged_end (x, xs... ) = _has_ragged_end (x) || _has_ragged_end (xs)
581+
582+ @inline function _ragged_getindex (A:: AbstractVectorOfArray , I... )
583+ cols = last (I)
584+ prefix = Base. front (I)
585+ if cols isa Int
586+ resolved = _resolve_ragged_indices (prefix, A, cols)
587+ return A. u[cols][resolved... ]
588+ else
589+ col_idxs = _column_indices (A, cols)
590+ vals = map (col_idxs) do col
591+ resolved = _resolve_ragged_indices (prefix, A, col)
592+ A. u[col][resolved... ]
593+ end
594+ return stack (vals)
595+ end
596+ end
597+
598+ @inline function _checkbounds_ragged (:: Type{Bool} , VA:: AbstractVectorOfArray , idxs... )
599+ cols = _column_indices (VA, last (idxs))
600+ prefix = Base. front (idxs)
601+ if cols isa Int
602+ resolved = _resolve_ragged_indices (prefix, VA, cols)
603+ return checkbounds (Bool, VA. u, cols) && checkbounds (Bool, VA. u[cols], resolved... )
604+ else
605+ for col in cols
606+ resolved = _resolve_ragged_indices (prefix, VA, col)
607+ checkbounds (Bool, VA. u, col) || return false
608+ checkbounds (Bool, VA. u[col], resolved... ) || return false
609+ end
610+ return true
611+ end
612+ end
613+
490614Base. @propagate_inbounds function Base. getindex (A:: AbstractVectorOfArray , _arg, args... )
491615 symtype = symbolic_type (_arg)
492616 elsymtype = symbolic_type (eltype (_arg))
493617
494618 if symtype == NotSymbolic () && elsymtype == NotSymbolic ()
619+ if _has_ragged_end (_arg, args... )
620+ return _ragged_getindex (A, _arg, args... )
621+ end
495622 if _arg isa Union{Tuple, AbstractArray} &&
496623 any (x -> symbolic_type (x) != NotSymbolic (), _arg)
497624 _getindex (A, symtype, elsymtype, _arg, args... )
@@ -704,12 +831,18 @@ Base.ndims(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} = N
704831function Base. checkbounds (
705832 :: Type{Bool} , VA:: AbstractVectorOfArray{T, N, <:AbstractVector{T}} ,
706833 idxs... ) where {T, N}
834+ if _has_ragged_end (idxs... )
835+ return _checkbounds_ragged (Bool, VA, idxs... )
836+ end
707837 if length (idxs) == 2 && (idxs[1 ] == Colon () || idxs[1 ] == 1 )
708838 return checkbounds (Bool, VA. u, idxs[2 ])
709839 end
710840 return checkbounds (Bool, VA. u, idxs... )
711841end
712842function Base. checkbounds (:: Type{Bool} , VA:: AbstractVectorOfArray , idx... )
843+ if _has_ragged_end (idx... )
844+ return _checkbounds_ragged (Bool, VA, idx... )
845+ end
713846 checkbounds (Bool, VA. u, last (idx)) || return false
714847 for i in last (idx)
715848 checkbounds (Bool, VA. u[i], Base. front (idx)... ) || return false
@@ -950,13 +1083,13 @@ end
9501083# make vectorofarrays broadcastable so they aren't collected
9511084Broadcast. broadcastable (x:: AbstractVectorOfArray ) = x
9521085
953- # recurse through broadcast arguments and return a parent array for
1086+ # recurse through broadcast arguments and return a parent array for
9541087# the first VoA or DiffEqArray in the bc arguments
9551088function find_VoA_parent (args)
9561089 arg = Base. first (args)
9571090 if arg isa AbstractDiffEqArray
958- # if first(args) is a DiffEqArray, use the underlying
959- # field `u` of DiffEqArray as a parent array.
1091+ # if first(args) is a DiffEqArray, use the underlying
1092+ # field `u` of DiffEqArray as a parent array.
9601093 return arg. u
9611094 elseif arg isa AbstractVectorOfArray
9621095 return parent (arg)
9751108 map (1 : N) do i
9761109 copy (unpack_voa (bc, i))
9771110 end
978- else # if parent isa AbstractArray
1111+ else # if parent isa AbstractArray
9791112 map (enumerate (Iterators. product (axes (parent)... ))) do (i, _)
9801113 copy (unpack_voa (bc, i))
9811114 end
0 commit comments