Skip to content

Commit 3d6e41a

Browse files
authored
Move coloring soundness tests to DIT (#246)
1 parent 5aaed92 commit 3d6e41a

5 files changed

Lines changed: 80 additions & 57 deletions

File tree

DifferentiationInterface/src/sparse/coloring.jl

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -89,36 +89,6 @@ function distance2_row_coloring(g::BipartiteGraph)
8989
return colors
9090
end
9191

92-
function check_structurally_orthogonal_columns(
93-
A::AbstractMatrix, colors::AbstractVector{<:Integer}
94-
)
95-
for c in unique(colors)
96-
js = filter(j -> colors[j] == c, axes(A, 2))
97-
Ajs = @view A[:, js]
98-
nonzeros_per_row = count(!iszero, Ajs; dims=2)
99-
if maximum(nonzeros_per_row) > 1
100-
@warn "Color $c has columns $js sharing nonzeros"
101-
return false
102-
end
103-
end
104-
return true
105-
end
106-
107-
function check_structurally_orthogonal_rows(
108-
A::AbstractMatrix, colors::AbstractVector{<:Integer}
109-
)
110-
for c in unique(colors)
111-
is = filter(i -> colors[i] == c, axes(A, 1))
112-
Ais = @view A[is, :]
113-
nonzeros_per_column = count(!iszero, Ais; dims=1)
114-
if maximum(nonzeros_per_column) > 1
115-
@warn "Color $c has rows $is sharing nonzeros"
116-
return false
117-
end
118-
end
119-
return true
120-
end
121-
12292
## Hessian coloring
12393

12494
"""
@@ -189,30 +159,6 @@ function star_coloring(g::AdjacencyGraph)
189159
return colors
190160
end
191161

192-
function check_symmetrically_structurally_orthogonal(
193-
A::AbstractMatrix, colors::AbstractVector{<:Integer}
194-
)
195-
for i in axes(A, 2), j in axes(A, 2)
196-
if !iszero(A[i, j])
197-
group_i = filter(i2 -> (i2 != i) && (colors[i2] == colors[i]), axes(A, 2))
198-
group_j = filter(j2 -> (j2 != j) && (colors[j2] == colors[j]), axes(A, 2))
199-
A_group_i_column_j = @view A[group_i, j]
200-
A_group_j_column_i = @view A[group_j, i]
201-
nonzeros_group_i_column_j = count(!iszero, A_group_i_column_j)
202-
nonzeros_group_j_column_i = count(!iszero, A_group_j_column_i)
203-
if nonzeros_group_i_column_j > 0 && nonzeros_group_j_column_i > 0
204-
@warn """
205-
For coefficient $((i, j)), both of the following have confounding zeros:
206-
- color $(colors[j]) with group $group_j
207-
- color $(colors[i]) with group $group_i
208-
"""
209-
return false
210-
end
211-
end
212-
end
213-
return true
214-
end
215-
216162
## ADTypes overloads
217163

218164
"""
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ADTypes: ADTypes
22
import DifferentiationInterface as DI
3+
import DifferentiationInterfaceTest as DIT
34
using LinearAlgebra: I, Symmetric
45
using SparseArrays: sprand
56

@@ -8,14 +9,14 @@ alg = DI.GreedyColoringAlgorithm()
89
A = sprand(Bool, 100, 200, 0.05)
910

1011
column_colors = ADTypes.column_coloring(A, alg)
11-
@test DI.check_structurally_orthogonal_columns(A, column_colors)
12+
@test DIT.check_structurally_orthogonal_columns(A, column_colors)
1213
@test maximum(column_colors) < size(A, 2) ÷ 2
1314

1415
row_colors = ADTypes.row_coloring(A, alg)
15-
@test DI.check_structurally_orthogonal_rows(A, row_colors)
16+
@test DIT.check_structurally_orthogonal_rows(A, row_colors)
1617
@test maximum(row_colors) < size(A, 1) ÷ 2
1718

1819
S = Symmetric(sprand(Bool, 100, 100, 0.05)) + I
1920
symmetric_colors = ADTypes.symmetric_coloring(S, alg)
20-
@test DI.check_symmetrically_structurally_orthogonal(S, symmetric_colors)
21+
@test DIT.check_symmetrically_structurally_orthogonal(S, symmetric_colors)
2122
@test maximum(symmetric_colors) < size(A, 2) ÷ 2

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ include("utils/zero_backends.jl")
6161
include("utils/misc.jl")
6262
include("utils/filter.jl")
6363

64+
include("tests/coloring.jl")
6465
include("tests/correctness.jl")
6566
include("tests/type_stability.jl")
6667
include("tests/sparsity.jl")
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
function check_structurally_orthogonal_columns(
2+
A::AbstractMatrix, colors::AbstractVector{<:Integer}
3+
)
4+
for c in unique(colors)
5+
js = filter(j -> colors[j] == c, axes(A, 2))
6+
Ajs = @view A[:, js]
7+
nonzeros_per_row = count(!iszero, Ajs; dims=2)
8+
if maximum(nonzeros_per_row) > 1
9+
@warn "Color $c has columns $js sharing nonzeros"
10+
return false
11+
end
12+
end
13+
return true
14+
end
15+
16+
function check_structurally_orthogonal_rows(
17+
A::AbstractMatrix, colors::AbstractVector{<:Integer}
18+
)
19+
for c in unique(colors)
20+
is = filter(i -> colors[i] == c, axes(A, 1))
21+
Ais = @view A[is, :]
22+
nonzeros_per_column = count(!iszero, Ais; dims=1)
23+
if maximum(nonzeros_per_column) > 1
24+
@warn "Color $c has rows $is sharing nonzeros"
25+
return false
26+
end
27+
end
28+
return true
29+
end
30+
31+
function check_symmetrically_structurally_orthogonal(
32+
A::AbstractMatrix, colors::AbstractVector{<:Integer}
33+
)
34+
for i in axes(A, 2), j in axes(A, 2)
35+
if !iszero(A[i, j])
36+
group_i = filter(i2 -> (i2 != i) && (colors[i2] == colors[i]), axes(A, 2))
37+
group_j = filter(j2 -> (j2 != j) && (colors[j2] == colors[j]), axes(A, 2))
38+
A_group_i_column_j = @view A[group_i, j]
39+
A_group_j_column_i = @view A[group_j, i]
40+
nonzeros_group_i_column_j = count(!iszero, A_group_i_column_j)
41+
nonzeros_group_j_column_i = count(!iszero, A_group_j_column_i)
42+
if nonzeros_group_i_column_j > 0 && nonzeros_group_j_column_i > 0
43+
@warn """
44+
For coefficient $((i, j)), both of the following have confounding zeros:
45+
- color $(colors[j]) with group $group_j
46+
- color $(colors[i]) with group $group_i
47+
"""
48+
return false
49+
end
50+
end
51+
end
52+
return true
53+
end
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using ADTypes: ADTypes
2+
import DifferentiationInterface as DI
3+
import DifferentiationInterfaceTest as DIT
4+
using LinearAlgebra: I, Symmetric
5+
using SparseArrays: sprand
6+
7+
alg = DI.GreedyColoringAlgorithm()
8+
9+
A = sprand(Bool, 100, 200, 0.05)
10+
11+
column_colors = ADTypes.column_coloring(A, alg)
12+
@test DIT.check_structurally_orthogonal_columns(A, column_colors)
13+
@test maximum(column_colors) < size(A, 2) ÷ 2
14+
15+
row_colors = ADTypes.row_coloring(A, alg)
16+
@test DIT.check_structurally_orthogonal_rows(A, row_colors)
17+
@test maximum(row_colors) < size(A, 1) ÷ 2
18+
19+
S = Symmetric(sprand(Bool, 100, 100, 0.05)) + I
20+
symmetric_colors = ADTypes.symmetric_coloring(S, alg)
21+
@test DIT.check_symmetrically_structurally_orthogonal(S, symmetric_colors)
22+
@test maximum(symmetric_colors) < size(A, 2) ÷ 2

0 commit comments

Comments
 (0)