Skip to content

Commit 4496997

Browse files
authored
Revamp batch size handling (#575)
* Batch size * Revamp batch size computations * Fix * Single batch modifications * Introduce BatchSizeSettings * Fix PolyesterFD * Fix * Fix internals * Type stab * Fix * Fixes * Coverage * Fix * Guess activity in Enzyme * Fix * Fixes * No static test * More coverage * AutoSparse with adaptive backends * Proper thresholding * Fix * Fixes * Fix * Up * Fix
1 parent fd7580c commit 4496997

25 files changed

Lines changed: 629 additions & 301 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.12"
4+
version = "0.6.13"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/explanation/advanced.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,23 @@ The complexity of sparse Jacobians or Hessians grows with the number of distinct
6767
To reduce this number of colors, [`GreedyColoringAlgorithm`](@ref) has two main settings: the order used for vertices and the decompression method.
6868
Depending on your use case, you may want to modify either of these options to increase performance.
6969
See the documentation of [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl) for details.
70+
71+
## Batch mode
72+
73+
### Multiple tangents
74+
75+
The [`jacobian`](@ref) and [`hessian`](@ref) operators compute matrices by repeatedly applying lower-level operators ([`pushforward`](@ref), [`pullback`](@ref) or [`hvp`](@ref)) to a set of tangents.
76+
The tangents usually correspond to basis elements of the appropriate vector space.
77+
We could call the lower-level operator on each tangent separately, but some packages ([ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)) have optimized implementations to handle multiple tangents at once.
78+
79+
This behavior is often called "vector mode" AD, but we call it "batch mode" to avoid confusion with Julia's `Vector` type.
80+
As a matter of fact, the optimal batch size $B$ (number of simultaneous tangents) is usually very small, so tangents are passed within an `NTuple` and not a `Vector`.
81+
When the underlying vector space has dimension $N$, the operators `jacobian` and `hessian` process $\lceil N / B \rceil$ batches of size $B$ each.
82+
83+
### Optimal batch size
84+
85+
For every backend which does not support batch mode, the batch size is set to $B = 1$.
86+
But for [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) and [`AutoEnzyme`](@extref ADTypes.AutoEnzyme), more complicated rules apply.
87+
If the backend object has a pre-determined batch size $B_0$, then we always set $B = B_0$.
88+
In particular, this will throw errors when $N < B_0$.
89+
On the other hand, without a pre-determined batch size, we apply backend-specific heuristics to pick $B$ based on $N$.

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ using DifferentiationInterface:
1616
NoHVPPrep,
1717
NoJacobianPrep,
1818
NoPullbackPrep,
19-
NoPushforwardPrep,
20-
pick_batchsize
19+
NoPushforwardPrep
2120
using Enzyme:
2221
Active,
2322
Annotation,
2423
BatchDuplicated,
24+
BatchMixedDuplicated,
2525
Const,
2626
Duplicated,
2727
DuplicatedNoNeed,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ end
121121
function DI.prepare_gradient(
122122
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
123123
) where {F}
124-
valB = pick_batchsize(backend, length(x))
124+
valB = to_val(DI.pick_batchsize(backend, x))
125125
shadows = create_shadows(valB, x)
126126
return EnzymeForwardGradientPrep(valB, shadows)
127127
end
@@ -190,7 +190,7 @@ function DI.prepare_jacobian(
190190
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
191191
) where {F}
192192
y = f(x)
193-
valB = pick_batchsize(backend, length(x))
193+
valB = to_val(DI.pick_batchsize(backend, x))
194194
shadows = create_shadows(valB, x)
195195
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
196196
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ end
337337
function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
338338
y = f(x)
339339
Sy = size(y)
340-
valB = pick_batchsize(backend, prod(Sy))
340+
valB = to_val(DI.pick_batchsize(backend, y))
341341
return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
342342
end
343343

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
2-
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(min(dimension, 16))
2+
function DI.BatchSizeSettings(::AutoEnzyme, N::Integer)
3+
B = DI.reasonable_batchsize(N, 16)
4+
singlebatch = B == N
5+
aligned = N % B == 0
6+
return DI.BatchSizeSettings{B,singlebatch,aligned}(N)
7+
end
8+
9+
to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B)
310

411
## Annotations
512

