Skip to content

Commit c3acf52

Browse files
amontoisongdalle
andauthored
More optimization for the decompression of the acyclic coloring (#73)
* More optimization for the decompression of the acyclic coloring * Optimal version! * Reduce dynamic allocations * Avoid modifying `TreeSet` a posteriori * Typo * Renaming * Typo --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent b6c38d9 commit c3acf52

3 files changed

Lines changed: 91 additions & 54 deletions

File tree

src/coloring.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ function acyclic_coloring(g::Graph, order::AbstractOrder)
328328
end
329329
end
330330
end
331+
332+
# compress forest
333+
for edge in forest.revmap
334+
find_root!(forest, edge)
335+
end
336+
331337
return color, TreeSet(forest)
332338
end
333339

@@ -343,8 +349,8 @@ function _prevent_cycle!(
343349
forest::DisjointSets{<:Tuple{Int,Int}},
344350
)
345351
wx = _sort(w, x)
346-
root = find_root!(forest, wx) # edge wx belongs to the 2-colored tree represented by edge "root"
347-
id = forest.intmap[root] # ID of the representative edge "root" of a two-colored tree.
352+
root = find_root!(forest, wx) # edge wx belongs to the 2-colored tree T represented by edge "root"
353+
id = forest.intmap[root] # ID of the representative edge "root" of a two-colored tree T.
348354
(p, q) = first_visit_to_tree[id]
349355
if p != v # T is being visited from vertex v for the first time
350356
vw = _sort(v, w)

src/decompression.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -288,14 +288,12 @@ function decompress_aux!(
288288
A .= zero(R)
289289
S = get_matrix(result)
290290
color = column_colors(result)
291-
@compat (; degrees, dfs_orders, stored_values) = result
291+
@compat (; vertices_by_tree, reverse_bfs_orders, buffer) = result
292292

293-
# stored_values holds the sum of edge values for subtrees in a tree.
294-
# For each vertex i, stored_values[i] is the sum of edge values in the subtree rooted at i.
295-
stored_values_right_type = if R == Float64
296-
stored_values
293+
if eltype(buffer) == R
294+
buffer_right_type = buffer
297295
else
298-
similar(stored_values, R)
296+
buffer_right_type = similar(buffer, R)
299297
end
300298

301299
# Recover the diagonal coefficients of A
@@ -306,16 +304,14 @@ function decompress_aux!(
306304
end
307305

308306
# Recover the off-diagonal coefficients of A
309-
for k in eachindex(degrees, dfs_orders)
310-
vertices = keys(degrees[k])
311-
for vertex in vertices
312-
stored_values_right_type[vertex] = zero(R)
307+
for k in eachindex(vertices_by_tree, reverse_bfs_orders)
308+
for vertex in vertices_by_tree[k]
309+
buffer_right_type[vertex] = zero(R)
313310
end
314311

315-
tree = dfs_orders[k]
316-
for (i, j) in tree
317-
val = B[i, color[j]] - stored_values_right_type[i]
318-
stored_values_right_type[j] = stored_values_right_type[j] + val
312+
for (i, j) in reverse_bfs_orders[k]
313+
val = B[i, color[j]] - buffer_right_type[i]
314+
buffer_right_type[j] = buffer_right_type[j] + val
319315
A[i, j] = val
320316
A[j, i] = val
321317
end

src/result.jl

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ struct TreeSetColoringResult{M,R} <:
198198
S::M
199199
color::Vector{Int}
200200
group::Vector{Vector{Int}}
201-
tree_set::TreeSet
202-
degrees::Vector{Dict{Int,Int}}
203-
dfs_orders::Vector{Vector{Tuple{Int,Int}}}
204-
stored_values::Vector{R}
201+
vertices_by_tree::Vector{Vector{Int}}
202+
reverse_bfs_orders::Vector{Vector{Tuple{Int,Int}}}
203+
buffer::Vector{R}
205204
end
206205

207206
function TreeSetColoringResult(
208207
S::SparseMatrixCSC, color::Vector{Int}, tree_set::TreeSet, decompression_eltype::Type{R}
209208
) where {R}
209+
nvertices = length(color)
210210
group = group_by_color(color)
211211

212212
# forest is a structure DisjointSets from DataStructures.jl
@@ -216,66 +216,101 @@ function TreeSetColoringResult(
216216
forest = tree_set.forest
217217
ntrees = forest.internal.ngroups
218218

219-
# vector of trees where each tree contains the indices of its edges
220-
trees = [Int[] for i in 1:ntrees]
221-
222219
# dictionary that maps a tree's root to the index of the tree
223220
roots = Dict{Int,Int}()
224221

222+
# vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree
223+
trees = [Dict{Int,Vector{Int}}() for i in 1:ntrees]
224+
225+
# counter of the number of roots found
225226
k = 0
226227
for edge in forest.revmap
228+
i, j = edge
229+
# forest has already been compressed so this doesn't change its state
230+
# I wanted to use find_root but it is deprecated
227231
root_edge = find_root!(forest, edge)
228232
root = forest.intmap[root_edge]
233+
234+
# Update roots
229235
if !haskey(roots, root)
230236
k += 1
231237
roots[root] = k
232238
end
239+
240+
# index of the tree T that contains this edge
233241
index_tree = roots[root]
234-
push!(trees[index_tree], forest.intmap[edge])
235-
end
236242

237-
# vector of dictionaries where each dictionary stores the degree of each vertex in a tree
238-
degrees = [Dict{Int,Int}() for k in 1:ntrees]
239-
for k in 1:ntrees
240-
tree = trees[k]
241-
degree = degrees[k]
242-
for edge_index in tree
243-
i, j = forest.revmap[edge_index]
244-
!haskey(degree, i) && (degree[i] = 0)
245-
!haskey(degree, j) && (degree[j] = 0)
246-
degree[i] += 1
247-
degree[j] += 1
243+
# Update the neighbors of i in the tree T
244+
if !haskey(trees[index_tree], i)
245+
trees[index_tree][i] = [j]
246+
else
247+
push!(trees[index_tree][i], j)
248+
end
249+
250+
# Update the neighbors of j in the tree T
251+
if !haskey(trees[index_tree], j)
252+
trees[index_tree][j] = [i]
253+
else
254+
push!(trees[index_tree][j], i)
248255
end
249256
end
250257

251-
# depth-first search (DFS) traversal order for each tree in the forest
252-
dfs_orders = [Vector{Tuple{Int,Int}}() for k in 1:ntrees]
258+
# degrees is a vector of integers that stores the degree of each vertex in a tree
259+
degrees = Vector{Int}(undef, nvertices)
260+
261+
# list of vertices for each tree in the forest
262+
vertices_by_tree = [collect(keys(trees[i])) for i in 1:ntrees]
263+
264+
# reverse breadth first (BFS) traversal order for each tree in the forest
265+
reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:ntrees]
266+
253267
for k in 1:ntrees
254268
tree = trees[k]
255-
degree = degrees[k]
256-
while sum(values(degree)) != 0
257-
for (t, edge_index) in enumerate(tree)
258-
if edge_index != 0
259-
i, j = forest.revmap[edge_index]
260-
if (degree[i] == 1) || (degree[j] == 1) # leaf vertex
261-
if degree[i] > degree[j] # vertex i is the parent of vertex j
262-
i, j = j, i # ensure that i always denotes a leaf vertex
263-
end
264-
degree[i] -= 1 # decrease the degree of vertex i
265-
degree[j] -= 1 # decrease the degree of vertex j
266-
tree[t] = 0 # remove the edge (i,j)
267-
push!(dfs_orders[k], (i, j))
269+
270+
# queue to store the leaves
271+
queue = Int[]
272+
273+
# compute the degree of each vertex in the tree
274+
for (vertex, neighbors) in trees[k]
275+
degree = length(neighbors)
276+
degrees[vertex] = degree
277+
278+
# the vertex is a leaf
279+
if degree == 1
280+
push!(queue, vertex)
281+
end
282+
end
283+
284+
# continue until all leaves are treated
285+
while !isempty(queue)
286+
leaf = pop!(queue)
287+
288+
# Convenient way to specify that the vertex is removed
289+
degrees[leaf] = 0
290+
291+
for neighbor in tree[leaf]
292+
if degrees[neighbor] != 0
293+
# (leaf, neighbor) represents the next edge to visit during decompression
294+
push!(reverse_bfs_orders[k], (leaf, neighbor))
295+
296+
# reduce the degree of all neighbors
297+
degrees[neighbor] -= 1
298+
299+
# check if the neighbor is now a leaf
300+
if degrees[neighbor] == 1
301+
push!(queue, neighbor)
268302
end
269303
end
270304
end
271305
end
272306
end
273307

274-
n = checksquare(S)
275-
stored_values = Vector{R}(undef, n)
308+
# buffer holds the sum of edge values for subtrees in a tree.
309+
# For each vertex i, buffer[i] is the sum of edge values in the subtree rooted at i.
310+
buffer = Vector{R}(undef, nvertices)
276311

277312
return TreeSetColoringResult(
278-
S, color, group, tree_set, degrees, dfs_orders, stored_values
313+
S, color, group, vertices_by_tree, reverse_bfs_orders, buffer
279314
)
280315
end
281316

0 commit comments

Comments
 (0)