Skip to content

Commit c8c59cc

Browse files
authored
Implement StarSetColoringResult and TreeSetColoringResult (#57)
* StarSet-based decompression * Add result types to doc * Add tests of standard and star/tree-set decompression * Remove useless getters * Docstrings
1 parent d774220 commit c8c59cc

6 files changed

Lines changed: 94 additions & 21 deletions

File tree

docs/src/dev.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ SparseMatrixColorings.TreeSet
3535

3636
```@docs
3737
SparseMatrixColorings.DefaultColoringResult
38+
SparseMatrixColorings.StarSetColoringResult
39+
SparseMatrixColorings.TreeSetColoringResult
3840
SparseMatrixColorings.DirectSparseColoringResult
3941
```
4042

src/decompression.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,20 @@ function decompress_aux!(
9292
return A
9393
end
9494

95+
function decompress_aux!(
96+
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::StarSetColoringResult{:column}
97+
) where {R<:Real}
98+
A .= zero(R)
99+
S = get_matrix(result)
100+
color = column_colors(result)
101+
for ij in findall(!iszero, S)
102+
i, j = Tuple(ij)
103+
k, l = symmetric_coefficient(i, j, color, result.star_set)
104+
A[i, j] = B[k, l]
105+
end
106+
return A
107+
end
108+
95109
function decompress_aux!(
96110
A::AbstractMatrix{R},
97111
B::AbstractMatrix{R},

src/interface.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ function coloring(
150150
)
151151
ag = adjacency_graph(S)
152152
color, star_set = star_coloring(ag, algo.order)
153-
# TODO: handle star_set
154-
return DefaultColoringResult{:symmetric,:column,:direct}(S, color)
153+
return StarSetColoringResult{:column}(S, color, star_set)
155154
end
156155

157156
function coloring(
@@ -161,8 +160,7 @@ function coloring(
161160
)
162161
ag = adjacency_graph(S)
163162
color, tree_set = acyclic_coloring(ag, algo.order)
164-
# TODO: handle tree_set
165-
return DefaultColoringResult{:symmetric,:column,:substitution}(S, color)
163+
return TreeSetColoringResult{:column}(S, color, tree_set)
166164
end
167165

168166
## ADTypes interface

src/result.jl

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ function group_by_color(color::AbstractVector{<:Integer})
7070
return group
7171
end
7272

73+
get_matrix(result::AbstractColoringResult) = result.matrix
74+
75+
column_colors(result::AbstractColoringResult{s,:column}) where {s} = result.color
76+
column_groups(result::AbstractColoringResult{s,:column}) where {s} = result.group
77+
78+
row_colors(result::AbstractColoringResult{s,:row}) where {s} = result.color
79+
row_groups(result::AbstractColoringResult{s,:row}) where {s} = result.group
80+
7381
## Concrete subtypes
7482

7583
"""
@@ -103,10 +111,64 @@ function DefaultColoringResult{structure,partition,decompression}(
103111
)
104112
end
105113

106-
get_matrix(result::DefaultColoringResult) = result.matrix
114+
"""
115+
$TYPEDEF
116+
117+
Storage for the result of a symmetric coloring algorithm with direct decompression.
118+
119+
Similar to [`DefaultColoringResult`](@ref) but contains an additional [`StarSet`](@ref).
120+
121+
# Fields
122+
123+
$TYPEDFIELDS
124+
125+
# See also
126+
127+
- [`AbstractColoringResult`](@ref)
128+
"""
129+
struct StarSetColoringResult{partition,M} <:
130+
AbstractColoringResult{:symmetric,partition,:direct,M}
131+
matrix::M
132+
color::Vector{Int}
133+
group::Vector{Vector{Int}}
134+
star_set::StarSet
135+
end
136+
137+
function StarSetColoringResult{partition}(
138+
matrix::M, color::Vector{Int}, star_set::StarSet
139+
) where {partition,M}
140+
return StarSetColoringResult{partition,M}(
141+
matrix, color, group_by_color(color), star_set
142+
)
143+
end
144+
145+
"""
146+
$TYPEDEF
147+
148+
Storage for the result of a symmetric coloring algorithm with decompression by substitution.
107149
108-
column_colors(result::DefaultColoringResult{s,:column}) where {s} = result.color
109-
column_groups(result::DefaultColoringResult{s,:column}) where {s} = result.group
150+
Similar to [`DefaultColoringResult`](@ref) but contains an additional [`TreeSet`](@ref).
110151
111-
row_colors(result::DefaultColoringResult{s,:row}) where {s} = result.color
112-
row_groups(result::DefaultColoringResult{s,:row}) where {s} = result.group
152+
# Fields
153+
154+
$TYPEDFIELDS
155+
156+
# See also
157+
158+
- [`AbstractColoringResult`](@ref)
159+
"""
160+
struct TreeSetColoringResult{partition,M} <:
161+
AbstractColoringResult{:symmetric,partition,:substitution,M}
162+
matrix::M
163+
color::Vector{Int}
164+
group::Vector{Vector{Int}}
165+
tree_set::TreeSet
166+
end
167+
168+
function TreeSetColoringResult{partition}(
169+
matrix::M, color::Vector{Int}, tree_set::TreeSet
170+
) where {partition,M}
171+
return TreeSetColoringResult{partition,M}(
172+
matrix, color, group_by_color(color), tree_set
173+
)
174+
end

src/sparsematrixcsc.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,6 @@ function DirectSparseColoringResult{structure,partition}(
3333
)
3434
end
3535

36-
get_matrix(result::DirectSparseColoringResult) = result.matrix
37-
38-
column_colors(result::DirectSparseColoringResult{s,:column}) where {s} = result.color
39-
column_groups(result::DirectSparseColoringResult{s,:column}) where {s} = result.group
40-
41-
row_colors(result::DirectSparseColoringResult{s,:row}) where {s} = result.color
42-
row_groups(result::DirectSparseColoringResult{s,:row}) where {s} = result.group
43-
4436
## Coloring
4537

4638
function coloring(

test/small.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,24 @@ end;
7777
@testset "Substitution - Fig 6.1 from 'What color is your Jacobian'" begin
7878
example = what_fig_61()
7979
A0, B0, color0 = example.A, example.B, example.color
80+
result0 = DefaultColoringResult{:symmetric,:column,:substitution}(A0, color0)
8081
result = coloring(
8182
A0,
8283
ColoringProblem(;
8384
structure=:symmetric, partition=:column, decompression=:substitution
8485
),
8586
GreedyColoringAlgorithm(),
86-
)
87-
color = column_colors(result)
87+
) # returns a TreeSetColoringResult
8888
group = column_groups(result)
8989
B = stack(group; dims=2) do g
9090
dropdims(sum(A0[:, g]; dims=2); dims=2)
9191
end
92-
@test color != color0
92+
@test column_colors(result) != color0
9393
@test B != B0
9494
@test decompress(B, result) A0
95+
@test decompress(B0, result0) A0
9596
for A in matrix_versions(A0)
97+
@test decompress!(respectful_similar(A), B0, result0) A
9698
@test decompress!(respectful_similar(A), B, result) A
9799
end
98100
end
@@ -112,16 +114,19 @@ end;
112114
@testset "Substitution - Fig 4 from 'Efficient computation of sparse hessians using coloring and AD'" begin
113115
example = efficient_fig_4()
114116
A0, B0, color0 = example.A, example.B, example.color
117+
result0 = DefaultColoringResult{:symmetric,:column,:substitution}(A0, color0)
115118
result = coloring(
116119
A0,
117120
ColoringProblem(;
118121
structure=:symmetric, partition=:column, decompression=:substitution
119122
),
120123
GreedyColoringAlgorithm(),
121-
)
124+
) # returns a TreeSetColoringResult
122125
@test column_colors(result) == color0
126+
@test decompress(B0, result0) A0
123127
@test decompress(B0, result) A0
124128
for A in matrix_versions(A0)
129+
@test decompress!(respectful_similar(A), B0, result0) A
125130
@test decompress!(respectful_similar(A), B0, result) A
126131
end
127132
end

0 commit comments

Comments
 (0)