Skip to content

Commit 1069266

Browse files
authored
[BREAKING] Move sparse functionality into package extensions (#448)
* Move extras in core code * Update backend extensions * Update docs * Typos * Typos * Fixes * Typos * Fix * Fix ForwardDiff * Fixes * Fixes * Fix Enzyme * Bump versions and compats * Move sparse functionality to extensions * Remove prefixes and add docs
1 parent de23245 commit 1069266

13 files changed

Lines changed: 226 additions & 171 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1010
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
13-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
14-
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
1513

1614
[weakdeps]
1715
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -23,6 +21,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2321
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2422
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2523
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
24+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
25+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
2626
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2727
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
2828
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -38,6 +38,8 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3838
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
3939
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
4040
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
41+
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
42+
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
4143
DifferentiationInterfaceSymbolicsExt = "Symbolics"
4244
DifferentiationInterfaceTapirExt = "Tapir"
4345
DifferentiationInterfaceTrackerExt = "Tracker"

DifferentiationInterface/docs/src/operators.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,17 @@ For this to work, three ingredients are needed (read [this survey](https://epubs
168168
- [`DenseSparsityDetector`](@ref) from DifferentiationInterface.jl (beware that this detector only gives a locally valid pattern)
169169
3. A coloring algorithm: [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl) is the only one we support.
170170

171+
!!! warning
172+
Generic sparse AD is now located in a package extension which depends on SparseMatrixColorings.jl.
173+
171174
These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which DifferentiationInterface.jl re-exports.
172175
Note that for sparse Hessians, you need to put the `SecondOrder` backend inside `AutoSparse`, and not the other way around.
176+
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).
173177

174178
The preparation step of `jacobian` or `hessian` with an `AutoSparse` backend can be long, because it needs to detect the sparsity pattern and color the resulting sparse matrix.
175179
But after preparation, the more zeros are present in the matrix, the greater the speedup will be compared to dense differentiation.
176180

177181
!!! danger
178-
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).
179-
180-
!!! warning
181182
The result of preparation for an `AutoSparse` backend cannot be reused if the sparsity pattern changes.
182183

183184
!!! info
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module DifferentiationInterfaceSparseArraysExt
2+
3+
using ADTypes: ADTypes
4+
using Compat
5+
using DifferentiationInterface
6+
using DifferentiationInterface:
7+
DenseSparsityDetector, PushforwardFast, PushforwardSlow, basis, pushforward_performance
8+
import DifferentiationInterface as DI
9+
using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse
10+
11+
include("sparsity_detector.jl")
12+
13+
end
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
## Direct
2+
3+
function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
4+
@compat (; backend, atol) = detector
5+
J = jacobian(f, backend, x)
6+
return sparse(abs.(J) .> atol)
7+
end
8+
9+
function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:direct})
10+
@compat (; backend, atol) = detector
11+
J = jacobian(f!, y, backend, x)
12+
return sparse(abs.(J) .> atol)
13+
end
14+
15+
function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
16+
@compat (; backend, atol) = detector
17+
H = hessian(f, backend, x)
18+
return sparse(abs.(H) .> atol)
19+
end
20+
21+
## Iterative
22+
23+
function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
24+
@compat (; backend, atol) = detector
25+
y = f(x)
26+
n, m = length(x), length(y)
27+
I, J = Int[], Int[]
28+
if pushforward_performance(backend) isa PushforwardFast
29+
p = similar(y)
30+
extras = prepare_pushforward_same_point(
31+
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
32+
)
33+
for (kj, j) in enumerate(CartesianIndices(x))
34+
pushforward!(f, p, extras, backend, x, basis(backend, x, j))
35+
for ki in LinearIndices(p)
36+
if abs(p[ki]) > atol
37+
push!(I, ki)
38+
push!(J, kj)
39+
end
40+
end
41+
end
42+
else
43+
p = similar(x)
44+
extras = prepare_pullback_same_point(
45+
f, backend, x, basis(backend, y, first(CartesianIndices(y)))
46+
)
47+
for (ki, i) in enumerate(CartesianIndices(y))
48+
pullback!(f, p, extras, backend, x, basis(backend, y, i))
49+
for kj in LinearIndices(p)
50+
if abs(p[kj]) > atol
51+
push!(I, ki)
52+
push!(J, kj)
53+
end
54+
end
55+
end
56+
end
57+
return sparse(I, J, ones(Bool, length(I)), m, n)
58+
end
59+
60+
function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:iterative})
61+
@compat (; backend, atol) = detector
62+
n, m = length(x), length(y)
63+
I, J = Int[], Int[]
64+
if pushforward_performance(backend) isa PushforwardFast
65+
p = similar(y)
66+
extras = prepare_pushforward_same_point(
67+
f!, y, backend, x, basis(backend, x, first(CartesianIndices(x)))
68+
)
69+
for (kj, j) in enumerate(CartesianIndices(x))
70+
pushforward!(f!, y, p, extras, backend, x, basis(backend, x, j))
71+
for ki in LinearIndices(p)
72+
if abs(p[ki]) > atol
73+
push!(I, ki)
74+
push!(J, kj)
75+
end
76+
end
77+
end
78+
else
79+
p = similar(x)
80+
extras = prepare_pullback_same_point(
81+
f!, y, backend, x, basis(backend, y, first(CartesianIndices(y)))
82+
)
83+
for (ki, i) in enumerate(CartesianIndices(y))
84+
pullback!(f!, y, p, extras, backend, x, basis(backend, y, i))
85+
for kj in LinearIndices(p)
86+
if abs(p[kj]) > atol
87+
push!(I, ki)
88+
push!(J, kj)
89+
end
90+
end
91+
end
92+
end
93+
return sparse(I, J, ones(Bool, length(I)), m, n)
94+
end
95+
96+
function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
97+
@compat (; backend, atol) = detector
98+
n = length(x)
99+
I, J = Int[], Int[]
100+
p = similar(x)
101+
extras = prepare_hvp_same_point(
102+
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
103+
)
104+
for (kj, j) in enumerate(CartesianIndices(x))
105+
hvp!(f, p, extras, backend, x, basis(backend, x, j))
106+
for ki in LinearIndices(p)
107+
if abs(p[ki]) > atol
108+
push!(I, ki)
109+
push!(J, kj)
110+
end
111+
end
112+
end
113+
return sparse(I, J, ones(Bool, length(I)), n, n)
114+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
module DifferentiationInterfaceSparseMatrixColoringsExt
2+
3+
using ADTypes:
4+
ADTypes,
5+
AbstractADType,
6+
AutoSparse,
7+
dense_ad,
8+
coloring_algorithm,
9+
sparsity_detector,
10+
jacobian_sparsity,
11+
hessian_sparsity
12+
using Compat
13+
using DifferentiationInterface
14+
using DifferentiationInterface:
15+
GradientExtras,
16+
HessianExtras,
17+
HVPExtras,
18+
JacobianExtras,
19+
PullbackExtras,
20+
PushforwardExtras,
21+
PushforwardFast,
22+
PushforwardSlow,
23+
Tangents,
24+
dense_ad,
25+
maybe_dense_ad,
26+
maybe_inner,
27+
maybe_outer,
28+
multibasis,
29+
pick_batchsize,
30+
pushforward_performance
31+
import DifferentiationInterface as DI
32+
using SparseMatrixColorings:
33+
AbstractColoringResult,
34+
ColoringProblem,
35+
GreedyColoringAlgorithm,
36+
coloring,
37+
column_colors,
38+
row_colors,
39+
column_groups,
40+
row_groups,
41+
decompress,
42+
decompress!
43+
44+
include("jacobian.jl")
45+
include("hessian.jl")
46+
47+
end

