Skip to content

Commit d4a9d9e

Browse files
authored
Separate compressed evaluation and decompression for sparse (#212)
1 parent 7b672a0 commit d4a9d9e

19 files changed

Lines changed: 556 additions & 367 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88
**/Manifest.toml
99
**/docs/Manifest.toml
1010

11-
*.csv
11+
*.csv
12+
13+
playground.jl

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ module DifferentiationInterfaceEnzymeExt
33
using ADTypes: ADTypes, AutoEnzyme
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6+
DerivativeExtras,
7+
GradientExtras,
8+
JacobianExtras,
9+
PullbackExtras,
10+
PushforwardExtras,
611
NoDerivativeExtras,
712
NoGradientExtras,
813
NoJacobianExtras,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737

3838
## Gradient
3939

40-
struct EnzymeForwardGradientExtras{C,O}
40+
struct EnzymeForwardGradientExtras{C,O} <: GradientExtras
4141
shadow::O
4242
end
4343

@@ -76,7 +76,7 @@ end
7676

7777
## Jacobian
7878

79-
struct EnzymeForwardOneArgJacobianExtras{C,O}
79+
struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras
8080
shadow::O
8181
end
8282

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@ module DifferentiationInterfaceFiniteDiffExt
33
using ADTypes: AutoFiniteDiff
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6+
DerivativeExtras,
7+
GradientExtras,
8+
HessianExtras,
9+
JacobianExtras,
10+
PullbackExtras,
11+
PushforwardExtras,
612
NoDerivativeExtras,
713
NoGradientExtras,
814
NoHessianExtras,
915
NoJacobianExtras,
1016
NoPullbackExtras,
11-
NoPushforwardExtras,
12-
GradientExtras
17+
NoPushforwardExtras
1318
using FiniteDiff:
1419
DerivativeCache,
1520
GradientCache,

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919

2020
## Derivative
2121

22-
struct FiniteDiffOneArgDerivativeExtras{C}
22+
struct FiniteDiffOneArgDerivativeExtras{C} <: DerivativeExtras
2323
cache::C
2424
end
2525

@@ -128,7 +128,7 @@ end
128128

129129
## Jacobian
130130

131-
struct FiniteDiffOneArgJacobianExtras{C}
131+
struct FiniteDiffOneArgJacobianExtras{C} <: JacobianExtras
132132
cache::C
133133
end
134134

@@ -166,7 +166,7 @@ end
166166

167167
## Hessian
168168

169-
struct FiniteDiffHessianExtras{C}
169+
struct FiniteDiffHessianExtras{C} <: HessianExtras
170170
cache::C
171171
end
172172

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919

2020
## Derivative
2121

22-
struct FiniteDiffTwoArgDerivativeExtras{C}
22+
struct FiniteDiffTwoArgDerivativeExtras{C} <: DerivativeExtras
2323
cache::C
2424
end
2525

@@ -61,7 +61,7 @@ end
6161

6262
## Jacobian
6363

64-
struct FiniteDiffTwoArgJacobianExtras{C}
64+
struct FiniteDiffTwoArgJacobianExtras{C} <: JacobianExtras
6565
cache::C
6666
end
6767

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
struct SymbolicsSparsityDetector <: ADTypes.AbstractSparsityDetector end
2-
3-
function ADTypes.jacobian_sparsity(f, x, ::SymbolicsSparsityDetector)
1+
function ADTypes.jacobian_sparsity(f, x, ::DI.SymbolicsSparsityDetector)
42
y = similar(f(x))
53
f!(y, x) = copyto!(y, f(x))
64
return jacobian_sparsity(f!, y, x)
75
end
86

9-
function ADTypes.jacobian_sparsity(f!, y, x, ::SymbolicsSparsityDetector)
7+
function ADTypes.jacobian_sparsity(f!, y, x, ::DI.SymbolicsSparsityDetector)
108
f!_vec(y_vec, x_vec) = f!(reshape(y_vec, size(y)), reshape(x_vec, size(x)))
119
return jacobian_sparsity(f!_vec, vec(y), vec(x))
1210
end
1311

14-
function ADTypes.hessian_sparsity(f, x, ::SymbolicsSparsityDetector)
12+
function ADTypes.hessian_sparsity(f, x, ::DI.SymbolicsSparsityDetector)
1513
f_vec(x_vec) = f(reshape(x_vec, size(x)))
1614
return hessian_sparsity(f_vec, vec(x))
1715
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ include("second_order/second_derivative.jl")
5252
include("second_order/hvp.jl")
5353
include("second_order/hessian.jl")
5454

55+
include("sparse/detector.jl")
5556
include("sparse/coloring.jl")
56-
include("sparse/sparse.jl")
57+
include("sparse/compressed_matrix.jl")
58+
include("sparse/fallbacks.jl")
59+
include("sparse/jacobian.jl")
60+
include("sparse/hessian.jl")
5761

5862
export SecondOrder
5963

DifferentiationInterface/src/sparse/coloring.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ This graph is defined as `G = (R, C, E)` where `R = 1:m` is the set of row indic
2626
# Fields
2727
2828
- `A::AbstractMatrix`
29+
30+
# Reference
31+
32+
> [What Color Is Your Jacobian? Graph Coloring for Computing Derivatives](https://epubs.siam.org/doi/abs/10.1137/S0036144504444711), Gebremedhin et al. (2005)
2933
"""
3034
struct BipartiteGraph{M<:AbstractMatrix} <: AbstractMatrixGraph
3135
A::M
@@ -127,6 +131,10 @@ This graph is defined as `G = (C, E)` where `C = 1:n` is the set of columns and
127131
# Fields
128132
129133
- `A::AbstractMatrix`
134+
135+
# Reference
136+
137+
> [What Color Is Your Jacobian? Graph Coloring for Computing Derivatives](https://epubs.siam.org/doi/abs/10.1137/S0036144504444711), Gebremedhin et al. (2005)
130138
"""
131139
struct AdjacencyGraph{M<:AbstractMatrix} <: AbstractMatrixGraph
132140
A::M
@@ -196,6 +204,23 @@ end
196204

197205
## ADTypes overloads
198206

207+
"""
208+
GreedyColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm
209+
210+
Matrix coloring algorithm for sparse Jacobians and Hessians.
211+
212+
Compatible with the [ADTypes.jl coloring framework](https://sciml.github.io/ADTypes.jl/stable/#Coloring-algorithm).
213+
214+
# See also
215+
216+
- `ADTypes.column_coloring`
217+
- `ADTypes.row_coloring`
218+
- `ADTypes.symmetric_coloring`
219+
220+
# Reference
221+
222+
> [What Color Is Your Jacobian? Graph Coloring for Computing Derivatives](https://epubs.siam.org/doi/abs/10.1137/S0036144504444711), Gebremedhin et al. (2005)
223+
"""
199224
struct GreedyColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end
200225

201226
function ADTypes.column_coloring(A, ::GreedyColoringAlgorithm)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
CompressedMatrix{dir}
3+
4+
Compressed representation `B` of a sparse matrix `A ∈ ℝ^{m×n}` obtained by summing some of its columns (if `dir == :col`) or rows (if `dir == :row`), grouped by color.
5+
6+
# Fields
7+
8+
- `sparsity::AbstractMatrix{Bool}`: boolean sparsity pattern of the matrix `A`
9+
- `colors::Vector{Int}`: vector such that
10+
- if `dir == `:col`, then `colors[j] ∈ 1:c` is the color of column `j`
11+
- if `dir == `:row`, then `colors[i] ∈ 1:c` is the color of row `i`
12+
- `groups::Vector{Vector{Int}}`: vector of length `c` such that
13+
- if `dir == :col`, then `groups[k]` is the vector of column indices assigned to the same color `k ∈ 1:c`
14+
- if `dir == :row`, then `groups[k]` is the vector of row indices assigned to the same color `k ∈ 1:c`
15+
- `aggregates::AbstractMatrix`: matrix `B` such that
16+
- if `dir == :col`, then `size(B) = (m, c)` and `B[:, c] = sum(A[:, k] for k in groups[c])`
17+
- if `dir == :row`, then `size(B) = (c, n)` and `B[c, :] = sum(A[k, :] for k in groups[c])`
18+
"""
19+
mutable struct CompressedMatrix{dir,S<:AbstractMatrix{Bool},M<:AbstractMatrix}
20+
sparsity::S
21+
colors::Vector{Int}
22+
groups::Vector{Vector{Int}}
23+
aggregates::M
24+
end
25+
26+
"""
27+
CompressedMatrix{dir}(sparsity, groups, aggregates)
28+
29+
Constructor for [`CompressedMatrix`](@ref).
30+
"""
31+
function CompressedMatrix{dir}(sparsity, colors, groups, aggregates) where {dir}
32+
@assert dir in (:col, :row)
33+
return CompressedMatrix{dir,typeof(sparsity),typeof(aggregates)}(
34+
sparsity, colors, groups, aggregates
35+
)
36+
end
37+
38+
function decompress!(A::AbstractMatrix, compressed::CompressedMatrix{:col})
39+
(; sparsity, colors, aggregates) = compressed
40+
A .= zero(eltype(A))
41+
@views for j in axes(A, 2)
42+
k = colors[j]
43+
nz_rows_j = (!iszero).(sparsity[:, j])
44+
copyto!(A[nz_rows_j, j], aggregates[nz_rows_j, k])
45+
end
46+
return A
47+
end
48+
49+
function decompress!(A::AbstractMatrix, compressed::CompressedMatrix{:row})
50+
(; sparsity, colors, aggregates) = compressed
51+
A .= zero(eltype(A))
52+
@views for i in axes(A, 1)
53+
k = colors[i]
54+
nz_cols_i = (!iszero).(sparsity[i, :])
55+
copyto!(A[i, nz_cols_i], aggregates[k, nz_cols_i])
56+
end
57+
return A
58+
end
59+
60+
function decompress_symmetric!(A::AbstractMatrix, compressed::CompressedMatrix{:col})
61+
(; sparsity, colors, groups, aggregates) = compressed
62+
@views for j in axes(A, 2)
63+
k = colors[j]
64+
group = groups[k]
65+
for i in axes(A, 1)
66+
if (!iszero(sparsity[i, j]) && count(!iszero, sparsity[i, group]) == 1)
67+
A[i, j] = aggregates[i, k]
68+
A[j, i] = aggregates[i, k]
69+
end
70+
end
71+
end
72+
return A
73+
end

0 commit comments

Comments
 (0)