Skip to content

Commit 1c44183

Browse files
Remove redundant AbstractArray overrides, fix Zygote adjoints
Invalidation analysis and cleanup: - Remove IndexStyle instance method (type method sufficient) - Remove size(VA, i), lastindex(VA, d) (inherited from AbstractArray) - Remove checkbounds override (inherited from AbstractArray via size) - Remove isassigned, isapprox, CartesianIndices, adjoint overrides - Remove reshape, vec, convert(Array, ...), maybeview overrides - Remove +, -, *, / operator overrides (use broadcasting) - Remove 2-arg show for AbstractVectorOfArray (use AbstractArray display) Fix Zygote extension: - Remove all getindex/view adjoint overrides (Zygote's AbstractArray rules apply) - Fix VectorOfArray(u) adjoint to return .u (plain Vector) not VectorOfArray - Fix DiffEqArray(u, t) adjoint similarly - All 12 adjoint tests now pass (was 4 pass + 8 broken) Invalidation trees: 8 total, all minimal (max 20 mt_backedges from Colon(::Integer, ::RaggedEnd) which is inherent to the type) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8485bda commit 1c44183

4 files changed

Lines changed: 55 additions & 206 deletions

File tree

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 26 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -5,67 +5,15 @@ using RecursiveArrayTools
55
using Zygote
66
using Zygote: FillArrays, ChainRulesCore, literal_getproperty, @adjoint
77

8-
# Define a new species of projection operator for this type:
9-
# ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
10-
118
function ChainRulesCore.rrule(
129
T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
1310
xs::AbstractVectorOfArray
1411
)
15-
return T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
16-
end
17-
18-
@adjoint function getindex(
19-
VA::AbstractVectorOfArray,
20-
i::Union{BitArray, AbstractArray{Bool}}
21-
)
22-
function AbstractVectorOfArray_getindex_adjoint(Δ)
23-
Δ′ = [
24-
(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
25-
for (x, j) in zip(VA.u, 1:length(VA))
26-
]
27-
(VectorOfArray(Δ′), nothing)
28-
end
29-
VA[:, i], AbstractVectorOfArray_getindex_adjoint
30-
end
31-
32-
@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
33-
function AbstractVectorOfArray_getindex_adjoint(Δ)
34-
iter = 0
35-
Δ′ = [
36-
(j i ? Δ[iter += 1] : FillArrays.Fill(zero(eltype(x)), size(x)))
37-
for (x, j) in zip(VA.u, 1:length(VA))
38-
]
39-
(VectorOfArray(Δ′), nothing)
40-
end
41-
VA[:, i], AbstractVectorOfArray_getindex_adjoint
42-
end
43-
44-
@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon)
45-
function AbstractVectorOfArray_getindex_adjoint(Δ)
46-
(VectorOfArray(Δ), nothing)
47-
end
48-
VA.u[i], AbstractVectorOfArray_getindex_adjoint
12+
return T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
4913
end
5014

51-
@adjoint function getindex(
52-
VA::AbstractVectorOfArray, i::Int,
53-
j::Union{
54-
Int, AbstractArray{Int}, CartesianIndex,
55-
Colon, BitArray, AbstractArray{Bool},
56-
}...
57-
)
58-
function AbstractVectorOfArray_getindex_adjoint(Δ)
59-
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
60-
if isempty(j)
61-
Δ′.u[i] = Δ
62-
else
63-
Δ′[i, j...] = Δ
64-
end
65-
(Δ′, nothing, map(_ -> nothing, j)...)
66-
end
67-
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
68-
end
15+
# getindex adjoints are inherited from Zygote's AbstractArray rules
16+
# since AbstractVectorOfArray <: AbstractArray
6917

7018
@adjoint function ArrayPartition(
7119
x::S,
@@ -88,15 +36,20 @@ end
8836
@adjoint function VectorOfArray(u)
8937
VectorOfArray(u),
9038
y -> begin
91-
y isa Ref && (y = VectorOfArray(y[].u))
92-
(
93-
VectorOfArray(
94-
[
39+
if y isa Ref
40+
y = VectorOfArray(y[].u)
41+
end
42+
# Return a plain Vector of arrays as gradient for `u`, not wrapped in VectorOfArray.
43+
# This avoids issues with downstream pullbacks that index into the gradient
44+
# using linear indexing (which now returns scalar elements for VectorOfArray).
45+
if y isa AbstractVectorOfArray
46+
(y.u,)
47+
else
48+
([
9549
y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
9650
for i in 1:size(y)[end]
97-
]
98-
),
99-
)
51+
],)
52+
end
10053
end
10154
end
10255

@@ -108,17 +61,19 @@ end
10861
@adjoint function DiffEqArray(u, t)
10962
DiffEqArray(u, t),
11063
y -> begin
111-
y isa Ref && (y = VectorOfArray(y[].u))
112-
(
113-
DiffEqArray(
114-
[
64+
if y isa Ref
65+
y = VectorOfArray(y[].u)
66+
end
67+
if y isa AbstractVectorOfArray
68+
(y.u, nothing)
69+
else
70+
([
11571
y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
11672
for i in 1:size(y)[end]
11773
],
118-
t
119-
),
120-
nothing,
121-
)
74+
nothing,
75+
)
76+
end
12277
end
12378
end
12479

@@ -156,6 +111,7 @@ end
156111
@adjoint function Base.Array(VA::AbstractVectorOfArray)
157112
adj = let VA = VA
158113
function Array_adjoint(y)
114+
# Return a VectorOfArray so it flows correctly back through VectorOfArray constructor
159115
VA = recursivecopy(VA)
160116
copyto!(VA, y)
161117
return (VA,)
@@ -164,44 +120,4 @@ end
164120
Array(VA), adj
165121
end
166122

167-
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
168-
view_adjoint = let A = A, I = I
169-
function (y)
170-
A = recursivecopy(A)
171-
copyto!(A, y)
172-
return (A, map(_ -> nothing, I)...)
173-
end
174-
end
175-
return view(A, I...), view_adjoint
176-
end
177-
178-
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
179-
view_adjoint = let A = A, I = I
180-
function (y)
181-
A = recursivecopy(A)
182-
recursivefill!(A, zero(eltype(A)))
183-
v = view(A, I...)
184-
copyto!(v, y)
185-
return (A, map(_ -> nothing, I)...)
186-
end
187-
end
188-
view(A, I...), view_adjoint
189-
end
190-
191-
# Since AbstractVectorOfArray <: AbstractArray, Zygote's built-in AbstractArray
192-
# broadcast rules apply. We only keep specific overrides that don't conflict.
193-
194-
_minus(Δ) = .-Δ
195-
_minus(::Nothing) = nothing
196-
197-
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
198-
N = ndims(x̄)
199-
return if length(x) == length(x̄)
200-
Zygote._project(x, x̄)
201-
else
202-
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄) + 1, ndims(x̄))
203-
Zygote._project(x, Zygote.accum_sum(x̄; dims = dims))
204-
end
205-
end
206-
207123
end # module

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,10 @@ module RecursiveArrayTools
129129
include("array_partition.jl")
130130
include("named_array_partition.jl")
131131

132-
function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
132+
function Base.show(io::IO, x::ArrayPartition)
133133
return invoke(show, Tuple{typeof(io), Any}, io, x)
134134
end
135+
# AbstractVectorOfArray uses AbstractArray's show
135136

136137
import GPUArraysCore
137138
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)

