Skip to content

Commit 1208b44

Browse files
authored
Improve Jacobian and Hessian preparation (#535)
* Avoid slicing the whole Jacobian if batch size does not divide total size * Make preparation type-stable and test it * Pick batchsize * Fix chunk size * Fix ENzyme * No fully static chunk size * No correctness for zero backends * Bump to next SMC
1 parent f71edc4 commit 1208b44

17 files changed

Lines changed: 157 additions & 72 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.4"
4+
version = "0.6.5"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -60,7 +60,7 @@ PolyesterForwardDiff = "0.1.1"
6060
ReverseDiff = "1.15.1"
6161
SparseArrays = "<0.0.1,1"
6262
SparseConnectivityTracer = "0.5.0,0.6"
63-
SparseMatrixColorings = "0.4.4"
63+
SparseMatrixColorings = "0.4.5"
6464
Symbolics = "5.27.1, 6"
6565
Tracker = "0.2.33"
6666
Zygote = "0.6.69"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,16 @@ struct EnzymeForwardGradientPrep{B,O} <: GradientPrep
114114
shadows::O
115115
end
116116

117+
function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
118+
return EnzymeForwardGradientPrep{B,O}(shadows)
119+
end
120+
117121
function DI.prepare_gradient(
118122
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
119123
) where {F}
120-
B = pick_batchsize(backend, length(x))
121-
shadows = create_shadows(Val(B), x)
122-
return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows)
124+
valB = pick_batchsize(backend, length(x))
125+
shadows = create_shadows(valB, x)
126+
return EnzymeForwardGradientPrep(valB, shadows)
123127
end
124128

125129
function DI.gradient(
@@ -176,13 +180,19 @@ struct EnzymeForwardOneArgJacobianPrep{B,O} <: JacobianPrep
176180
output_length::Int
177181
end
178182

183+
function EnzymeForwardOneArgJacobianPrep(
184+
::Val{B}, shadows::O, output_length::Integer
185+
) where {B,O}
186+
return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length)
187+
end
188+
179189
function DI.prepare_jacobian(
180190
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
181191
) where {F}
182192
y = f(x)
183-
B = pick_batchsize(backend, length(x))
184-
shadows = create_shadows(Val(B), x)
185-
return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y))
193+
valB = pick_batchsize(backend, length(x))
194+
shadows = create_shadows(valB, x)
195+
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
186196
end
187197

188198
function DI.jacobian(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,15 @@ end
349349

350350
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end
351351

352+
function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
353+
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
354+
end
355+
352356
function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
353357
y = f(x)
354358
Sy = size(y)
355-
B = pick_batchsize(backend, prod(Sy))
356-
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
359+
valB = pick_batchsize(backend, prod(Sy))
360+
return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
357361
end
358362

359363
function DI.jacobian(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
2-
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = min(dimension, 16)
2+
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(16)
33

44
## Annotations
55

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,11 @@ using LinearAlgebra: dot, mul!
5050

5151
DI.check_available(::AutoForwardDiff) = true
5252

53-
function DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C}
54-
if isnothing(C)
55-
return ForwardDiff.pickchunksize(dimension)
56-
else
57-
return min(dimension, C)
58-
end
53+
DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} = Val(C)
54+
55+
function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer)
56+
# type-unstable
57+
return Val(ForwardDiff.pickchunksize(dimension))
5958
end
6059

6160
include("utils.jl")

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using DifferentiationInterface:
2424
outer,
2525
multibasis,
2626
pick_batchsize,
27+
pick_jacobian_batchsize,
2728
pushforward_performance,
2829
unwrap,
2930
with_contexts

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
4242
function DI.prepare_hessian(
4343
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
4444
) where {F,C}
45+
valB = pick_batchsize(backend, length(x))
46+
return _prepare_sparse_hessian_aux(valB, f, backend, x, contexts...)
47+
end
48+
49+
function _prepare_sparse_hessian_aux(
50+
::Val{B}, f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
51+
) where {B,F,C}
4552
dense_backend = dense_ad(backend)
4653
sparsity = hessian_sparsity(
4754
with_contexts(f, contexts...), x, sparsity_detector(backend)
@@ -52,7 +59,6 @@ function DI.prepare_hessian(
5259
)
5360
groups = column_groups(coloring_result)
5461
Ng = length(groups)
55-
B = pick_batchsize(outer(dense_backend), Ng)
5662
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
5763
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
5864
batched_seeds = [

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,28 @@ function DI.prepare_jacobian(
7474
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
7575
) where {F,C}
7676
y = f(x, map(unwrap, contexts)...)
77-
return _prepare_sparse_jacobian_aux(
78-
pushforward_performance(backend), y, (f,), backend, x, contexts...
79-
)
77+
perf = pushforward_performance(backend)
78+
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
79+
return _prepare_sparse_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...)
8080
end
8181

8282
function DI.prepare_jacobian(
8383
f!::F, y, backend::AutoSparse, x, contexts::Vararg{Context,C}
8484
) where {F,C}
85-
return _prepare_sparse_jacobian_aux(
86-
pushforward_performance(backend), y, (f!, y), backend, x, contexts...
87-
)
85+
perf = pushforward_performance(backend)
86+
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
87+
return _prepare_sparse_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...)
8888
end
8989

9090
function _prepare_sparse_jacobian_aux(
91-
::PushforwardFast, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
92-
) where {FY,C}
91+
::PushforwardFast,
92+
::Val{B},
93+
y,
94+
f_or_f!y::FY,
95+
backend::AutoSparse,
96+
x,
97+
contexts::Vararg{Context,C},
98+
) where {B,FY,C}
9399
dense_backend = dense_ad(backend)
94100