@@ -17,9 +24,10 @@ function get_f_and_df(
1724
M,
1825
<:Union{
1926
Duplicated,
20-
EnzymeCore.DuplicatedNoNeed,
27+
MixedDuplicated,
2128
BatchDuplicated,
22-
EnzymeCore.BatchDuplicatedFunc,
29+
BatchMixedDuplicated,
30+
EnzymeCore.DuplicatedNoNeed,
2331
EnzymeCore.BatchDuplicatedNoNeed,
2432
},
2533
},

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Base: Fix1, Fix2
55
import DifferentiationInterface as DI
66
using DifferentiationInterface:
7+
BatchSizeSettings,
78
Context,
89
DerivativePrep,
910
DifferentiateWith,
@@ -49,24 +50,6 @@ using LinearAlgebra: dot, mul!
4950

5051
DI.check_available(::AutoForwardDiff) = true
5152

52-
function DI.pick_batchsize(
53-
::AutoForwardDiff{chunksize}, dimension::Integer
54-
) where {chunksize}
55-
return Val{chunksize}()
56-
end
57-
58-
function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer)
59-
# type-unstable
60-
return Val(ForwardDiff.pickchunksize(dimension))
61-
end
62-
63-
function DI.threshold_batchsize(
64-
backend::AutoForwardDiff{chunksize1}, chunksize2::Integer
65-
) where {chunksize1}
66-
chunksize = (chunksize1 === nothing) ? nothing : min(chunksize1, chunksize2)
67-
return AutoForwardDiff(; chunksize, tag=backend.tag)
68-
end
69-
7053
include("utils.jl")
7154
include("onearg.jl")
7255
include("twoarg.jl")

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, N::Integer)
2+
B = ForwardDiff.pickchunksize(N)
3+
singlebatch = B == N
4+
aligned = N % B == 0
5+
return BatchSizeSettings{B,singlebatch,aligned}(N)
6+
end
7+
8+
function DI.BatchSizeSettings(::AutoForwardDiff{chunksize}, N::Integer) where {chunksize}
9+
if chunksize > N
10+
throw(ArgumentError("Fixed chunksize $chunksize larger than input size $N"))
11+
end
12+
B = chunksize
13+
singlebatch = B == N
14+
aligned = N % B == 0
15+
return BatchSizeSettings{B,singlebatch,aligned}(N)
16+
end
17+
18+
function DI.threshold_batchsize(
19+
backend::AutoForwardDiff{chunksize1}, chunksize2::Integer
20+
) where {chunksize1}
21+
chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2)
22+
return AutoForwardDiff(; chunksize, tag=backend.tag)
23+
end
24+
125
choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
226
choose_chunk(::AutoForwardDiff{chunksize}, x) where {chunksize} = Chunk{chunksize}()
327

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ end
2828

2929
DI.check_available(::AutoPolyesterForwardDiff) = true
3030

31-
function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, dimension::Integer)
32-
return DI.pick_batchsize(single_threaded(backend), dimension)
31+
function DI.BatchSizeSettings(backend::AutoPolyesterForwardDiff, x_or_N)
32+
return DI.BatchSizeSettings(single_threaded(backend), x_or_N)
3333
end
3434

3535
function DI.threshold_batchsize(
3636
backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer
3737
) where {chunksize1}
38-
chunksize = (chunksize1 === nothing) ? nothing : min(chunksize1, chunksize2)
38+
chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2)
3939
return AutoPolyesterForwardDiff(; chunksize, tag=backend.tag)
4040
end
4141

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@ using ADTypes:
1010
hessian_sparsity
1111
using DifferentiationInterface
1212
using DifferentiationInterface:
13+
BatchSizeSettings,
1314
GradientPrep,
1415
HessianPrep,
1516
HVPPrep,
1617
JacobianPrep,
1718
PullbackPrep,
1819
PushforwardPrep,
1920
PushforwardFast,
20-
PushforwardSlow,
21+
PushforwardPerformance,
2122
inner,
23+
outer,
2224
multibasis,
23-
pick_hessian_batchsize,
24-
pick_jacobian_batchsize,
25+
pick_batchsize,
2526
pushforward_performance,
2627
unwrap,
2728
with_contexts

0 commit comments

Comments
 (0)