Skip to content

Commit 09a7da0

Browse files
authored
More efficient sparsity handling (#222)
1 parent e788f03 commit 09a7da0

9 files changed

Lines changed: 129 additions & 27 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112

1213
[weakdeps]
1314
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -54,6 +55,7 @@ ForwardDiff = "0.10.36"
5455
LinearAlgebra = "1"
5556
PolyesterForwardDiff = "0.1.1"
5657
ReverseDiff = "1.15.1"
58+
SparseArrays = "1"
5759
Symbolics = "5.27.1"
5860
Tapir = "0.1.2"
5961
Test = "1"
@@ -84,4 +86,25 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
8486
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8587

8688
[targets]
87-
test = ["ADTypes", "Aqua", "Diffractor", "Documenter", "Enzyme", "FastDifferentiation", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "PolyesterForwardDiff", "ReverseDiff", "SparseArrays", "Symbolics", "Tapir", "Test", "Tracker", "Zygote"]
89+
test = [
90+
"ADTypes",
91+
"Aqua",
92+
"Diffractor",
93+
"Documenter",
94+
"Enzyme",
95+
"FastDifferentiation",
96+
"FiniteDiff",
97+
"FiniteDifferences",
98+
"ForwardDiff",
99+
"JET",
100+
"JuliaFormatter",
101+
"Pkg",
102+
"PolyesterForwardDiff",
103+
"ReverseDiff",
104+
"SparseArrays",
105+
"Symbolics",
106+
"Tapir",
107+
"Test",
108+
"Tracker",
109+
"Zygote",
110+
]

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ using ADTypes:
3030
AutoZygote
3131
using DocStringExtensions
3232
using FillArrays: OneElement
33-
using LinearAlgebra: Symmetric, dot
33+
using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose
34+
using SparseArrays: SparseMatrixCSC, nzrange, rowvals, sparse
3435

3536
abstract type Extras end
3637

@@ -54,6 +55,7 @@ include("second_order/hvp.jl")
5455
include("second_order/hessian.jl")
5556

5657
include("sparse/detector.jl")
58+
include("sparse/matrices.jl")
5759
include("sparse/coloring.jl")
5860
include("sparse/compressed_matrix.jl")
5961
include("sparse/fallbacks.jl")

DifferentiationInterface/src/sparse/coloring.jl

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ end
1010

1111
abstract type AbstractMatrixGraph end
1212

13-
Base.size(g::AbstractMatrixGraph, args...) = size(g.A, args...)
14-
rows(g::AbstractMatrixGraph) = axes(g.A, 1)
15-
columns(g::AbstractMatrixGraph) = axes(g.A, 2)
13+
rows(g::AbstractMatrixGraph) = axes(g.A_colmajor, 1)
14+
columns(g::AbstractMatrixGraph) = axes(g.A_colmajor, 2)
1615

1716
## Jacobian coloring
1817

@@ -25,22 +24,33 @@ This graph is defined as `G = (R, C, E)` where `R = 1:m` is the set of row indic
2524
2625
# Fields
2726
28-
- `A::AbstractMatrix`
27+
- `A_colmajor::AbstractMatrix`: output of [`col_major`](@ref) applied to `A`
28+
- `A_rowmajor::AbstractMatrix`: output of [`row_major`](@ref) applied to `A`
2929
3030
# Reference
3131
3232
> [What Color Is Your Jacobian? Graph Coloring for Computing Derivatives](https://epubs.siam.org/doi/abs/10.1137/S0036144504444711), Gebremedhin et al. (2005)
3333
"""
34-
struct BipartiteGraph{M<:AbstractMatrix} <: AbstractMatrixGraph
35-
A::M
34+
struct BipartiteGraph{M1<:AbstractMatrix,M2<:AbstractMatrix} <: AbstractMatrixGraph
35+
A_colmajor::M1
36+
A_rowmajor::M2
37+
38+
function BipartiteGraph(A::AbstractMatrix)
39+
A_colmajor = col_major(A)
40+
A_rowmajor = row_major(A)
41+
return new{typeof(A_colmajor),typeof(A_rowmajor)}(A_colmajor, A_rowmajor)
42+
end
3643
end
3744

38-
function neighbors_of_row(g::BipartiteGraph, i::Integer)
39-
return filter(j -> !iszero(g.A[i, j]), columns(g))
40-
end
45+
neighbors_of_column(g::BipartiteGraph, j::Integer) = nz_in_col(g.A_colmajor, j)
46+
neighbors_of_row(g::BipartiteGraph, i::Integer) = nz_in_row(g.A_rowmajor, i)
4147

42-
function neighbors_of_column(g::BipartiteGraph, j::Integer)
43-
return filter(i -> !iszero(g.A[i, j]), rows(g))
48+
function colored_neighbors_of_column(
49+
g::BipartiteGraph, j::Integer, colors::AbstractVector{<:Integer}
50+
)
51+
return filter(neighbors_of_column(g, j)) do i
52+
!iszero(colors[i])
53+
end
4454
end
4555

4656
function colored_neighbors_of_row(
@@ -51,14 +61,6 @@ function colored_neighbors_of_row(
5161
end
5262
end
5363

54-
function colored_neighbors_of_column(
55-
g::BipartiteGraph, j::Integer, colors::AbstractVector{<:Integer}
56-
)
57-
return filter(neighbors_of_column(g, j)) do i
58-
!iszero(colors[i])
59-
end
60-
end
61-
6264
function distance2_column_coloring(g::BipartiteGraph)
6365
n = length(columns(g))
6466
colors = zeros(Int, n)
@@ -130,18 +132,23 @@ This graph is defined as `G = (C, E)` where `C = 1:n` is the set of columns and
130132
131133
# Fields
132134
133-
- `A::AbstractMatrix`
135+
- `A_colmajor::AbstractMatrix`: output of [`col_major`](@ref) applied to `A`
134136
135137
# Reference
136138
137139
> [What Color Is Your Jacobian? Graph Coloring for Computing Derivatives](https://epubs.siam.org/doi/abs/10.1137/S0036144504444711), Gebremedhin et al. (2005)
138140
"""
139141
struct AdjacencyGraph{M<:AbstractMatrix} <: AbstractMatrixGraph
140-
A::M
142+
A_colmajor::M
143+
144+
function AdjacencyGraph(A::AbstractMatrix)
145+
A_colmajor = col_major(A)
146+
return new{typeof(A_colmajor)}(A_colmajor)
147+
end
141148
end
142149

143150
function neighbors(g::AdjacencyGraph, j::Integer)
144-
return filter(i -> (i != j) && !iszero(g.A[i, j]), columns(g))
151+
return filter(!isequal(j), nz_in_col(g.A_colmajor, j))
145152
end
146153

147154
function colored_neighbors(g::AdjacencyGraph, j::Integer, colors::AbstractVector{<:Integer})
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
## Conversion between row- and col-major
2+
3+
"""
4+
col_major(A::AbstractMatrix)
5+
6+
Construct a column-major representation of the matrix `A`.
7+
"""
8+
col_major(A::M) where {M<:AbstractMatrix} = A
9+
col_major(A::Transpose{<:Any,M}) where {M<:AbstractMatrix} = M(A)
10+
11+
"""
12+
row_major(A::AbstractMatrix)
13+
14+
Construct a row-major representation of the matrix `A`.
15+
"""
16+
row_major(A::M) where {M<:AbstractMatrix} = transpose(M(transpose(A)))
17+
row_major(A::Transpose{<:Any,M}) where {M<:AbstractMatrix} = A
18+
19+
## Generic nz
20+
21+
function nz_in_col(A_colmajor::AbstractMatrix, j::Integer)
22+
return filter(i -> !iszero(A_colmajor[i, j]), axes(A_colmajor, 1))
23+
end
24+
25+
function nz_in_row(A_rowmajor::AbstractMatrix, i::Integer)
26+
return filter(j -> !iszero(A_rowmajor[i, j]), axes(A_rowmajor, 2))
27+
end
28+
29+
## Sparse nz
30+
31+
function nz_in_col(A_colmajor::SparseMatrixCSC{T}, j::Integer) where {T}
32+
rv = rowvals(A_colmajor)
33+
ind = nzrange(A_colmajor, j)
34+
return view(rv, ind)
35+
end
36+
37+
function nz_in_row(A_rowmajor::Transpose{T,<:SparseMatrixCSC{T}}, i::Integer) where {T}
38+
A_transpose_colmajor = parent(A_rowmajor)
39+
rv = rowvals(A_transpose_colmajor)
40+
ind = nzrange(A_transpose_colmajor, i)
41+
return view(rv, ind)
42+
end
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testset "$(typeof(A))" for A in (
2+
rand(0:1, 10, 10),
3+
transpose(rand(0:1, 10, 10)),
4+
sprand(100, 100, 0.1),
5+
transpose(sprand(100, 100, 0.1)),
6+
)
7+
A_colmajor = DI.col_major(A)
8+
A_rowmajor = DI.row_major(A)
9+
10+
@test A_colmajor == A
11+
@test A_rowmajor == A
12+
end
13+
14+
@testset "$(typeof(A))" for A in (rand(0:1, 10, 10), sprand(100, 100, 0.1))
15+
A_colmajor = DI.col_major(A)
16+
A_rowmajor = DI.row_major(A)
17+
18+
for i in axes(A, 1)
19+
@test DI.nz_in_row(A_rowmajor, i) == DI.nz_in_row(Matrix(A_rowmajor), i)
20+
end
21+
for j in axes(A, 2)
22+
@test DI.nz_in_col(A_colmajor, j) == DI.nz_in_col(Matrix(A_colmajor), j)
23+
end
24+
end

DifferentiationInterface/test/runtests.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,19 @@ include("test_imports.jl")
5151

5252
@testset verbose = true "Internals" begin
5353
@testset verbose = true "Exception handling" begin
54-
include("test_exceptions.jl")
54+
include("internals/exceptions.jl")
5555
end
5656

5757
@testset "Chunks" begin
58-
include("chunk.jl")
58+
include("internals/chunk.jl")
59+
end
60+
61+
@testset "Matrices" begin
62+
include("internals/matrices.jl")
5963
end
6064

6165
@testset verbose = true "Coloring" begin
62-
include("coloring.jl")
66+
include("internals/coloring.jl")
6367
end
6468
end
6569
end;

0 commit comments

Comments
 (0)