Skip to content

Commit 2c5d35c

Browse files
authored
Better sparsity pattern check (#9)
* Better sparsity pattern check * Coverage
1 parent 65fa107 commit 2c5d35c

2 files changed

Lines changed: 77 additions & 3 deletions

File tree

src/decompression.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,30 @@ function transpose_respecting_similar(A::Transpose, ::Type{T}) where {T}
44
return transpose(similar(parent(A), T))
55
end
66

7+
function same_sparsity_pattern(A::SparseMatrixCSC, B::SparseMatrixCSC)
8+
if size(A) != size(B)
9+
return false
10+
elseif nnz(A) != nnz(B)
11+
return false
12+
else
13+
for j in axes(A, 2)
14+
rA = nzrange(A, j)
15+
rB = nzrange(B, j)
16+
if rA != rB
17+
return false
18+
end
19+
# TODO: check rowvals?
20+
end
21+
return true
22+
end
23+
end
24+
25+
function same_sparsity_pattern(
26+
A::Transpose{<:Any,<:SparseMatrixCSC}, B::Transpose{<:Any,<:SparseMatrixCSC}
27+
)
28+
return same_sparsity_pattern(parent(A), parent(B))
29+
end
30+
731
"""
832
color_groups(colors)
933
@@ -60,7 +84,7 @@ function decompress_columns!(
6084
C::AbstractMatrix{R},
6185
colors::AbstractVector{<:Integer},
6286
) where {R<:Real}
63-
if nnz(parent(A)) != nnz(parent(S))
87+
if !same_sparsity_pattern(A, S)
6488
throw(DimensionMismatch("`A` and `S` must have the same sparsity pattern."))
6589
end
6690
Anz, Arv = nonzeros(A), rowvals(A)
@@ -133,7 +157,7 @@ function decompress_rows!(
133157
C::AbstractMatrix{R},
134158
colors::AbstractVector{<:Integer},
135159
) where {R<:Real}
136-
if nnz(parent(A)) != nnz(parent(S))
160+
if !same_sparsity_pattern(A, S)
137161
throw(DimensionMismatch("`A` and `S` must have the same sparsity pattern."))
138162
end
139163
PA = parent(A)

test/decompression_correctness.jl

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,55 @@
11
using ADTypes: column_coloring, row_coloring, symmetric_coloring
22
using Compat
3+
using LinearAlgebra
34
using SparseArrays
45
using SparseMatrixColorings
5-
using SparseMatrixColorings: color_groups, decompress_columns, decompress_rows
6+
using SparseMatrixColorings:
7+
color_groups,
8+
decompress_columns,
9+
decompress_columns!,
10+
decompress_rows,
11+
decompress_rows!,
12+
same_sparsity_pattern
613
using StableRNGs
714
using Test
815

916
rng = StableRNG(63)
1017

1118
algo = GreedyColoringAlgorithm()
1219

20+
@testset "Sparsity pattern comparison" begin
21+
A = [
22+
1 1
23+
0 1
24+
0 0
25+
]
26+
B1 = [
27+
1 1
28+
0 1
29+
1 0
30+
]
31+
B2 = [
32+
1 1
33+
0 0
34+
0 1
35+
]
36+
C = [
37+
1 1 0
38+
0 1 0
39+
0 0 0
40+
]
41+
42+
@test same_sparsity_pattern(sparse(A), sparse(A))
43+
@test !same_sparsity_pattern(sparse(A), sparse(B1))
44+
@test_broken !same_sparsity_pattern(sparse(A), sparse(B2))
45+
@test !same_sparsity_pattern(sparse(A), sparse(C))
46+
47+
@test same_sparsity_pattern(transpose(sparse(A)), transpose(sparse(A)))
48+
@test !same_sparsity_pattern(transpose(sparse(A)), transpose(sparse(B1)))
49+
@test_broken !same_sparsity_pattern(transpose(sparse(A)), transpose(sparse(B2)))
50+
@test !same_sparsity_pattern(transpose(sparse(A)), transpose(sparse(C)))
51+
end;
52+
1353
@testset "Column decompression" begin
1454
@testset "Small" begin
1555
A0 = [
@@ -33,6 +73,11 @@ algo = GreedyColoringAlgorithm()
3373
(sparse(A0), sparse(S0)),
3474
]
3575
@test decompress_columns(S, C, colors) == A
76+
if A isa SparseMatrixCSC
77+
@test_throws DimensionMismatch decompress_columns!(
78+
similar(A), false .* S, C, colors
79+
)
80+
end
3681
end
3782
end
3883
@testset "Medium" begin
@@ -76,6 +121,11 @@ end
76121
(transpose(sparse(transpose(A0))), transpose(sparse(transpose(S0)))),
77122
]
78123
@test decompress_rows(S, C, colors) == A
124+
if A isa Transpose{<:Any,<:SparseMatrixCSC}
125+
@test_throws DimensionMismatch decompress_rows!(
126+
transpose(similar(parent(A))), transpose(false .* parent(S)), C, colors
127+
)
128+
end
79129
end
80130
end
81131
@testset "Medium" begin

0 commit comments

Comments
 (0)