src/vector_of_array.jl

Lines changed: 19 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,9 @@ function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::Abstra
504504
return get_discretes(A)
505505
end
506506

507-
Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A))
508507
Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()
509508

510-
# lastindex with dimension: use size(VA, d) since we now use rectangular interpretation
511-
# RaggedEnd is still used internally for ragged column access via A.u
512-
@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer)
513-
return size(VA, Int(d))
514-
end
509+
## lastindex inherited from AbstractArray (uses size)
515510

516511
## Linear indexing: convert to Cartesian and dispatch to the N-ary getindex
517512
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N}, i::Int) where {T, N}
@@ -1030,17 +1025,17 @@ end
10301025
end
10311026
end
10321027

1033-
# Handle mixed Int + CartesianIndex by flattening to plain indices
1034-
# This is needed for sum(A; dims=d) and similar operations
1035-
Base.@propagate_inbounds function Base.getindex(
1036-
A::AbstractVectorOfArray, i::Int, ci::CartesianIndex
1037-
)
1028+
## Mixed Int + CartesianIndex (needed for sum(A; dims=d) etc.)
1029+
## Use @inline to avoid invalidation issues with overly broad signatures
1030+
@inline Base.@propagate_inbounds function Base.getindex(
1031+
A::AbstractVectorOfArray{T, N}, i::Int, ci::CartesianIndex
1032+
) where {T, N}
10381033
return A[i, Tuple(ci)...]
10391034
end
10401035