95101
sparsity = jacobian_sparsity(
@@ -104,7 +110,6 @@ function _prepare_sparse_jacobian_aux(
104110
)
105111
groups = column_groups(coloring_result)
106112
Ng = length(groups)
107-
B = pick_batchsize(dense_backend, Ng)
108113
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
109114
compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2)
110115
batched_seeds = [
@@ -121,8 +126,14 @@ function _prepare_sparse_jacobian_aux(
121126
end
122127

123128
function _prepare_sparse_jacobian_aux(
124-
::PushforwardSlow, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
125-
) where {FY,C}
129+
::PushforwardSlow,
130+
::Val{B},
131+
y,
132+
f_or_f!y::FY,
133+
backend::AutoSparse,
134+
x,
135+
contexts::Vararg{Context,C},
136+
) where {B,FY,C}
126137
dense_backend = dense_ad(backend)
127138
sparsity = jacobian_sparsity(
128139
fy_with_contexts(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
@@ -136,7 +147,6 @@ function _prepare_sparse_jacobian_aux(
136147
)
137148
groups = row_groups(coloring_result)
138149
Ng = length(groups)
139-
B = pick_batchsize(dense_backend, Ng)
140150
seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups]
141151
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1)
142152
batched_seeds = [

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,29 @@ function prepare_jacobian(
8686
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
8787
) where {F,C}
8888
y = f(x, map(unwrap, contexts)...)
89-
return _prepare_jacobian_aux(
90-
pushforward_performance(backend), y, (f,), backend, x, contexts...
91-
)
89+
perf = pushforward_performance(backend)
90+
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
91+
return _prepare_jacobian_aux(perf, valB, y, (f,), backend, x, contexts...)
9292
end
9393

9494
function prepare_jacobian(
9595
f!::F, y, backend::AbstractADType, x, contexts::Vararg{Context,C}
9696
) where {F,C}
97-
return _prepare_jacobian_aux(
98-
pushforward_performance(backend), y, (f!, y), backend, x, contexts...
99-
)
97+
perf = pushforward_performance(backend)
98+
valB = pick_jacobian_batchsize(perf, backend; N=length(x), M=length(y))
99+
return _prepare_jacobian_aux(perf, valB, y, (f!, y), backend, x, contexts...)
100100
end
101101

102102
function _prepare_jacobian_aux(
103103
::PushforwardFast,
104+
::Val{B},
104105
y,
105106
f_or_f!y::FY,
106107
backend::AbstractADType,
107108
x,
108109
contexts::Vararg{Context,C},
109-
) where {FY,C}
110+
) where {B,FY,C}
110111
N = length(x)
111-
B = pick_batchsize(backend, N)
112112
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
113113
batched_seeds = [
114114
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
@@ -128,14 +128,14 @@ end
128128

129129
function _prepare_jacobian_aux(
130130
::PushforwardSlow,
131+
::Val{B},
131132
y,
132133
f_or_f!y::FY,
133134
backend::AbstractADType,
134135
x,
135136
contexts::Vararg{Context,C},
136-
) where {FY,C}
137+
) where {B,FY,C}
137138
M = length(y)
138-
B = pick_batchsize(backend, M)
139139
seeds = [basis(backend, y, ind) for ind in eachindex(y)]
140140
batched_seeds = [
141141
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
@@ -241,13 +241,14 @@ function _jacobian_aux(
241241
batched_seeds[a],
242242
contexts...,
243243
)
244-
stack(vec, dy_batch; dims=2)
244+
block = stack(vec, dy_batch; dims=2)
245+
if N % B != 0 && a == lastindex(batched_seeds)
246+
block = block[:, 1:(N - (a - 1) * B)]
247+
end
248+
block
245249
end
246250

247251
jac = reduce(hcat, jac_blocks)
248-
if N < size(jac, 2)
249-
jac = jac[:, 1:N]
250-
end
251252
return jac
252253
end
253254

@@ -268,13 +269,14 @@ function _jacobian_aux(
268269
dx_batch = pullback(
269270
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
270271
)
271-
stack(vec, dx_batch; dims=1)
272+
block = stack(vec, dx_batch; dims=1)
273+
if M % B != 0 && a == lastindex(batched_seeds)
274+
block = block[1:(M - (a - 1) * B), :]
275+
end
276+
block
272277
end
273278

274279
jac = reduce(vcat, jac_blocks)
275-
if M < size(jac, 1)
276-
jac = jac[1:M, :]
277-
end
278280
return jac
279281
end
280282

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ end
7272
function prepare_hessian(
7373
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
7474
) where {F,C}
75+
valB = pick_batchsize(backend, length(x))
76+
return _prepare_hessian_aux(valB, f, backend, x, contexts...)
77+
end
78+
79+
function _prepare_hessian_aux(
80+
::Val{B}, f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
81+
) where {B,F,C}
7582
N = length(x)
76-
B = pick_batchsize(outer(backend), N)
7783
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
7884
batched_seeds = [
7985
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
@@ -107,7 +113,11 @@ function hessian(
107113

108114
hess_blocks = map(eachindex(batched_seeds)) do a
109115
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
110-
stack(vec, dg_batch; dims=2)
116+
block = stack(vec, dg_batch; dims=2)
117+
if N % B != 0 && a == lastindex(batched_seeds)
118+
block = block[:, 1:(N - (a - 1) * B)]
119+
end
120+
block
111121
end
112122

113123
hess = reduce(hcat, hess_blocks)

0 commit comments

Comments
 (0)