Skip to content

Commit 9a524d3

Browse files
authored
Mixed-mode sparse Jacobians (#554)
1 parent 3094dca commit 9a524d3

17 files changed

Lines changed: 352 additions & 28 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.21"
4+
version = "0.6.22"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -61,7 +61,7 @@ ReverseDiff = "1.15.1"
6161
SparseArrays = "<0.0.1,1"
6262
SparseConnectivityTracer = "0.5.0,0.6"
6363
StaticArrays = "1.9.7"
64-
SparseMatrixColorings = "0.4.5"
64+
SparseMatrixColorings = "0.4.9"
6565
Symbolics = "5.27.1, 6"
6666
Tracker = "0.2.33"
6767
Zygote = "0.6.69"
@@ -99,4 +99,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9999
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
100100

101101
[targets]
102-
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]
102+
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]

DifferentiationInterface/docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ jacobian
6868
jacobian!
6969
value_and_jacobian
7070
value_and_jacobian!
71+
MixedMode
7172
```
7273

7374
## Second order

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ using DifferentiationInterface:
2121
PushforwardPerformance,
2222
inner,
2323
outer,
24+
forward_backend,
25+
reverse_backend,
2426
multibasis,
2527
pick_batchsize,
2628
pushforward_performance,
@@ -33,13 +35,32 @@ using SparseMatrixColorings:
3335
coloring,
3436
column_colors,
3537
row_colors,
38+
ncolors,
3639
column_groups,
3740
row_groups,
3841
sparsity_pattern,
3942
decompress!
4043
import SparseMatrixColorings as SMC
4144

45+
function fy_with_contexts(f, contexts::Vararg{Context,C}) where {C}
46+
return (with_contexts(f, contexts...),)
47+
end
48+
49+
function fy_with_contexts(f!, y, contexts::Vararg{Context,C}) where {C}
50+
return (with_contexts(f!, contexts...), y)
51+
end
52+
53+
abstract type SparseJacobianPrep <: JacobianPrep end
54+
55+
SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result)
56+
SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result)
57+
SMC.column_groups(prep::SparseJacobianPrep) = column_groups(prep.coloring_result)
58+
SMC.row_colors(prep::SparseJacobianPrep) = row_colors(prep.coloring_result)
59+
SMC.row_groups(prep::SparseJacobianPrep) = row_groups(prep.coloring_result)
60+
SMC.ncolors(prep::SparseJacobianPrep) = ncolors(prep.coloring_result)
61+
4262
include("jacobian.jl")
63+
include("jacobian_mixed.jl")
4364
include("hessian.jl")
4465

4566
end

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ end
1919
SMC.sparsity_pattern(prep::SparseHessianPrep) = sparsity_pattern(prep.coloring_result)
2020
SMC.column_colors(prep::SparseHessianPrep) = column_colors(prep.coloring_result)
2121
SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
22+
SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result)
2223

2324
## Hessian, one argument
2425

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,5 @@
1-
function fy_with_contexts(f, contexts::Vararg{Context,C}) where {C}
2-
return (with_contexts(f, contexts...),)
3-
end
4-
5-
function fy_with_contexts(f!, y, contexts::Vararg{Context,C}) where {C}
6-
return (with_contexts(f!, contexts...), y)
7-
end
8-
91
## Preparation
102

11-
abstract type SparseJacobianPrep <: JacobianPrep end
12-
13-
SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result)
14-
SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result)
15-
SMC.column_groups(prep::SparseJacobianPrep) = column_groups(prep.coloring_result)
16-
SMC.row_colors(prep::SparseJacobianPrep) = row_colors(prep.coloring_result)
17-
SMC.row_groups(prep::SparseJacobianPrep) = row_groups(prep.coloring_result)
18-
193
struct PushforwardSparseJacobianPrep{
204
BS<:BatchSizeSettings,
215
C<:AbstractColoringResult{:nonsymmetric,:column},
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
## Preparation
2+
3+
struct MixedModeSparseJacobianPrep{
4+
BSf<:BatchSizeSettings,
5+
BSr<:BatchSizeSettings,
6+
C<:AbstractColoringResult{:nonsymmetric,:bidirectional},
7+
M<:AbstractMatrix{<:Real},
8+
Sf<:Vector{<:NTuple},
9+
Sr<:Vector{<:NTuple},
10+
Rf<:Vector{<:NTuple},
11+
Rr<:Vector{<:NTuple},
12+
Ef<:PushforwardPrep,
13+
Er<:PullbackPrep,
14+
} <: SparseJacobianPrep
15+
batch_size_settings_forward::BSf
16+
batch_size_settings_reverse::BSr
17+
coloring_result::C
18+
compressed_matrix_forward::M
19+
compressed_matrix_reverse::M
20+
batched_seeds_forward::Sf
21+
batched_seeds_reverse::Sr
22+
batched_results_forward::Rf
23+
batched_results_reverse::Rr
24+
pushforward_prep::Ef
25+
pullback_prep::Er
26+
end
27+
28+
function DI.prepare_jacobian(
29+
f::F, backend::AutoSparse{<:MixedMode}, x, contexts::Vararg{Context,C}
30+
) where {F,C}
31+
y = f(x, map(unwrap, contexts)...)
32+
return _prepare_mixed_sparse_jacobian_aux(y, (f,), backend, x, contexts...)
33+
end
34+
35+
function DI.prepare_jacobian(
36+
f!::F, y, backend::AutoSparse{<:MixedMode}, x, contexts::Vararg{Context,C}
37+
) where {F,C}
38+
return _prepare_mixed_sparse_jacobian_aux(y, (f!, y), backend, x, contexts...)
39+
end
40+
41+
function _prepare_mixed_sparse_jacobian_aux(
42+
y, f_or_f!y::FY, backend::AutoSparse{<:MixedMode}, x, contexts::Vararg{Context,C}
43+
) where {FY,C}
44+
dense_backend = dense_ad(backend)
45+
sparsity = jacobian_sparsity(
46+
fy_with_contexts(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
47+
)
48+
problem = ColoringProblem{:nonsymmetric,:bidirectional}()
49+
coloring_result = coloring(
50+
sparsity,
51+
problem,
52+
coloring_algorithm(backend);
53+
decompression_eltype=promote_type(eltype(x), eltype(y)),
54+
)
55+
56+
Nf = length(column_groups(coloring_result))
57+
Nr = length(row_groups(coloring_result))
58+
batch_size_settings_forward = pick_batchsize(forward_backend(dense_backend), Nf)
59+
batch_size_settings_reverse = pick_batchsize(reverse_backend(dense_backend), Nr)
60+
61+
return _prepare_mixed_sparse_jacobian_aux_aux(
62+
batch_size_settings_forward,
63+
batch_size_settings_reverse,
64+
coloring_result,
65+
y,
66+
f_or_f!y,
67+
backend,
68+
x,
69+
contexts...,
70+
)
71+
end
72+
73+
function _prepare_mixed_sparse_jacobian_aux_aux(
74+
batch_size_settings_forward::BatchSizeSettings{Bf},
75+
batch_size_settings_reverse::BatchSizeSettings{Br},
76+
coloring_result::AbstractColoringResult{:nonsymmetric,:bidirectional},
77+
y,
78+
f_or_f!y::FY,
79+
backend::AutoSparse{<:MixedMode},
80+
x,
81+
contexts::Vararg{Context,C},
82+
) where {Bf,Br,FY,C}
83+
Nf, Af = batch_size_settings_forward.N, batch_size_settings_forward.A
84+
Nr, Ar = batch_size_settings_reverse.N, batch_size_settings_reverse.A
85+
86+
dense_backend = dense_ad(backend)
87+
88+
groups_forward = column_groups(coloring_result)
89+
groups_reverse = row_groups(coloring_result)
90+
91+
seeds_forward = [
92+
multibasis(backend, x, eachindex(x)[group]) for group in groups_forward
93+
]
94+
seeds_reverse = [
95+
multibasis(backend, y, eachindex(y)[group]) for group in groups_reverse
96+
]
97+
98+
compressed_matrix_forward = stack(_ -> vec(similar(y)), groups_forward; dims=2)
99+
compressed_matrix_reverse = stack(_ -> vec(similar(x)), groups_reverse; dims=1)
100+
101+
batched_seeds_forward = [
102+
ntuple(b -> seeds_forward[1 + ((a - 1) * Bf + (b - 1)) % Nf], Val(Bf)) for a in 1:Af
103+
]
104+
batched_seeds_reverse = [
105+
ntuple(b -> seeds_reverse[1 + ((a - 1) * Br + (b - 1)) % Nr], Val(Br)) for a in 1:Ar
106+
]
107+
108+
batched_results_forward = [
109+
ntuple(b -> similar(y), Val(Bf)) for _ in batched_seeds_forward
110+
]
111+
batched_results_reverse = [
112+
ntuple(b -> similar(x), Val(Br)) for _ in batched_seeds_reverse
113+
]
114+
115+
pushforward_prep = prepare_pushforward(
116+
f_or_f!y...,
117+
forward_backend(dense_backend),
118+
x,
119+
batched_seeds_forward[1],
120+
contexts...,
121+
)
122+
pullback_prep = prepare_pullback(
123+
f_or_f!y...,
124+
reverse_backend(dense_backend),
125+
x,
126+
batched_seeds_reverse[1],
127+
contexts...,
128+
)
129+
130+
return MixedModeSparseJacobianPrep(
131+
batch_size_settings_forward,
132+
batch_size_settings_reverse,
133+
coloring_result,
134+
compressed_matrix_forward,
135+
compressed_matrix_reverse,
136+
batched_seeds_forward,
137+
batched_seeds_reverse,
138+
batched_results_forward,
139+
batched_results_reverse,
140+
pushforward_prep,
141+
pullback_prep,
142+
)
143+
end
144+
145+
## Common auxiliaries
146+
147+
function _sparse_jacobian_aux!(
148+
f_or_f!y::FY,
149+
jac,
150+
prep::MixedModeSparseJacobianPrep{<:BatchSizeSettings{Bf},<:BatchSizeSettings{Br}},
151+
backend::AutoSparse,
152+
x,
153+
contexts::Vararg{Context,C},
154+
) where {FY,Bf,Br,C}
155+
(;
156+
batch_size_settings_forward,
157+
batch_size_settings_reverse,
158+
coloring_result,
159+
compressed_matrix_forward,
160+
compressed_matrix_reverse,
161+
batched_seeds_forward,
162+
batched_seeds_reverse,
163+
batched_results_forward,
164+
batched_results_reverse,
165+
pushforward_prep,
166+
pullback_prep,
167+
) = prep
168+
169+
dense_backend = dense_ad(backend)
170+
Nf = batch_size_settings_forward.N
171+
Nr = batch_size_settings_reverse.N
172+
173+
pushforward_prep_same = prepare_pushforward_same_point(
174+
f_or_f!y...,
175+
pushforward_prep,
176+
forward_backend(dense_backend),
177+
x,
178+
batched_seeds_forward[1],
179+
contexts...,
180+
)
181+
pullback_prep_same = prepare_pullback_same_point(
182+
f_or_f!y...,
183+
pullback_prep,
184+
reverse_backend(dense_backend),
185+
x,
186+
batched_seeds_reverse[1],
187+
contexts...,
188+
)
189+
190+
for a in eachindex(batched_seeds_forward, batched_results_forward)
191+
pushforward!(
192+
f_or_f!y...,
193+
batched_results_forward[a],
194+
pushforward_prep_same,
195+
forward_backend(dense_backend),
196+
x,
197+
batched_seeds_forward[a],
198+
contexts...,
199+
)
200+
201+
for b in eachindex(batched_results_forward[a])
202+
copyto!(
203+
view(compressed_matrix_forward, :, 1 + ((a - 1) * Bf + (b - 1)) % Nf),
204+
vec(batched_results_forward[a][b]),
205+
)
206+
end
207+
end
208+
209+
for a in eachindex(batched_seeds_reverse, batched_results_reverse)
210+
pullback!(
211+
f_or_f!y...,
212+
batched_results_reverse[a],
213+
pullback_prep_same,
214+
reverse_backend(dense_backend),
215+
x,
216+
batched_seeds_reverse[a],
217+
contexts...,
218+
)
219+
220+
for b in eachindex(batched_results_reverse[a])
221+
copyto!(
222+
view(compressed_matrix_reverse, 1 + ((a - 1) * Br + (b - 1)) % Nr, :),
223+
vec(batched_results_reverse[a][b]),
224+
)
225+
end
226+
end
227+
228+
decompress!(jac, compressed_matrix_reverse, compressed_matrix_forward, coloring_result)
229+
230+
return jac
231+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using LinearAlgebra: dot
3333

3434
include("compat.jl")
3535

36+
include("first_order/mixed_mode.jl")
3637
include("second_order/second_order.jl")
3738

3839
include("utils/prep.jl")
@@ -66,7 +67,7 @@ include("misc/zero_backends.jl")
6667
## Exported
6768

6869
export Context, Constant, Cache
69-
export SecondOrder
70+
export MixedMode, SecondOrder
7071

7172
export value_and_pushforward!, value_and_pushforward
7273
export value_and_pullback!, value_and_pullback
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
MixedMode
3+
4+
Combination of a forward and a reverse mode backend for mixed-mode Jacobian computation.
5+
6+
!!! danger
7+
`MixedMode` backends only support [`jacobian`](@ref) and its variants.
8+
9+
# Constructor
10+
11+
MixedMode(forward_backend, reverse_backend)
12+
"""
13+
struct MixedMode{F<:AbstractADType,R<:AbstractADType} <: AbstractADType
14+
forward::F
15+
reverse::R
16+
function MixedMode(forward::AbstractADType, reverse::AbstractADType)
17+
@assert pushforward_performance(forward) isa PushforwardFast
18+
@assert pullback_performance(reverse) isa PullbackFast
19+
return new{typeof(forward),typeof(reverse)}(forward, reverse)
20+
end
21+
end
22+
23+
forward_backend(m::MixedMode) = m.forward
24+
reverse_backend(m::MixedMode) = m.reverse
25+
26+
struct ForwardAndReverseMode <: ADTypes.AbstractMode end
27+
ADTypes.mode(::MixedMode) = ForwardAndReverseMode()

DifferentiationInterface/src/utils/batchsize.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ function pick_batchsize(backend::AbstractADType, x_or_N::Union{AbstractArray,Int
5252
"You should select the batch size for the dense backend of $backend"
5353
),
5454
)
55+
elseif backend isa MixedMode
56+
throw(
57+
ArgumentError(
58+
"You should select the batch size for the forward or reverse backend of $backend",
59+
),
60+
)
5561
else
5662
return BatchSizeSettings(backend, x_or_N)
5763
end

0 commit comments

Comments
 (0)