Skip to content

Commit 8501655

Browse files
authored
Implement single-color decompression (#81)
* Add single-color decompression * Encode undefined hubs as negative * Typos * Typo * Fix test utils
1 parent 7443b02 commit 8501655

11 files changed

Lines changed: 198 additions & 111 deletions

File tree

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ row_groups
3535
compress
3636
decompress
3737
decompress!
38+
decompress_single_color!
3839
```
3940

4041
## Orders

docs/src/dev.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ SparseMatrixColorings.symmetric_coefficient
2626
SparseMatrixColorings.star_coloring
2727
SparseMatrixColorings.acyclic_coloring
2828
SparseMatrixColorings.group_by_color
29-
SparseMatrixColorings.get_matrix
3029
SparseMatrixColorings.StarSet
3130
SparseMatrixColorings.TreeSet
3231
```
3332

3433
## Concrete coloring results
3534

3635
```@docs
37-
SparseMatrixColorings.NonSymmetricColoringResult
36+
SparseMatrixColorings.ColumnColoringResult
37+
SparseMatrixColorings.RowColoringResult
3838
SparseMatrixColorings.StarSetColoringResult
3939
SparseMatrixColorings.TreeSetColoringResult
4040
SparseMatrixColorings.LinearSystemColoringResult

src/SparseMatrixColorings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ export ColoringProblem, GreedyColoringAlgorithm, AbstractColoringResult
5454
export coloring
5555
export column_colors, row_colors
5656
export column_groups, row_groups
57-
export compress, decompress, decompress!
57+
export compress, decompress, decompress!, decompress_single_color!
5858

5959
end

src/coloring.jl

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,23 @@ $TYPEDFIELDS
131131
struct StarSet
132132
"a mapping from edges (pair of vertices) to their star index"
133133
star::Dict{Tuple{Int,Int},Int}
134-
"a mapping from star indices to their hub (the hub is `0` if the star only contains one edge)"
134+
"a mapping from star indices to their hub (undefined hubs for single-edge stars are the negative value of one of the vertices, picked arbitrarily)"
135135
hub::Vector{Int}
136+
"a mapping from star indices to the vector of their spokes"
137+
spokes::Vector{Vector{Int}}
138+
end
139+
140+
function StarSet(star, hub)
141+
spokes = [Int[] for s in eachindex(hub)]
142+
for ((i, j), s) in pairs(star)
143+
h = hub[s]
144+
if i == abs(h)
145+
push!(spokes[s], j)
146+
elseif j == abs(h)
147+
push!(spokes[s], i)
148+
end
149+
end
150+
return StarSet(star, hub, spokes)
136151
end
137152

138153
_sort(u, v) = (min(u, v), max(u, v))
@@ -185,7 +200,7 @@ function _update_stars!(
185200
hub[star[vq]] = v # this may already be true
186201
star[vw] = star[vq]
187202
else # vw forms a new star
188-
push!(hub, 0) # hub is yet undefined
203+
push!(hub, -max(v, w)) # hub is undefined so we set it to a negative value, but it allows us to remember one of the two vertices
189204
star[vw] = length(hub)
190205
end
191206
end
@@ -194,13 +209,6 @@ function _update_stars!(
194209
end
195210

196211
"""
197-
symmetric_coefficient(
198-
i::Integer, j::Integer,
199-
color::AbstractVector{<:Integer},
200-
group::AbstractVector{<:AbstractVector{<:Integer}},
201-
S::AbstractMatrix{Bool}
202-
)
203-
204212
symmetric_coefficient(
205213
i::Integer, j::Integer,
206214
color::AbstractVector{<:Integer},
@@ -209,30 +217,12 @@ end
209217
210218
Return the indices `(k, c)` such that `A[i, j] = B[k, c]`, where `A` is the uncompressed symmetric matrix and `B` is the column-compressed matrix.
211219
212-
The first version corresponds to algorithm `DirectRecover1` in the paper, the second to `DirectRecover2`.
220+
This function corresponds to algorithm `DirectRecover2` in the paper.
213221
214222
# References
215223
216-
> [_Efficient Computation of Sparse Hessians Using Coloring and Automatic Differentiation_](https://pubsonline.informs.org/doi/abs/10.1287/ijoc.1080.0286), Gebremedhin et al. (2009), Figures 2 and 3
224+
> [_Efficient Computation of Sparse Hessians Using Coloring and Automatic Differentiation_](https://pubsonline.informs.org/doi/abs/10.1287/ijoc.1080.0286), Gebremedhin et al. (2009), Figure 3
217225
"""
218-
function symmetric_coefficient end
219-
220-
function symmetric_coefficient(
221-
i::Integer,
222-
j::Integer,
223-
color::AbstractVector{<:Integer},
224-
group::AbstractVector{<:AbstractVector{<:Integer}},
225-
S::AbstractMatrix,
226-
)
227-
for j2 in group[color[j]]
228-
j2 == j && continue
229-
if !iszero(S[i, j2])
230-
return j, color[i]
231-
end
232-
end
233-
return i, color[j]
234-
end
235-
236226
function symmetric_coefficient(
237227
i::Integer, j::Integer, color::AbstractVector{<:Integer}, star_set::StarSet
238228
)
@@ -246,11 +236,7 @@ function symmetric_coefficient(
246236
i, j = j, i
247237
end
248238
star_id = star[i, j]
249-
h = hub[star_id]
250-
if h == 0
251-
# pick arbitrary hub
252-
h = i
253-
end
239+
h = abs(hub[star_id])
254240
if h == j
255241
# i is the spoke
256242
return i, color[h]

src/decompression.jl

Lines changed: 116 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ true
116116
- [`AbstractColoringResult`](@ref)
117117
"""
118118
function decompress(B::AbstractMatrix{R}, result::AbstractColoringResult) where {R<:Real}
119-
S = get_matrix(result)
119+
@compat (; S) = result
120120
A = respectful_similar(S, R)
121121
return decompress!(A, B, result)
122122
end
@@ -183,33 +183,84 @@ true
183183
"""
184184
function decompress! end
185185

186-
## NonSymmetricColoringResult
186+
"""
187+
decompress_single_color!(
188+
A::AbstractMatrix, b::AbstractVector, c::Integer,
189+
result::AbstractColoringResult,
190+
)
191+
192+
Decompress the vector `b` corresponding to color `c` in-place into `A`, given a coloring `result` of the sparsity pattern of `A`.
193+
194+
- If `result` comes from a `:nonsymmetric` structure with `:column` partition, this will update the columns of `A` that share color `c` (whose sum makes up `b`).
195+
- If `result` comes from a `:nonsymmetric` structure with `:row` partition, this will update the rows of `A` that share color `c` (whose sum makes up `b`).
196+
- If `result` comes from a `:symmetric` structure with `:column` partition, this will update the coefficients of `A` whose value is deduced from color `c`.
197+
198+
!!! warning
199+
This function will only update some coefficients of `A`, without resetting the rest to zero.
200+
201+
# See also
202+
203+
- [`ColoringProblem`](@ref)
204+
- [`AbstractColoringResult`](@ref)
205+
- [`decompress!`](@ref)
206+
"""
207+
function decompress_single_color! end
208+
209+
## ColumnColoringResult
187210

188211
function decompress!(
189-
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:column}
212+
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
190213
) where {R<:Real}
191-
S = get_matrix(result)
214+
@compat (; S, color) = result
192215
check_same_pattern(A, S)
193216
A .= zero(R)
194-
color = column_colors(result)
195217
rvS = rowvals(S)
196218
for j in axes(S, 2)
219+
cj = color[j]
197220
for k in nzrange(S, j)
198221
i = rvS[k]
199-
cj = color[j]
200222
A[i, j] = B[i, cj]
201223
end
202224
end
203225
return A
204226
end
205227

228+
function decompress_single_color!(
229+
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::ColumnColoringResult
230+
) where {R<:Real}
231+
@compat (; S, group) = result
232+
check_same_pattern(A, S)
233+
view(A, :, group[c]) .= zero(R)
234+
rvS = rowvals(S)
235+
for j in group[c]
236+
for k in nzrange(S, j)
237+
i = rvS[k]
238+
A[i, j] = b[i]
239+
end
240+
end
241+
return A
242+
end
243+
206244
function decompress!(
207-
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:row}
245+
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
208246
) where {R<:Real}
209-
S = get_matrix(result)
247+
@compat (; S, compressed_indices) = result
248+
check_same_pattern(A, S)
249+
nzA = nonzeros(A)
250+
for k in eachindex(nzA, compressed_indices)
251+
nzA[k] = B[compressed_indices[k]]
252+
end
253+
return A
254+
end
255+
256+
## RowColoringResult
257+
258+
function decompress!(
259+
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::RowColoringResult
260+
) where {R<:Real}
261+
@compat (; S, color) = result
210262
check_same_pattern(A, S)
211263
A .= zero(R)
212-
color = row_colors(result)
213264
rvS = rowvals(S)
214265
for j in axes(S, 2)
215266
for k in nzrange(S, j)
@@ -221,28 +272,30 @@ function decompress!(
221272
return A
222273
end
223274

224-
function decompress!(
225-
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:column}
275+
function decompress_single_color!(
276+
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::RowColoringResult
226277
) where {R<:Real}
227-
S = get_matrix(result)
278+
@compat (; S, Sᵀ, group) = result
228279
check_same_pattern(A, S)
229-
nzA = nonzeros(A)
230-
ind = result.compressed_indices
231-
for i in eachindex(nzA, ind)
232-
nzA[i] = B[ind[i]]
280+
view(A, group[c], :) .= zero(R)
281+
rvSᵀ = rowvals(Sᵀ)
282+
for i in group[c]
283+
for k in nzrange(Sᵀ, i)
284+
j = rvSᵀ[k]
285+
A[i, j] = b[j]
286+
end
233287
end
234288
return A
235289
end
236290

237291
function decompress!(
238-
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:row}
292+
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::RowColoringResult
239293
) where {R<:Real}
240-
S = get_matrix(result)
294+
@compat (; S, compressed_indices) = result
241295
check_same_pattern(A, S)
242296
nzA = nonzeros(A)
243-
ind = result.compressed_indices
244-
for i in eachindex(nzA, ind)
245-
nzA[i] = B[ind[i]]
297+
for k in eachindex(nzA, compressed_indices)
298+
nzA[k] = B[compressed_indices[k]]
246299
end
247300
return A
248301
end
@@ -252,16 +305,43 @@ end
252305
function decompress!(
253306
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::StarSetColoringResult
254307
) where {R<:Real}
255-
S = get_matrix(result)
308+
@compat (; S, color, star_set) = result
309+
@compat (; star, hub, spokes) = star_set
256310
check_same_pattern(A, S)
257311
A .= zero(R)
258-
color = column_colors(result)
259-
rvS = rowvals(S)
260-
for j in axes(S, 2)
261-
for k in nzrange(S, j)
262-
i = rvS[k]
263-
k, c = symmetric_coefficient(i, j, color, result.star_set)
264-
A[i, j] = B[k, c]
312+
for i in axes(A, 1)
313+
if !iszero(S[i, i])
314+
A[i, i] = B[i, color[i]]
315+
end
316+
end
317+
for s in eachindex(hub, spokes)
318+
j = abs(hub[s])
319+
for i in spokes[s]
320+
A[i, j] = B[i, color[j]]
321+
A[j, i] = B[i, color[j]]
322+
end
323+
end
324+
return A
325+
end
326+
327+
function decompress_single_color!(
328+
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::StarSetColoringResult
329+
) where {R<:Real}
330+
@compat (; S, color, group, star_set) = result
331+
@compat (; hub, spokes) = star_set
332+
check_same_pattern(A, S)
333+
for i in axes(A, 1)
334+
if !iszero(S[i, i]) && color[i] == c
335+
A[i, i] = b[i]
336+
end
337+
end
338+
for s in eachindex(hub, spokes)
339+
j = abs(hub[s])
340+
if color[j] == c
341+
for i in spokes[s]
342+
A[i, j] = b[i]
343+
A[j, i] = b[i]
344+
end
265345
end
266346
end
267347
return A
@@ -270,12 +350,11 @@ end
270350
function decompress!(
271351
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::StarSetColoringResult
272352
) where {R<:Real}
273-
S = get_matrix(result)
353+
@compat (; S, compressed_indices) = result
274354
check_same_pattern(A, S)
275355
nzA = nonzeros(A)
276-
ind = result.compressed_indices
277-
for i in eachindex(nzA, ind)
278-
nzA[i] = B[ind[i]]
356+
for k in eachindex(nzA, compressed_indices)
357+
nzA[k] = B[compressed_indices[k]]
279358
end
280359
return A
281360
end
@@ -287,11 +366,9 @@ end
287366
function decompress!(
288367
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::TreeSetColoringResult
289368
) where {R<:Real}
290-
S = get_matrix(result)
369+
@compat (; S, color, vertices_by_tree, reverse_bfs_orders, buffer) = result
291370
check_same_pattern(A, S)
292371
A .= zero(R)
293-
color = column_colors(result)
294-
@compat (; vertices_by_tree, reverse_bfs_orders, buffer) = result
295372

296373
if eltype(buffer) == R
297374
buffer_right_type = buffer
@@ -327,10 +404,10 @@ end
327404
function decompress!(
328405
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::LinearSystemColoringResult
329406
) where {R<:Real}
330-
S = get_matrix(result)
407+
@compat (;
408+
S, color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A
409+
) = result
331410
check_same_pattern(A, S)
332-
color = column_colors(result)
333-
@compat (; strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A) = result
334411

335412
# TODO: for some reason I cannot use ldiv! with a sparse QR
336413
strict_upper_nonzeros_A = T_factorization \ vec(B)

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ function coloring(
178178
S = sparse(A)
179179
bg = bipartite_graph(S)
180180
color = partial_distance2_coloring(bg, Val(2), algo.order)
181-
return NonSymmetricColoringResult{:column}(S, color)
181+
return ColumnColoringResult(S, color)
182182
end
183183

184184
function coloring(
@@ -190,7 +190,7 @@ function coloring(
190190
S = sparse(A)
191191
bg = bipartite_graph(S)
192192
color = partial_distance2_coloring(bg, Val(1), algo.order)
193-
return NonSymmetricColoringResult{:row}(S, color)
193+
return RowColoringResult(S, color)
194194
end
195195

196196
function coloring(

0 commit comments

Comments
 (0)