Skip to content

Commit 07b67e8

Browse files
authored
Optimal sparse decompression (#223)
* Optimal sparse decompression * Fix reverse mode dispatch with major-respecting similar
1 parent 09a7da0 commit 07b67e8

7 files changed

Lines changed: 67 additions & 57 deletions

File tree

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using ADTypes:
3131
using DocStringExtensions
3232
using FillArrays: OneElement
3333
using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose
34-
using SparseArrays: SparseMatrixCSC, nzrange, rowvals, sparse
34+
using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse
3535

3636
abstract type Extras end
3737

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
"""
22
CompressedMatrix{dir}
33
4-
Compressed representation `B` of a sparse matrix `A ∈ ℝ^{m×n}` obtained by summing some of its columns (if `dir == :col`) or rows (if `dir == :row`), grouped by color.
4+
Compressed representation `B` of a `(m, n)` sparse matrix `A` obtained by summing some of its columns (if `dir == :col`) or rows (if `dir == :row`) if they have the same color.
55
66
# Fields
77
8-
- `sparsity::AbstractMatrix{Bool}`: boolean sparsity pattern of the matrix `A`
9-
- `colors::Vector{Int}`: vector such that
10-
- if `dir == `:col`, then `colors[j] ∈ 1:c` is the color of column `j`
11-
- if `dir == `:row`, then `colors[i] ∈ 1:c` is the color of row `i`
12-
- `groups::Vector{Vector{Int}}`: vector of length `c` such that
13-
- if `dir == :col`, then `groups[k]` is the vector of column indices assigned to the same color `k ∈ 1:c`
14-
- if `dir == :row`, then `groups[k]` is the vector of row indices assigned to the same color `k ∈ 1:c`
15-
- `aggregates::AbstractMatrix`: matrix `B` such that
16-
- if `dir == :col`, then `size(B) = (m, c)` and `B[:, c] = sum(A[:, k] for k in groups[c])`
17-
- if `dir == :row`, then `size(B) = (c, n)` and `B[c, :] = sum(A[k, :] for k in groups[c])`
8+
| field | type | size | meaning | if `dir` is `:col` | if `dir` is `:row` |
9+
| :----------- | :----------------------- | :------------------- | :------------------------- | :-------------------------------- | :-------------------------------- |
10+
| `sparsity` | `AbstractMatrix{Bool}` | `(m, n)` | sparsity pattern | column-major | row-major |
11+
| `colors` | `Vector{Int}` | `n` or `m` | color assignments in `1:c` | `colors[j]` of col `j` | `colors[i]` of row `i` |
12+
| `groups` | `Vector{Vector{Int}}` | `c` | groups with same color | `groups[k] = {j : colors[j] = k}` | `groups[k] = {i : colors[i] = k}` |
13+
| `aggregates` | `AbstractMatrix{<:Real}` | `(m, c)` or `(c, n)` | color-summed values `B` | `B[:, c] = sum(A[:, groups[k]])` | `B[c, :] = sum(A[groups[k], :])` |
1814
"""
1915
mutable struct CompressedMatrix{dir,S<:AbstractMatrix{Bool},M<:AbstractMatrix}
2016
sparsity::S
@@ -35,39 +31,62 @@ function CompressedMatrix{dir}(sparsity, colors, groups, aggregates) where {dir}
3531
)
3632
end
3733

34+
## Column decompression
35+
3836
function decompress!(A::AbstractMatrix, compressed::CompressedMatrix{:col})
3937
(; sparsity, colors, aggregates) = compressed
4038
A .= zero(eltype(A))
4139
@views for j in axes(A, 2)
4240
k = colors[j]
43-
nz_rows_j = (!iszero).(sparsity[:, j])
44-
copyto!(A[nz_rows_j, j], aggregates[nz_rows_j, k])
41+
rows_j = (!iszero).(sparsity[:, j])
42+
copyto!(A[rows_j, j], aggregates[rows_j, k])
43+
end
44+
return A
45+
end
46+
47+
function decompress!(
48+
A::SparseMatrixCSC, compressed::CompressedMatrix{:col,<:SparseMatrixCSC}
49+
)
50+
# A and compressed.sparsity have the same pattern
51+
(; colors, aggregates) = compressed
52+
Anz, Arv = nonzeros(A), rowvals(A)
53+
Anz .= zero(eltype(A))
54+
@views for j in axes(A, 2)
55+
k = colors[j]
56+
nzrange_j = nzrange(A, j)
57+
rows_j = Arv[nzrange_j]
58+
copyto!(Anz[nzrange_j], aggregates[rows_j, k])
4559
end
4660
return A
4761
end
4862

63+
## Row decompression
64+
4965
function decompress!(A::AbstractMatrix, compressed::CompressedMatrix{:row})
5066
(; sparsity, colors, aggregates) = compressed
5167
A .= zero(eltype(A))
5268
@views for i in axes(A, 1)
5369
k = colors[i]
54-
nz_cols_i = (!iszero).(sparsity[i, :])
55-
copyto!(A[i, nz_cols_i], aggregates[k, nz_cols_i])
70+
cols_i = (!iszero).(sparsity[i, :])
71+
copyto!(A[i, cols_i], aggregates[k, cols_i])
5672
end
5773
return A
5874
end
5975

