Skip to content

Commit 46ce779

Browse files
committed
Merge branch 'master' into lastindex-ragged
2 parents 957641e + 0c717ab commit 46ce779

3 files changed

Lines changed: 33 additions & 3 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "3.40.0"
4+
version = "3.41.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/array_partition.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,19 @@ function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
329329
end
330330
recursivecopy(A::ArrayPartition) = ArrayPartition(copy.(A.x))
331331

332+
function recursivecopy(A::ArrayPartition{
333+
T, S}) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
334+
return ArrayPartition(map(recursivecopy, A.x))
335+
end
336+
337+
function recursivecopy!(A::ArrayPartition{T, S},
338+
B::ArrayPartition{T, S}) where {T, S <: Tuple{Vararg{AbstractVectorOfArray}}}
339+
for i in eachindex(A.x, B.x)
340+
recursivecopy!(A.x[i], B.x[i])
341+
end
342+
return A
343+
end
344+
332345
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
333346

334347
# note: consider only first partition for recursive one and eltype
@@ -475,7 +488,7 @@ end
475488
## Linear Algebra
476489

477490
function ArrayInterface.zeromatrix(A::ArrayPartition)
478-
x = reduce(vcat,vec.(A.x))
491+
x = reduce(vcat, vec.(A.x))
479492
x .* x' .* false
480493
end
481494

test/partitions_test.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,23 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
132132
@inferred recursive_one(x)
133133
@inferred recursive_bottom_eltype(x)
134134

135+
src_voa = VectorOfArray([[1.0, 2.0], [3.0, 4.0]])
136+
src_ap = ArrayPartition(src_voa)
137+
138+
copied_ap = recursivecopy(src_ap)
139+
@test copied_ap.x[1].u[1] == src_ap.x[1].u[1]
140+
@test copied_ap.x[1].u[2] == src_ap.x[1].u[2]
141+
@test copied_ap.x[1].u[1] !== src_ap.x[1].u[1]
142+
@test copied_ap.x[1].u[2] !== src_ap.x[1].u[2]
143+
144+
dest_voa = VectorOfArray([zeros(2), zeros(2)])
145+
dest_ap = ArrayPartition(dest_voa)
146+
recursivecopy!(dest_ap, src_ap)
147+
@test dest_ap.x[1].u[1] == src_ap.x[1].u[1]
148+
@test dest_ap.x[1].u[2] == src_ap.x[1].u[2]
149+
@test dest_ap.x[1].u[1] !== src_ap.x[1].u[1]
150+
@test dest_ap.x[1].u[2] !== src_ap.x[1].u[2]
151+
135152
# mapreduce
136153
@inferred Union{Int, Float64} sum(x)
137154
@inferred sum(ArrayPartition(ArrayPartition(zeros(4, 4))))
@@ -149,7 +166,7 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
149166
@test any(isnan, ArrayPartition([2], [NaN]))
150167
@test any(isnan, ArrayPartition([2], ArrayPartition([NaN])))
151168

152-
# all
169+
# all
153170
@test !all(isnan, ArrayPartition([1, 2], [3.0, 4.0]))
154171
@test !all(isnan, ArrayPartition([3.0, 4.0]))
155172
@test !all(isnan, ArrayPartition([NaN], [3.0, 4.0]))

0 commit comments

Comments
 (0)