Skip to content

Commit 7443b02

Browse files
authored
Make type-stable constructors part of the API (#80)
1 parent 7105681 commit 7443b02

9 files changed

Lines changed: 129 additions & 65 deletions

File tree

docs/src/api.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ CollapsedDocStrings = true
55
CurrentModule = SparseMatrixColorings
66
```
77

8+
The docstrings on this page define the public API of the package.
9+
810
```@docs
911
SparseMatrixColorings
1012
```
1113

12-
The docstrings on this page define the public API of the package.
13-
1414
## Main function
1515

1616
```@docs
@@ -39,8 +39,6 @@ decompress!
3939

4040
## Orders
4141

42-
These symbols are not exported but they are still part of the public API.
43-
4442
```@docs
4543
AbstractOrder
4644
NaturalOrder

docs/src/dev.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ SparseMatrixColorings.LinearSystemColoringResult
4343
## Testing
4444

4545
```@docs
46-
SparseMatrixColorings.same_sparsity_pattern
4746
SparseMatrixColorings.directly_recoverable_columns
4847
SparseMatrixColorings.symmetrically_orthogonal_columns
4948
SparseMatrixColorings.structurally_orthogonal_columns
@@ -54,6 +53,7 @@ SparseMatrixColorings.structurally_orthogonal_columns
5453
```@docs
5554
SparseMatrixColorings.respectful_similar
5655
SparseMatrixColorings.matrix_versions
56+
SparseMatrixColorings.same_pattern
5757
```
5858

5959
## Examples

src/SparseMatrixColorings.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@ include("decompression.jl")
4949
include("check.jl")
5050
include("examples.jl")
5151

52+
export NaturalOrder, RandomOrder, LargestFirst
5253
export ColoringProblem, GreedyColoringAlgorithm, AbstractColoringResult
5354
export coloring
5455
export column_colors, row_colors
5556
export column_groups, row_groups
5657
export compress, decompress, decompress!
5758

58-
@compat public NaturalOrder, RandomOrder, LargestFirst
59-
6059
end

src/decompression.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -181,27 +181,16 @@ true
181181
- [`ColoringProblem`](@ref)
182182
- [`AbstractColoringResult`](@ref)
183183
"""
184-
function decompress!(
185-
A::AbstractMatrix{R},
186-
B::AbstractMatrix{R},
187-
result::AbstractColoringResult{structure,partition,decompression},
188-
) where {R<:Real,structure,partition,decompression}
189-
# common checks
190-
S = get_matrix(result)
191-
structure == :symmetric && checksquare(A)
192-
if !same_sparsity_pattern(A, S)
193-
throw(DimensionMismatch("`A` and `S` must have the same sparsity pattern."))
194-
end
195-
return decompress_aux!(A, B, result)
196-
end
184+
function decompress! end
197185

198186
## NonSymmetricColoringResult
199187

200-
function decompress_aux!(
188+
function decompress!(
201189
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:column}
202190
) where {R<:Real}
203-
A .= zero(R)
204191
S = get_matrix(result)
192+
check_same_pattern(A, S)
193+
A .= zero(R)
205194
color = column_colors(result)
206195
rvS = rowvals(S)
207196
for j in axes(S, 2)
@@ -214,11 +203,12 @@ function decompress_aux!(
214203
return A
215204
end
216205

217-
function decompress_aux!(
206+
function decompress!(
218207
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:row}
219208
) where {R<:Real}
220-
A .= zero(R)
221209
S = get_matrix(result)
210+
check_same_pattern(A, S)
211+
A .= zero(R)
222212
color = row_colors(result)
223213
rvS = rowvals(S)
224214
for j in axes(S, 2)
@@ -231,9 +221,11 @@ function decompress_aux!(
231221
return A
232222
end
233223

234-
function decompress_aux!(
224+
function decompress!(
235225
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:column}
236226
) where {R<:Real}
227+
S = get_matrix(result)
228+
check_same_pattern(A, S)
237229
nzA = nonzeros(A)
238230
ind = result.compressed_indices
239231
for i in eachindex(nzA, ind)
@@ -242,9 +234,11 @@ function decompress_aux!(
242234
return A
243235
end
244236

245-
function decompress_aux!(
237+
function decompress!(
246238
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::NonSymmetricColoringResult{:row}
247239
) where {R<:Real}
240+
S = get_matrix(result)
241+
check_same_pattern(A, S)
248242
nzA = nonzeros(A)
249243
ind = result.compressed_indices
250244
for i in eachindex(nzA, ind)
@@ -255,11 +249,12 @@ end
255249

256250
## StarSetColoringResult
257251

258-
function decompress_aux!(
252+
function decompress!(
259253
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::StarSetColoringResult
260254
) where {R<:Real}
261-
A .= zero(R)
262255
S = get_matrix(result)
256+
check_same_pattern(A, S)
257+
A .= zero(R)
263258
color = column_colors(result)
264259
rvS = rowvals(S)
265260
for j in axes(S, 2)
@@ -272,9 +267,11 @@ function decompress_aux!(
272267
return A
273268
end
274269

275-
function decompress_aux!(
270+
function decompress!(
276271
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::StarSetColoringResult
277272
) where {R<:Real}
273+
S = get_matrix(result)
274+
check_same_pattern(A, S)
278275
nzA = nonzeros(A)
279276
ind = result.compressed_indices
280277
for i in eachindex(nzA, ind)
@@ -287,11 +284,12 @@ end
287284

288285
# TODO: add method for A::SparseMatrixCSC
289286

290-
function decompress_aux!(
287+
function decompress!(
291288
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::TreeSetColoringResult
292289
) where {R<:Real}
293-
A .= zero(R)
294290
S = get_matrix(result)
291+
check_same_pattern(A, S)
292+
A .= zero(R)
295293
color = column_colors(result)
296294
@compat (; vertices_by_tree, reverse_bfs_orders, buffer) = result
297295

@@ -326,10 +324,11 @@ end
326324

327325
## MatrixInverseColoringResult
328326

329-
function decompress_aux!(
327+
function decompress!(
330328
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::LinearSystemColoringResult
331329
) where {R<:Real}
332330
S = get_matrix(result)
331+
check_same_pattern(A, S)
333332
color = column_colors(result)
334333
@compat (; strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A) = result
335334

src/interface.jl

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,62 @@
1+
function check_valid_problem(structure::Symbol, partition::Symbol)
2+
valid = (
3+
(structure == :nonsymmetric && partition in (:column, :row)) ||
4+
(structure == :symmetric && partition == :column)
5+
)
6+
if !valid
7+
throw(
8+
ArgumentError(
9+
"The combination `($(repr(structure)), $(repr(partition)))` is not supported by `ColoringProblem`.",
10+
),
11+
)
12+
end
13+
end
14+
15+
function check_valid_algorithm(decompression::Symbol)
16+
valid = decompression in (:direct, :substitution)
17+
if !valid
18+
throw(
19+
ArgumentError(
20+
"The setting `decompression=$(repr(decompression))` is not supported by `GreedyColoringAlgorithm`.",
21+
),
22+
)
23+
end
24+
end
25+
126
"""
227
ColoringProblem{structure,partition}
328
429
Selector type for the coloring problem to solve, enabling multiple dispatch.
530
631
It is passed as an argument to the main function [`coloring`](@ref).
732
8-
# Constructor
33+
# Constructors
934
10-
ColoringProblem(; structure::Symbol=:nonsymmetric, partition::Symbol=:column)
35+
ColoringProblem{structure,partition}()
36+
ColoringProblem(; structure=:nonsymmetric, partition=:column)
1137
1238
- `structure::Symbol`: either `:nonsymmetric` or `:symmetric`
1339
- `partition::Symbol`: either `:column`, `:row` or `:bidirectional`
1440
41+
!!! warning
42+
The second constructor (based on keyword arguments) is type-unstable.
43+
1544
# Link to automatic differentiation
1645
1746
Matrix coloring is often used in automatic differentiation, and here is the translation guide:
1847
19-
| matrix | mode | `structure` | `partition` |
20-
| -------- | -------------------- | --------------- | -----------------|
21-
| Jacobian | forward | `:nonsymmetric` | `:column` |
22-
| Jacobian | reverse | `:nonsymmetric` | `:row` |
23-
| Jacobian | forward + reverse | `:nonsymmetric` | `:bidirectional` |
24-
| Hessian | any | `:symmetric` | `:column` |
25-
26-
!!! warning
27-
With a `:symmetric` structure, you have to use a `:column` partition.
28-
29-
!!! warning
30-
At the moment, `:bidirectional` partitions are not implemented.
48+
| matrix | mode | `structure` | `partition` | implemented |
49+
| -------- | ------- | --------------- | ---------------- | ----------- |
50+
| Jacobian | forward | `:nonsymmetric` | `:column` | yes |
51+
| Jacobian | reverse | `:nonsymmetric` | `:row` | yes |
52+
| Jacobian | mixed | `:nonsymmetric` | `:bidirectional` | no |
53+
| Hessian | - | `:symmetric` | `:column` | yes |
54+
| Hessian | - | `:symmetric` | `:row` | no |
3155
"""
3256
struct ColoringProblem{structure,partition} end
3357

3458
function ColoringProblem(; structure::Symbol=:nonsymmetric, partition::Symbol=:column)
35-
@assert structure in (:nonsymmetric, :symmetric)
36-
@assert partition in (:column, :row, :bidirectional)
59+
check_valid_problem(structure, partition)
3760
return ColoringProblem{structure,partition}()
3861
end
3962

@@ -44,16 +67,17 @@ Greedy coloring algorithm for sparse matrices which colors columns or rows one a
4467
4568
It is passed as an argument to the main function [`coloring`](@ref).
4669
47-
# Constructor
70+
# Constructors
4871
49-
GreedyColoringAlgorithm(
50-
order::AbstractOrder=NaturalOrder();
51-
decompression::Symbol=:direct
52-
)
72+
GreedyColoringAlgorithm{decompression}(order=NaturalOrder())
73+
GreedyColoringAlgorithm(order=NaturalOrder(); decompression=:direct)
5374
5475
- `order::AbstractOrder`: the order in which the columns or rows are colored, which can impact the number of colors.
5576
- `decompression::Symbol`: either `:direct` or `:substitution`. Usually `:substitution` leads to fewer colors, at the cost of a more expensive coloring (and decompression). When `:substitution` is not applicable, it falls back on `:direct` decompression.
5677
78+
!!! warning
79+
The second constructor (based on keyword arguments) is type-unstable.
80+
5781
# ADTypes coloring interface
5882
5983
`GreedyColoringAlgorithm` is a subtype of [`ADTypes.AbstractColoringAlgorithm`](@extref ADTypes.AbstractColoringAlgorithm), which means the following methods are also applicable:
@@ -74,10 +98,17 @@ struct GreedyColoringAlgorithm{decompression,O<:AbstractOrder} <:
7498
order::O
7599
end
76100

101+
function GreedyColoringAlgorithm{decompression}(
102+
order::AbstractOrder=NaturalOrder()
103+
) where {decompression}
104+
check_valid_algorithm(decompression)
105+
return GreedyColoringAlgorithm{decompression,typeof(order)}(order)
106+
end
107+
77108
function GreedyColoringAlgorithm(
78109
order::AbstractOrder=NaturalOrder(); decompression::Symbol=:direct
79110
)
80-
@assert decompression in (:direct, :substitution)
111+
check_valid_algorithm(decompression)
81112
return GreedyColoringAlgorithm{decompression,typeof(order)}(order)
82113
end
83114

@@ -106,9 +137,9 @@ julia> S = sparse([
106137
0 1 1 0 0 0
107138
]);
108139
109-
julia> problem = ColoringProblem(structure=:nonsymmetric, partition=:column);
140+
julia> problem = ColoringProblem(; structure=:nonsymmetric, partition=:column);
110141
111-
julia> algo = GreedyColoringAlgorithm();
142+
julia> algo = GreedyColoringAlgorithm(; decompression=:direct);
112143
113144
julia> result = coloring(S, problem, algo);
114145

src/matrices.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,23 @@ function respectful_similar(A::Adjoint, ::Type{T}) where {T}
4646
end
4747

4848
"""
49-
same_sparsity_pattern(A::AbstractMatrix, B::AbstractMatrix)
49+
same_pattern(A::AbstractMatrix, B::AbstractMatrix)
5050
5151
Perform a partial equality check on the sparsity patterns of `A` and `B`:
5252
5353
- if the return is `true`, they might have the same sparsity pattern but we're not sure
5454
- if the return is `false`, they definitely don't have the same sparsity pattern
5555
"""
56-
function same_sparsity_pattern(A::AbstractMatrix, B::AbstractMatrix)
56+
function same_pattern(A::AbstractMatrix, B::AbstractMatrix)
5757
return size(A) == size(B)
5858
end
5959

60-
function same_sparsity_pattern(A::SparseMatrixCSC, B::SparseMatrixCSC)
60+
function same_pattern(A::SparseMatrixCSC, B::SparseMatrixCSC)
6161
return size(A) == size(B) && nnz(A) == nnz(B)
6262
end
6363

64-
function same_sparsity_pattern(
65-
A::TransposeOrAdjoint{<:Any,<:SparseMatrixCSC},
66-
B::TransposeOrAdjoint{<:Any,<:SparseMatrixCSC},
67-
)
68-
return same_sparsity_pattern(parent(A), parent(B))
64+
function check_same_pattern(A::AbstractMatrix, S::AbstractMatrix)
65+
if !same_pattern(A, S)
66+
throw(DimensionMismatch("`A` and `S` must have the same sparsity pattern."))
67+
end
6968
end

test/constructors.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using SparseMatrixColorings
2+
using Test
3+
4+
@test ColoringProblem{:nonsymmetric,:column}() == ColoringProblem()
5+
@test ColoringProblem{:symmetric,:column}() ==
6+
ColoringProblem(; structure=:symmetric, partition=:column)
7+
8+
@test_throws ArgumentError ColoringProblem(; structure=:weird, partition=:column)
9+
@test_throws ArgumentError ColoringProblem(; structure=:symmetric, partition=:row)
10+
11+
@test GreedyColoringAlgorithm{:direct}() == GreedyColoringAlgorithm()
12+
@test GreedyColoringAlgorithm{:substitution}() ==
13+
GreedyColoringAlgorithm(; decompression=:substitution)
14+
15+
@test_throws ArgumentError GreedyColoringAlgorithm(decompression=:weird)

test/matrices.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
using LinearAlgebra
12
using SparseArrays
2-
using SparseMatrixColorings: matrix_versions, respectful_similar, same_sparsity_pattern
3+
using SparseMatrixColorings:
4+
check_same_pattern, matrix_versions, respectful_similar, same_pattern
35
using StableRNGs
46
using Test
57

@@ -30,3 +32,21 @@ same_view(::Adjoint, ::Adjoint) = true
3032
size(B) == size(A) && same_view(A, B)
3133
end
3234
end
35+
36+
@testset "Sparsity pattern" begin
37+
S = sparse([
38+
0 1 1
39+
0 1 0
40+
1 1 0
41+
])
42+
43+
A1 = copy(S)
44+
A2 = copy(S)
45+
A2[1, 1] = 1
46+
47+
@test same_pattern(A1, S)
48+
@test !same_pattern(A2, S)
49+
@test same_pattern(Matrix(A2), S)
50+
51+
@test_throws DimensionMismatch check_same_pattern(A2, S)
52+
end

0 commit comments

Comments
 (0)