We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ec2ee49 commit 4425ce9Copy full SHA for 4425ce9
3 files changed
src/array_partition.jl
@@ -579,3 +579,7 @@ end
579
end
580
return sum_expr
581
582
+
583
+function Adapt.adapt_structure(to, ap::ArrayPartition)
584
+ ArrayPartition(map(x -> Adapt.adapt(to, x), ap.x)...)
585
+end
test/gpu/arraypartition_gpu.jl
@@ -1,4 +1,4 @@
1
-using RecursiveArrayTools, CUDA, Test
+using RecursiveArrayTools, CUDA, Test, Adapt
2
CUDA.allowscalar(false)
3
4
# Test indexing with colon
@@ -21,3 +21,23 @@ fill!(pA, false)
21
a = ArrayPartition(([1.0f0] |> cu, [2.0f0] |> cu, [3.0f0] |> cu))
22
b = ArrayPartition(([0.0f0] |> cu, [0.0f0] |> cu, [0.0f0] |> cu))
23
@. a + b
24
25
+# Test adapt from ArrayPartition with CuArrays to ArrayPartition with CPU arrays
26
27
+a = CuArray(Float64.([1., 2., 3., 4.]))
28
+b = CuArray(Float64.([1., 2., 3., 4.]))
29
+part_a_gpu = ArrayPartition(a, b)
30
+part_a = adapt(Array{Float32}, part_a_gpu)
31
32
+c = Float32.([1., 2., 3., 4.])
33
+d = Float32.([1., 2., 3., 4.])
34
+part_b = ArrayPartition(c, d)
35
36
+@test part_a == part_b # Test equality
37
38
+for i in 1:length(part_a.x)
39
+ sub_a = part_a.x[i]
40
+ sub_b = part_b.x[i]
41
+ @test sub_a == sub_b # Test for value equality in sub-arrays
42
+ @test typeof(sub_a) === typeof(sub_b) # Test type equality
43
test/partitions_test.jl
-using RecursiveArrayTools, Test, Statistics, ArrayInterface
+using RecursiveArrayTools, Test, Statistics, ArrayInterface, Adapt
@test length(ArrayPartition()) == 0
@test isempty(ArrayPartition())
@@ -306,3 +306,22 @@ end
306
copyto!(u, ArrayPartition(1.0, -1.2))
307
@test u == [1.0, -1.2]
308
309
310
+# Test adapt on ArrayPartition from Float64 to Float32 arrays
311
+a = Float64.([1., 2., 3., 4.])
312
+b = Float64.([1., 2., 3., 4.])
313
+part_a_64 = ArrayPartition(a, b)
314
+part_a = adapt(Array{Float32}, part_a_64)
315
316
317
318
319
320
+@test part_a == part_b # Test equality of partitions
321
322
323
324
325
+ @test sub_a == sub_b # Test for value equality
326
327
0 commit comments