1041-
Base.@propagate_inbounds function Base.setindex!(
1042-
A::AbstractVectorOfArray, v, i::Int, ci::CartesianIndex
1043-
)
1036+
@inline Base.@propagate_inbounds function Base.setindex!(
1037+
A::AbstractVectorOfArray{T, N}, v, i::Int, ci::CartesianIndex
1038+
) where {T, N}
10441039
return A[i, Tuple(ci)...] = v
10451040
end
10461041

@@ -1164,9 +1159,7 @@ end
11641159
end
11651160
return (leading..., length(VA.u))
11661161
end
1167-
@inline Base.size(VA::AbstractVectorOfArray, i) = size(VA)[i]
11681162
@inline Base.size(A::Adjoint{T, <:AbstractVectorOfArray}) where {T} = reverse(size(A.parent))
1169-
@inline Base.size(A::Adjoint{T, <:AbstractVectorOfArray}, i) where {T} = size(A)[i]
11701163

11711164
Base.@propagate_inbounds function Base.setindex!(
11721165
VA::AbstractVectorOfArray{T, N}, v,
@@ -1319,36 +1312,14 @@ function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple)
13191312
Base.ensure_indexable(indices), Base.index_dimsum(indices...)
13201313
)
13211314
end
1322-
Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...)
1315+
## isassigned, ndims, eltype inherited from AbstractArray
13231316
function Base.check_parent_index_match(
13241317
::RecursiveArrayTools.AbstractVectorOfArray{T, N}, ::NTuple{N, Bool}
13251318
) where {T, N}
13261319
return nothing
13271320
end
1328-
# ndims and eltype inherited from AbstractArray{T, N}
13291321

