Skip to content

Commit 1220488

Browse files
refactor: hacky fix for autodiff
1 parent 6a5e303 commit 1220488

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,20 @@ end
9494
@adjoint function VectorOfArray(u)
9595
VectorOfArray(u),
9696
y -> begin
97-
(VectorOfArray([y[].u[ntuple(x -> Colon(), ndims(y[].u) - 1)..., i]
98-
for i in 1:size(y[].u)[end]]),)
97+
y isa Ref && (y = VectorOfArray(y[].u))
98+
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
99+
for i in 1:size(y.u)[end]]),)
99100
end
100101
end
101102

102103
@adjoint function DiffEqArray(u, t)
103104
DiffEqArray(u, t),
104-
y -> (DiffEqArray([y[].u[ntuple(x -> Colon(), ndims(y[].u) - 1)..., i]
105-
for i in 1:size(y[].u)[end]],
105+
y -> begin
106+
y isa Ref && (y = VectorOfArray(y[].u))
107+
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
108+
for i in 1:size(y.u)[end]],
106109
t), nothing)
110+
end
107111
end
108112

109113
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})

0 commit comments

Comments
 (0)