Skip to content

Commit ce5819e

Browse files
committed
Add tests
1 parent a8a628d commit ce5819e

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

DifferentiationInterface/test/Core/Internals/linalg.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using DifferentiationInterface: recursive_similar, get_pattern
1+
using DifferentiationInterface: recursive_similar, get_pattern, arroftup_to_tupofarr
22
using SparseArrays
33
using Test
4+
using JLArrays, ComponentArrays
45

56
@testset "Recursive similar" begin
67
@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32}
@@ -16,3 +17,19 @@ end
1617
@test_broken get_pattern(D) == Diagonal(trues(10))
1718
@test get_pattern(sparse(D)) == Diagonal(trues(10))
1819
end
20+
21+
@testset "Wrong-mode array conversion" begin
22+
x = [1.0, 3.0, 5.0]
23+
xt = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
24+
y = ComponentVector(a = [1.0, 3.0], b = [5.0])
25+
yt = ComponentVector(a = [(1.0, 2.0), (3.0, 4.0)], b = [(5.0, 6.0)])
26+
z = jl([1.0, 3.0, 5.0])
27+
zt = jl([(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
28+
@test arroftup_to_tupofarr((1.0, 2.0), 1.0) == (1.0, 2.0)
29+
@test arroftup_to_tupofarr(xt, x) == ([1.0, 3.0, 5.0], [2.0, 4.0, 6.0])
30+
@test arroftup_to_tupofarr(yt, y) == (ComponentVector(a = [1.0, 3.0], b = [5.0]), ComponentVector(a = [2.0, 4.0], b = [6.0]))
31+
@test arroftup_to_tupofarr(zt, z) == (jl([1.0, 3.0, 5.0]), jl([2.0, 4.0, 6.0]))
32+
@test arroftup_to_tupofarr(xt, x)[1] isa Vector
33+
@test arroftup_to_tupofarr(yt, y)[1] isa ComponentVector
34+
@test arroftup_to_tupofarr(zt, z)[1] isa JLVector
35+
end

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,8 @@ end
169169
logging = LOGGING,
170170
)
171171
end;
172+
173+
@testset "Array format preservation in wrong mode" begin
174+
@test gradient(sum, AutoSimpleFiniteDiff(), jl(ones(2))) isa JLVector
175+
@test derivative(t -> jl(fill(t, 2)), AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), 1.0) isa JLVector
176+
end

0 commit comments

Comments
 (0)