Skip to content

Commit cb1c2c6

Browse files
amontoisongdalle
andauthored
Implement TreeSet-based acyclic decompression (#55)
* Implement TreeSet-based acyclic decompression * Add comments in decompression.jl * Optimize acyclic decompression * Format * Second decompression test --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 6fba67e commit cb1c2c6

2 files changed

Lines changed: 103 additions & 2 deletions

File tree

src/decompression.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,102 @@ function decompress_aux!(
295295
end
296296
return A
297297
end
298+
299+
function decompress_aux!(
300+
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::TreeSetColoringResult
301+
) where {R<:Real}
302+
n = checksquare(A)
303+
A .= zero(R)
304+
S = get_matrix(result)
305+
color = column_colors(result)
306+
307+
# forest is a structure DisjointSets from DataStructures.jl
308+
# - forest.intmap: a dictionary that maps an edge (i, j) to an integer k
309+
# - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j)
310+
# - forest.internal.ngroups: the number of trees in the forest
311+
forest = result.tree_set.forest
312+
ntrees = forest.internal.ngroups
313+
314+
# vector of trees where each tree contains the indices of its edges
315+
trees = [Int[] for i in 1:ntrees]
316+
317+
# dictionary that maps a tree's root to the index of the tree
318+
roots = Dict{Int,Int}()
319+
320+
k = 0
321+
for edge in forest.revmap
322+
root_edge = find_root!(forest, edge)
323+
root = forest.intmap[root_edge]
324+
if !haskey(roots, root)
325+
k += 1
326+
roots[root] = k
327+
end
328+
index_tree = roots[root]
329+
push!(trees[index_tree], forest.intmap[edge])
330+
end
331+
332+
# vector of dictionaries where each dictionary stores the degree of each vertex in a tree
333+
degrees = [Dict{Int,Int}() for k in 1:ntrees]
334+
for k in 1:ntrees
335+
tree = trees[k]
336+
degree = degrees[k]
337+
for edge_index in tree
338+
i, j = forest.revmap[edge_index]
339+
!haskey(degree, i) && (degree[i] = 0)
340+
!haskey(degree, j) && (degree[j] = 0)
341+
degree[i] += 1
342+
degree[j] += 1
343+
end
344+
end
345+
346+
# depth-first search (DFS) traversal order for each tree in the forest
347+
dfs_orders = [Vector{Tuple{Int,Int}}() for k in 1:ntrees]
348+
for k in 1:ntrees
349+
tree = trees[k]
350+
degree = degrees[k]
351+
while sum(values(degree)) != 0
352+
for (t, edge_index) in enumerate(tree)
353+
if edge_index != 0
354+
i, j = forest.revmap[edge_index]
355+
if (degree[i] == 1) || (degree[j] == 1) # leaf vertex
356+
if degree[i] > degree[j] # vertex i is the parent of vertex j
357+
i, j = j, i # ensure that i always denotes a leaf vertex
358+
end
359+
degree[i] -= 1 # decrease the degree of vertex i
360+
degree[j] -= 1 # decrease the degree of vertex j
361+
tree[t] = 0 # remove the edge (i,j)
362+
push!(dfs_orders[k], (i, j))
363+
end
364+
end
365+
end
366+
end
367+
end
368+
369+
# stored_values holds the sum of edge values for subtrees in a tree.
370+
# For each vertex i, stored_values[i] is the sum of edge values in the subtree rooted at i.
371+
stored_values = Vector{R}(undef, n)
372+
373+
# Recover the diagonal coefficients of A
374+
for i in axes(A, 1)
375+
if !iszero(S[i, i])
376+
A[i, i] = B[i, color[i]]
377+
end
378+
end
379+
380+
# Recover the off-diagonal coefficients of A
381+
for k in 1:ntrees
382+
vertices = keys(degrees[k])
383+
for vertex in vertices
384+
stored_values[vertex] = zero(R)
385+
end
386+
387+
tree = dfs_orders[k]
388+
for (i, j) in tree
389+
val = B[i, color[j]] - stored_values[i]
390+
stored_values[j] = stored_values[j] + val
391+
A[i, j] = val
392+
A[j, i] = val
393+
end
394+
end
395+
return A
396+
end

test/utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ function test_coloring_decompression(
2222
B = compress(A, result)
2323
!isnothing(color0) && @test color == color0
2424
!isnothing(B0) && @test B == B0
25-
@test decompress(B, result) A0
2625
@test decompress(B, default_result) A0
27-
@test decompress!(respectful_similar(A), B, result) A0
26+
@test decompress(B, result) A0
27+
@test decompress(B, result) A0 # check result wasn't modified
2828
@test decompress!(respectful_similar(A), B, default_result) A0
29+
@test decompress!(respectful_similar(A), B, result) A0
30+
@test decompress!(respectful_similar(A), B, result) A0
2931
end
3032
@test all(color_vec .== Ref(color_vec[1]))
3133
end

0 commit comments

Comments
 (0)