60-
function decompress_symmetric!(A::AbstractMatrix, compressed::CompressedMatrix{:col})
61-
(; sparsity, colors, groups, aggregates) = compressed
62-
@views for j in axes(A, 2)
63-
k = colors[j]
64-
group = groups[k]
65-
for i in axes(A, 1)
66-
if (!iszero(sparsity[i, j]) && count(!iszero, sparsity[i, group]) == 1)
67-
A[i, j] = aggregates[i, k]
68-
A[j, i] = aggregates[i, k]
69-
end
70-
end
76+
function decompress!(
77+
A::Transpose{<:Any,<:SparseMatrixCSC},
78+
compressed::CompressedMatrix{:row,<:Transpose{<:Any,<:SparseMatrixCSC}},
79+
)
80+
# A and compressed.sparsity have the same pattern
81+
(; colors, aggregates) = compressed
82+
PA = parent(A)
83+
PAnz, PArv = nonzeros(PA), rowvals(PA)
84+
PAnz .= zero(eltype(A))
85+
@views for i in axes(A, 1)
86+
k = colors[i]
87+
nzrange_i = nzrange(PA, i)
88+
cols_i = PArv[nzrange_i]
89+
copyto!(PAnz[nzrange_i], aggregates[k, cols_i])
7190
end
7291
return A
7392
end

DifferentiationInterface/src/sparse/hessian.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ end
1010
## Hessian, one argument
1111

1212
function prepare_hessian(f, backend::AutoSparse, x)
13-
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
14-
colors = symmetric_coloring(sparsity, coloring_algorithm(backend))
13+
initial_sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
14+
sparsity = col_major(initial_sparsity)
15+
colors = column_coloring(sparsity, coloring_algorithm(backend))
1516
groups = get_groups(colors)
1617
seeds = map(groups) do group
1718
seed = zero(x)
@@ -33,7 +34,7 @@ function hessian!(f, hess, backend::AutoSparse, x, extras::SparseHessianExtras)
3334
hvp!(f, products[k], backend, x, seeds[k], hvp_extras)
3435
copyto!(view(compressed.aggregates, :, k), vec(products[k]))
3536
end
36-
decompress_symmetric!(hess, compressed)
37+
decompress!(hess, compressed)
3738
return hess
3839
end
3940

DifferentiationInterface/src/sparse/jacobian.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ end
2323