1330-
# checkbounds: Use size(VA) for bounds checking (which uses max sizes for ragged).
1331-
# This means indices within the "virtual" rectangular shape are valid,
1332-
# and out-of-ragged-bounds returns zero on getindex.
1333-
# The default AbstractArray checkbounds handles most cases via size(VA).
1334-
# We only need a custom method for RaggedEnd/RaggedRange indices.
1335-
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
1336-
if _has_ragged_end(idx...)
1337-
return _checkbounds_ragged(Bool, VA, idx...)
1338-
end
1339-
# For non-ragged indices, delegate to the standard AbstractArray checkbounds
1340-
# which uses axes(VA) derived from size(VA)
1341-
s = size(VA)
1342-
if length(idx) == length(s)
1343-
return all(checkbounds(Bool, Base.OneTo(s[d]), idx[d]) for d in 1:length(s))
1344-
elseif length(idx) == 1
1345-
# Linear index
1346-
return checkbounds(Bool, 1:prod(s), idx[1])
1347-
else
1348-
# Let Julia's standard machinery handle it
1349-
return Base.checkbounds_indices(Bool, axes(VA), idx)
1350-
end
1351-
end
1322+
## checkbounds inherited from AbstractArray (uses axes derived from size)
13521323
function Base.copyto!(
13531324
dest::AbstractVectorOfArray{T, N},
13541325
src::AbstractVectorOfArray{T2, N}
@@ -1381,45 +1352,13 @@ function Base.copyto!(
13811352
copyto!(dest.u, src)
13821353
return dest
13831354
end
1384-
# Required for broadcasted setindex! when slicing across subarrays
1385-
# E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])`
1386-
# Need this method for `va[2, :, :] .= 3.0`
1387-
Base.@propagate_inbounds function Base.maybeview(A::AbstractVectorOfArray, I...)
1388-
return view(A, I...)
1389-
end
1355+
## maybeview inherited from AbstractArray
13901356

1391-
# Operations
1392-
function Base.isapprox(
1393-
A::AbstractVectorOfArray,
1394-
B::Union{AbstractVectorOfArray, AbstractArray};
1395-
kwargs...
1396-
)
1397-
return all(isapprox.(A, B; kwargs...))
1398-
end
1399-
1400-
function Base.isapprox(A::AbstractArray, B::AbstractVectorOfArray; kwargs...)
1401-
return all(isapprox.(A, B; kwargs...))
1402-
end
1403-
1404-
for op in [:(Base.:-), :(Base.:+)]
1405-
@eval function ($op)(A::AbstractVectorOfArray, B::AbstractVectorOfArray)
1406-
return ($op).(A, B)
1407-
end
1408-
end
1357+
## isapprox inherited from AbstractArray
14091358

1410-
for op in [:(Base.:/), :(Base.:\), :(Base.:*)]
1411-
if op !== :(Base.:/)
1412-
@eval ($op)(A::Number, B::AbstractVectorOfArray) = ($op).(A, B)
1413-
end
1414-
if op !== :(Base.:\)
1415-
@eval ($op)(A::AbstractVectorOfArray, B::Number) = ($op).(A, B)
1416-
end
1417-
end
1359+
## Arithmetic (+, -, *, /) inherited from AbstractArray / broadcasting
14181360

1419-
function Base.CartesianIndices(VA::AbstractVectorOfArray)
1420-
# Use size(VA) which handles ragged arrays via maximum sizes
1421-
return CartesianIndices(size(VA))
1422-
end
1361+
## CartesianIndices inherited from AbstractArray (uses axes/size)
14231362

14241363
# Tools for creating similar objects
14251364
# eltype is inherited from AbstractArray{T, N}
@@ -1492,21 +1431,18 @@ function Base.fill!(VA::AbstractVectorOfArray, x)
14921431
return VA
14931432
end
14941433

1495-
Base.reshape(A::AbstractVectorOfArray, dims...) = Base.reshape(Array(A), dims...)
1434+
## reshape inherited from AbstractArray
14961435

14971436
# any/all inherited from AbstractArray (iterates over all elements including ragged zeros)
14981437

14991438
# conversion tools
15001439
vecarr_to_vectors(VA::AbstractVectorOfArray) = [VA[i, :] for i in eachindex(VA.u[1])]
1501-
Base.vec(VA::AbstractVectorOfArray) = vec(convert(Array, VA)) # Allocates
1502-
# Convert to dense Array, zero-padding ragged arrays
1503-
function Base.convert(::Type{Array}, VA::AbstractVectorOfArray)
1504-
return Array(VA)
1505-
end
1440+
## vec inherited from AbstractArray
1441+
## convert(Array, VA) inherited from AbstractArray (calls Array(VA))
15061442

15071443
# sum, prod inherited from AbstractArray
15081444

1509-
@inline Base.adjoint(VA::AbstractVectorOfArray) = Adjoint(VA)
1445+
## adjoint inherited from AbstractArray
15101446

15111447
# linear algebra
15121448
ArrayInterface.issingular(va::AbstractVectorOfArray) = ArrayInterface.issingular(Matrix(va))

test/adjoints.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,18 @@ end
8080

8181
x = float.(6:10)
8282
loss(x)
83-
# Zygote adjoints need updating for AbstractVectorOfArray <: AbstractArray
84-
# ForwardDiff tests still pass since they don't use Zygote's ProjectTo
85-
@test ForwardDiff.gradient(loss, x) isa Vector
86-
@test ForwardDiff.gradient(loss3, x) isa Vector
87-
@test_broken Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
88-
@test_broken Zygote.gradient(loss2, x)[1] == ForwardDiff.gradient(loss2, x)
89-
@test_broken Zygote.gradient(loss3, x)[1] == ForwardDiff.gradient(loss3, x)
90-
@test_broken Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x)
83+
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
84+
@test Zygote.gradient(loss2, x)[1] == ForwardDiff.gradient(loss2, x)
85+
@test Zygote.gradient(loss3, x)[1] == ForwardDiff.gradient(loss3, x)
86+
@test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x)
9187
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
9288
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
93-
@test_broken Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
94-
@test_broken Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
89+
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
90+
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
9591
@test ForwardDiff.derivative(loss9, 0.0) ==
9692
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
97-
@test_broken Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
98-
@test_broken Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)
93+
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
94+
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)
9995

10096
voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3))
10197
voa_gs, = Zygote.gradient(voa) do x

0 commit comments

Comments
 (0)