DifferentiationInterface/src/sparse/hessian.jl renamed to DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535

3636
## Hessian, one argument
3737

38-
function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
38+
function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F}
3939
dense_backend = dense_ad(backend)
4040
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
4141
problem = ColoringProblem{:symmetric,:column}()
@@ -64,7 +64,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
6464
)
6565
end
6666

67-
function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) where {F,B}
67+
function DI.hessian(
68+
f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x
69+
) where {F,B}
6870
@compat (; coloring_result, batched_seeds, hvp_extras) = extras
6971
dense_backend = dense_ad(backend)
7072
Ng = length(column_groups(coloring_result))
@@ -85,7 +87,7 @@ function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) w
8587
return decompress(compressed_matrix, coloring_result)
8688
end
8789

88-
function hessian!(
90+
function DI.hessian!(
8991
f::F, hess, extras::SparseHessianExtras{B}, backend::AutoSparse, x
9092
) where {F,B}
9193
@compat (;
@@ -113,7 +115,7 @@ function hessian!(
113115
return hess
114116
end
115117

116-
function value_gradient_and_hessian!(
118+
function DI.value_gradient_and_hessian!(
117119
f::F, grad, hess, extras::SparseHessianExtras, backend::AutoSparse, x
118120
) where {F}
119121
y, _ = value_and_gradient!(
@@ -123,7 +125,7 @@ function value_gradient_and_hessian!(
123125
return y, grad, hess
124126
end
125127

126-
function value_gradient_and_hessian(
128+
function DI.value_gradient_and_hessian(
127129
f::F, extras::SparseHessianExtras, backend::AutoSparse, x
128130
) where {F}
129131
y, grad = value_and_gradient(

DifferentiationInterface/src/sparse/jacobian.jl renamed to DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ function PullbackSparseJacobianExtras{B}(;
6060
)
6161
end
6262

63-
function prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
63+
function DI.prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
6464
y = f(x)
6565
return _prepare_sparse_jacobian_aux(
6666
(f,), backend, x, y, pushforward_performance(backend)
6767
)
6868
end
6969

70-
function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
70+
function DI.prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
7171
return _prepare_sparse_jacobian_aux(
7272
(f!, y), backend, x, y, pushforward_performance(backend)
7373
)
@@ -137,49 +137,51 @@ end
137137

138138
## One argument
139139

140-
function jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
140+
function DI.jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
141141
return _sparse_jacobian_aux((f,), extras, backend, x)
142142
end
143143

144-
function jacobian!(
144+
function DI.jacobian!(
145145
f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
146146
) where {F}
147147
return _sparse_jacobian_aux!((f,), jac, extras, backend, x)
148148
end
149149

150-
function value_and_jacobian(
150+
function DI.value_and_jacobian(
151151
f::F, extras::SparseJacobianExtras, backend::AutoSparse, x
152152
) where {F}
153153
return f(x), jacobian(f, extras, backend, x)
154154
end
155155

156-
function value_and_jacobian!(
156+
function DI.value_and_jacobian!(
157157
f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
158158
) where {F}
159159
return f(x), jacobian!(f, jac, extras, backend, x)
160160
end
161161

162162
## Two arguments
163163

164-
function jacobian(f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
164+
function DI.jacobian(
165+
f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x
166+
) where {F}
165167
return _sparse_jacobian_aux((f!, y), extras, backend, x)
166168
end
167169

168-
function jacobian!(
170+
function DI.jacobian!(
169171
f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
170172
) where {F}
171173
return _sparse_jacobian_aux!((f!, y), jac, extras, backend, x)
172174
end
173175

174-
function value_and_jacobian(
176+
function DI.value_and_jacobian(
175177
f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x
176178
) where {F}
177179
jac = jacobian(f!, y, extras, backend, x)
178180
f!(y, x)
179181
return y, jac
180182
end
181183

182-
function value_and_jacobian!(
184+
function DI.value_and_jacobian!(
183185
f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
184186
) where {F}
185187
jacobian!(f!, y, jac, extras, backend, x)

0 commit comments

Comments
 (0)