2424
function prepare_jacobian(f, backend::AutoSparse, x)
2525
y = f(x)
26-
sparsity = jacobian_sparsity(f, x, sparsity_detector(backend))
26+
initial_sparsity = jacobian_sparsity(f, x, sparsity_detector(backend))
2727
if Bool(pushforward_performance(backend))
28+
sparsity = col_major(initial_sparsity)
2829
colors = column_coloring(sparsity, coloring_algorithm(backend))
2930
groups = get_groups(colors)
3031
seeds = map(groups) do group
@@ -39,6 +40,7 @@ function prepare_jacobian(f, backend::AutoSparse, x)
3940
aggregates = stack(vec, products; dims=2)
4041
compressed = CompressedMatrix{:col}(sparsity, colors, groups, aggregates)
4142
else
43+
sparsity = row_major(initial_sparsity)
4244
colors = row_coloring(sparsity, coloring_algorithm(backend))
4345
groups = get_groups(colors)
4446
seeds = map(groups) do group
@@ -77,7 +79,7 @@ function jacobian!(f, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{
7779
end
7880

7981
function jacobian(f, backend::AutoSparse, x, extras::SparseJacobianExtras{1})
80-
jac = similar(extras.compressed.sparsity, eltype(x))
82+
jac = major_respecting_similar(extras.compressed.sparsity, eltype(x))
8183
return jacobian!(f, jac, backend, x, extras)
8284
end
8385

@@ -94,8 +96,9 @@ end
9496
## Jacobian, two arguments
9597

9698
function prepare_jacobian(f!, y, backend::AutoSparse, x)
97-
sparsity = jacobian_sparsity(f!, y, x, sparsity_detector(backend))
99+
initial_sparsity = jacobian_sparsity(f!, y, x, sparsity_detector(backend))
98100
if Bool(pushforward_performance(backend))
101+
sparsity = col_major(initial_sparsity)
99102
colors = column_coloring(sparsity, coloring_algorithm(backend))
100103
groups = get_groups(colors)
101104
seeds = map(groups) do group
@@ -110,6 +113,7 @@ function prepare_jacobian(f!, y, backend::AutoSparse, x)
110113
aggregates = stack(vec, products; dims=2)
111114
compressed = CompressedMatrix{:col}(sparsity, colors, groups, aggregates)
112115
else
116+
sparsity = row_major(initial_sparsity)
113117
colors = row_coloring(sparsity, coloring_algorithm(backend))
114118
groups = get_groups(colors)
115119
seeds = map(groups) do group
@@ -148,7 +152,7 @@ function jacobian!(f!, y, jac, backend::AutoSparse, x, extras::SparseJacobianExt
148152
end
149153

150154
function jacobian(f!, y, backend::AutoSparse, x, extras::SparseJacobianExtras{2})
151-
jac = similar(extras.compressed.sparsity, eltype(x))
155+
jac = major_respecting_similar(extras.compressed.sparsity, eltype(x))
152156
return jacobian!(f!, y, jac, backend, x, extras)
153157
end
154158

DifferentiationInterface/src/sparse/matrices.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ Construct a row-major representation of the matrix `A`.
1616
row_major(A::M) where {M<:AbstractMatrix} = transpose(M(transpose(A)))
1717
row_major(A::Transpose{<:Any,M}) where {M<:AbstractMatrix} = A
1818

19+
## Similar
20+
21+
major_respecting_similar(A::AbstractMatrix, ::Type{T}) where {T} = similar(A, T)
22+
23+
function major_respecting_similar(A::Transpose, ::Type{T}) where {T}
24+
return transpose(similar(parent(A), T))
25+
end
26+
1927
## Generic nz
2028

2129
function nz_in_col(A_colmajor::AbstractMatrix, j::Integer)

DifferentiationInterface/test/sparsity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ sparse_backends = [
55
AutoSparse(AutoFastDifferentiation()),
66
AutoSparse(AutoSymbolics()),
77
AutoSparse(AutoForwardDiff(); sparsity_detector, coloring_algorithm),
8-
AutoSparse(AutoZygote(); sparsity_detector, coloring_algorithm),
8+
AutoSparse(AutoEnzyme(Enzyme.Reverse); sparsity_detector, coloring_algorithm),
99
]
1010

1111
sparse_second_order_backends = [

DifferentiationInterfaceTest/src/tests/sparsity.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ function test_sparsity(
1414
_, jac1 = value_and_jacobian(f, ba, x, extras)
1515
jac2 = jacobian(f, ba, x, extras)
1616

17-
@testset "Sparse type" begin
18-
@test jac1 isa SparseMatrixCSC
19-
@test jac2 isa SparseMatrixCSC
20-
end
2117
@testset "Sparsity pattern" begin
2218
@test nnz(jac1) == nnz(jac_true)
2319
@test nnz(jac2) == nnz(jac_true)
@@ -37,10 +33,6 @@ function test_sparsity(ba::AbstractADType, scen::JacobianScenario{1,:inplace}; r
3733
_, jac1 = value_and_jacobian!(f, mysimilar(jac_true), ba, x, extras)
3834
jac2 = jacobian!(f, mysimilar(jac_true), ba, x, extras)
3935

40-
@testset "Sparse type" begin
41-
@test jac1 isa SparseMatrixCSC
42-
@test jac2 isa SparseMatrixCSC
43-
end
4436
@testset "Sparsity pattern" begin
4537
@test nnz(jac1) == nnz(jac_true)
4638
@test nnz(jac2) == nnz(jac_true)
@@ -63,10 +55,6 @@ function test_sparsity(
6355
_, jac1 = value_and_jacobian(f!, mysimilar(y), ba, x, extras)
6456
jac2 = jacobian(f!, mysimilar(y), ba, x, extras)
6557

66-
@testset "Sparse type" begin
67-
@test jac1 isa SparseMatrixCSC
68-
@test jac2 isa SparseMatrixCSC
69-
end
7058
@testset "Sparsity pattern" begin
7159
@test nnz(jac1) == nnz(jac_true)
7260
@test nnz(jac2) == nnz(jac_true)
@@ -87,10 +75,6 @@ function test_sparsity(ba::AbstractADType, scen::JacobianScenario{2,:inplace}; r
8775
_, jac1 = value_and_jacobian!(f!, mysimilar(y), mysimilar(jac_true), ba, x, extras)
8876
jac2 = jacobian!(f!, mysimilar(y), mysimilar(jac_true), ba, x, extras)
8977

90-
@testset "Sparse type" begin
91-
@test jac1 isa SparseMatrixCSC
92-
@test jac2 isa SparseMatrixCSC
93-
end
9478
@testset "Sparsity pattern" begin
9579
@test nnz(jac1) == nnz(jac_true)
9680
@test nnz(jac2) == nnz(jac_true)
@@ -113,9 +97,6 @@ function test_sparsity(
11397

11498
hess1 = hessian(f, ba, x, extras)
11599

116-
@testset "Sparse type" begin
117-
@test hess1 isa SparseMatrixCSC
118-
end
119100
@testset "Sparsity pattern" begin
120101
@test nnz(hess1) == nnz(hess_true)
121102
end
@@ -133,9 +114,6 @@ function test_sparsity(ba::AbstractADType, scen::HessianScenario{1,:inplace}; re
133114

134115
hess1 = hessian!(f, mysimilar(hess_true), ba, x, extras)
135116

136-
@testset "Sparse type" begin
137-
@test hess1 isa SparseMatrixCSC
138-
end
139117
@testset "Sparsity pattern" begin
140118
@test nnz(hess1) == nnz(hess_true)
141119
end

0 commit comments

Comments
 (0)