Skip to content

Commit e52dddf

Browse files
authored
Batch size fixes (#547)
* Add tests that need to pass * Batchsize fixes * Upper bound Mooncake versoin * Don't test on 1.11 * Docs on 1.10 * Test on pre * DIT
1 parent 88c48c1 commit e52dddf

20 files changed

Lines changed: 145 additions & 104 deletions

File tree

.github/workflows/Documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
- uses: actions/checkout@v4
3939
- uses: julia-actions/setup-julia@v2
4040
with:
41-
version: '1'
41+
version: '1.10' # TODO: 1
4242
- uses: julia-actions/cache@v1
4343
- name: Install dependencies
4444
run: julia --project=${{ matrix.pkg.dir}}/docs/ -e '

.github/workflows/Test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
fail-fast: true
2929
matrix:
3030
version:
31-
- "1"
31+
- "1.10" # TODO: 1 (as of 2024.10.08, 1 means 1.11 and we're not ready yet)
3232
- "lts"
3333
- "pre"
3434
group:
@@ -134,7 +134,7 @@ jobs:
134134
fail-fast: true
135135
matrix:
136136
version:
137-
- "1"
137+
- "1.10" # TODO: 1
138138
- "lts"
139139
- "pre"
140140
group:

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.6"
4+
version = "0.6.7"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -56,7 +56,7 @@ ForwardDiff = "0.10.36"
5656
LinearAlgebra = "<0.0.1,1"
5757
Mooncake = "0.4.0"
5858
PackageExtensionCompat = "1.0.2"
59-
PolyesterForwardDiff = "0.1.1"
59+
PolyesterForwardDiff = "0.1.2"
6060
ReverseDiff = "1.15.1"
6161
SparseArrays = "<0.0.1,1"
6262
SparseConnectivityTracer = "0.5.0,0.6"

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ using DifferentiationInterface:
1010
JacobianPrep,
1111
PullbackPrep,
1212
PushforwardPrep,
13-
SecondDerivativePrep,
14-
dense_ad
13+
SecondDerivativePrep
1514
using FastDifferentiation:
1615
derivative,
1716
hessian,
@@ -33,6 +32,9 @@ monovec(x::Number) = [x]
3332
myvec(x::Number) = monovec(x)
3433
myvec(x::AbstractArray) = vec(x)
3534

35+
dense_ad(backend::AutoFastDifferentiation) = backend
36+
dense_ad(backend::AutoSparse{<:AutoFastDifferentiation}) = ADTypes.dense_ad(backend)
37+
3638
include("onearg.jl")
3739
include("twoarg.jl")
3840

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,22 @@ using LinearAlgebra: dot, mul!
5050

5151
DI.check_available(::AutoForwardDiff) = true
5252

53-
DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} = Val(C)
53+
function DI.pick_batchsize(
54+
::AutoForwardDiff{chunksize}, dimension::Integer
55+
) where {chunksize}
56+
return Val{chunksize}()
57+
end
5458

5559
function DI.pick_batchsize(::AutoForwardDiff{nothing}, dimension::Integer)
5660
# type-unstable
5761
return Val(ForwardDiff.pickchunksize(dimension))
5862
end
5963

60-
function DI.threshold_batchsize(backend::AutoForwardDiff{C1}, C2::Integer) where {C1}
61-
C = (C1 === nothing) ? nothing : min(C1, C2)
62-
return AutoForwardDiff(; chunksize=C, tag=backend.tag)
64+
function DI.threshold_batchsize(
65+
backend::AutoForwardDiff{chunksize1}, chunksize2::Integer
66+
) where {chunksize1}
67+
chunksize = (chunksize1 === nothing) ? nothing : min(chunksize1, chunksize2)
68+
return AutoForwardDiff(; chunksize, tag=backend.tag)
6369
end
6470

6571
include("utils.jl")

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
2-
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()
2+
choose_chunk(::AutoForwardDiff{chunksize}, x) where {chunksize} = Chunk{chunksize}()
33

4-
tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T
5-
tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = typeof(Tag(f, eltype(x)))
4+
tag_type(f, ::AutoForwardDiff{chunksize,T}, x) where {chunksize,T} = T
5+
6+
function tag_type(f, ::AutoForwardDiff{chunksize,Nothing}, x) where {chunksize}
7+
return typeof(Tag(f, eltype(x)))
8+
end
69

710
function make_dual_similar(::Type{T}, x::Number, tx::NTuple{B}) where {T,B}
811
return Dual{T}(x, tx...)

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian!
2222
using PolyesterForwardDiff.ForwardDiff: Chunk
2323
using PolyesterForwardDiff.ForwardDiff.DiffResults: DiffResults
2424

25-
function single_threaded(backend::AutoPolyesterForwardDiff{C,T}) where {C,T}
26-
return AutoForwardDiff{C,T}(backend.tag)
25+
function single_threaded(backend::AutoPolyesterForwardDiff{chunksize,T}) where {chunksize,T}
26+
return AutoForwardDiff{chunksize,T}(backend.tag)
2727
end
2828

2929
DI.check_available(::AutoPolyesterForwardDiff) = true
@@ -33,10 +33,10 @@ function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, dimension::Integer
3333
end
3434

3535
function DI.threshold_batchsize(
36-
backend::AutoPolyesterForwardDiff{C1}, C2::Integer
37-
) where {C1}
38-
C = (C1 === nothing) ? nothing : min(C1, C2)
39-
return AutoPolyesterForwardDiff(; chunksize=C, tag=backend.tag)
36+
backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer
37+
) where {chunksize1}
38+
chunksize = (chunksize1 === nothing) ? nothing : min(chunksize1, chunksize2)
39+
return AutoPolyesterForwardDiff(; chunksize, tag=backend.tag)
4040
end
4141

4242
include("onearg.jl")

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -107,109 +107,113 @@ end
107107

108108
## Gradient
109109

110+
struct PolyesterForwardDiffGradientPrep{chunksize} <: GradientPrep
111+
chunk::Chunk{chunksize}
112+
end
113+
110114
function DI.prepare_gradient(
111-
f, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{Context,C}
112-
) where {C}
113-
return DI.prepare_gradient(f, single_threaded(backend), x, contexts...)
115+
f, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
116+
) where {chunksize,C}
117+
if isnothing(chunksize)
118+
chunk = Chunk(x)
119+
else
120+
chunk = Chunk{chunksize}()
121+
end
122+
return PolyesterForwardDiffGradientPrep(chunk)
114123
end
115124

116125
function DI.value_and_gradient!(
117126
f,
118127
grad,
119-
::GradientPrep,
120-
::AutoPolyesterForwardDiff{K},
121-
x::AbstractVector,
128+
prep::PolyesterForwardDiffGradientPrep,
129+
::AutoPolyesterForwardDiff,
130+
x,
122131
contexts::Vararg{Context,C},
123-
) where {K,C}
132+
) where {C}
124133
fc = with_contexts(f, contexts...)
125-
threaded_gradient!(fc, grad, x, Chunk{K}())
134+
threaded_gradient!(fc, grad, x, prep.chunk)
126135
return fc(x), grad
127136
end
128137

129138
function DI.gradient!(
130139
f,
131140
grad,
132-
::GradientPrep,
133-
::AutoPolyesterForwardDiff{K},
134-
x::AbstractVector,
141+
prep::PolyesterForwardDiffGradientPrep,
142+
::AutoPolyesterForwardDiff,
143+
x,
135144
contexts::Vararg{Context,C},
136-
) where {K,C}
145+
) where {C}
137146
fc = with_contexts(f, contexts...)
138-
threaded_gradient!(fc, grad, x, Chunk{K}())
147+
threaded_gradient!(fc, grad, x, prep.chunk)
139148
return grad
140149
end
141150

142-
function DI.value_and_gradient!(
151+
function DI.value_and_gradient(
143152
f,
144-
grad,
145-
prep::GradientPrep,
153+
prep::PolyesterForwardDiffGradientPrep,
146154
backend::AutoPolyesterForwardDiff,
147155
x,
148156
contexts::Vararg{Context,C},
149157
) where {C}
150-
return DI.value_and_gradient!(f, grad, prep, single_threaded(backend), x, contexts...)
158+
return DI.value_and_gradient!(f, similar(x), prep, backend, x, contexts...)
151159
end
152160

153-
function DI.gradient!(
161+
function DI.gradient(
154162
f,
155-
grad,
156-
prep::GradientPrep,
163+
prep::PolyesterForwardDiffGradientPrep,
157164
backend::AutoPolyesterForwardDiff,
158165
x,
159166
contexts::Vararg{Context,C},
160-
) where {C}
161-
return DI.gradient!(f, grad, prep, single_threaded(backend), x, contexts...)
162-
end
163-
164-
function DI.value_and_gradient(
165-
f, prep::GradientPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{Context,C}
166-
) where {C}
167-
return DI.value_and_gradient!(f, similar(x), prep, backend, x, contexts...)
168-
end
169-
170-
function DI.gradient(
171-
f, prep::GradientPrep, backend::AutoPolyesterForwardDiff, x, contexts::Vararg{Context,C}
172167
) where {C}
173168
return DI.gradient!(f, similar(x), prep, backend, x, contexts...)
174169
end
175170

176171
## Jacobian
177172

173+
struct PolyesterForwardDiffOneArgJacobianPrep{chunksize} <: JacobianPrep
174+
chunk::Chunk{chunksize}
175+
end
176+
178177
function DI.prepare_jacobian(
179-
f, ::AutoPolyesterForwardDiff, x, contexts::Vararg{Context,C}
180-
) where {C}
181-
return NoJacobianPrep()
178+
f, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
179+
) where {chunksize,C}
180+
if isnothing(chunksize)
181+
chunk = Chunk(x)
182+
else
183+
chunk = Chunk{chunksize}()
184+
end
185+
return PolyesterForwardDiffOneArgJacobianPrep(chunk)
182186
end
183187

184188
function DI.value_and_jacobian!(
185189
f,
186-
jac::AbstractMatrix,
187-
::NoJacobianPrep,
188-
::AutoPolyesterForwardDiff{K},
189-
x::AbstractArray,
190+
jac,
191+
prep::PolyesterForwardDiffOneArgJacobianPrep,
192+
::AutoPolyesterForwardDiff,
193+
x,
190194
contexts::Vararg{Context,C},
191-
) where {K,C}
195+
) where {C}
192196
fc = with_contexts(f, contexts...)
193-
return fc(x), threaded_jacobian!(fc, jac, x, Chunk{K}())
197+
return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk)
194198
end
195199

196200
function DI.jacobian!(
197201
f,
198-
jac::AbstractMatrix,
199-
::NoJacobianPrep,
200-
::AutoPolyesterForwardDiff{K},
201-
x::AbstractArray,
202+
jac,
203+
prep::PolyesterForwardDiffOneArgJacobianPrep,
204+
::AutoPolyesterForwardDiff,
205+
x,
202206
contexts::Vararg{Context,C},
203-
) where {K,C}
207+
) where {C}
204208
fc = with_contexts(f, contexts...)
205-
return threaded_jacobian!(fc, jac, x, Chunk{K}())
209+
return threaded_jacobian!(fc, jac, x, prep.chunk)
206210
end
207211

208212
function DI.value_and_jacobian(
209213
f,
210-
prep::NoJacobianPrep,
214+
prep::PolyesterForwardDiffOneArgJacobianPrep,
211215
backend::AutoPolyesterForwardDiff,
212-
x::AbstractArray,
216+
x,
213217
contexts::Vararg{Context,C},
214218
) where {C}
215219
y = f(x, map(unwrap, contexts)...)
@@ -220,9 +224,9 @@ end
220224

221225
function DI.jacobian(
222226
f,
223-
prep::NoJacobianPrep,
227+
prep::PolyesterForwardDiffOneArgJacobianPrep,
224228
backend::AutoPolyesterForwardDiff,
225-
x::AbstractArray,
229+
x,
226230
contexts::Vararg{Context,C},
227231
) where {C}
228232
y = f(x, map(unwrap, contexts)...)

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,32 @@ end
118118

119119
## Jacobian
120120

121+
struct PolyesterForwardDiffTwoArgJacobianPrep{chunksize} <: JacobianPrep
122+
chunk::Chunk{chunksize}
123+
end
124+
121125
function DI.prepare_jacobian(
122-
f!, y, ::AutoPolyesterForwardDiff, x, contexts::Vararg{Context,C}
123-
) where {C}
124-
return NoJacobianPrep()
126+
f!, y, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
127+
) where {chunksize,C}
128+
if isnothing(chunksize)
129+
chunk = Chunk(x)
130+
else
131+
chunk = Chunk{chunksize}()
132+
end
133+
return PolyesterForwardDiffTwoArgJacobianPrep(chunk)
125134
end
126135

127136
function DI.value_and_jacobian(
128-
f!, y, ::NoJacobianPrep, ::AutoPolyesterForwardDiff{K}, x, contexts::Vararg{Context,C}
137+
f!,
138+
y,
139+
prep::PolyesterForwardDiffTwoArgJacobianPrep,
140+
::AutoPolyesterForwardDiff{K},
141+
x,
142+
contexts::Vararg{Context,C},
129143
) where {K,C}
130144
fc! = with_contexts(f!, contexts...)
131145
jac = similar(y, length(y), length(x))
132-
threaded_jacobian!(fc!, y, jac, x, Chunk{K}())
146+
threaded_jacobian!(fc!, y, jac, x, prep.chunk)
133147
fc!(y, x)
134148
return y, jac
135149
end
@@ -138,36 +152,41 @@ function DI.value_and_jacobian!(
138152
f!,
139153
y,
140154
jac,
141-
::NoJacobianPrep,
155+
prep::PolyesterForwardDiffTwoArgJacobianPrep,
142156
::AutoPolyesterForwardDiff{K},
143157
x,
144158
contexts::Vararg{Context,C},
145159
) where {K,C}
146160
fc! = with_contexts(f!, contexts...)
147-
threaded_jacobian!(fc!, y, jac, x, Chunk{K}())
161+
threaded_jacobian!(fc!, y, jac, x, prep.chunk)
148162
fc!(y, x)
149163
return y, jac
150164
end
151165

152166
function DI.jacobian(
153-
f!, y, ::NoJacobianPrep, ::AutoPolyesterForwardDiff{K}, x, contexts::Vararg{Context,C}
154-
) where {K,C}
167+
f!,
168+
y,
169+
prep::PolyesterForwardDiffTwoArgJacobianPrep,
170+
::AutoPolyesterForwardDiff,
171+
x,
172+
contexts::Vararg{Context,C},
173+
) where {C}
155174
fc! = with_contexts(f!, contexts...)
156175
jac = similar(y, length(y), length(x))
157-
threaded_jacobian!(fc!, y, jac, x, Chunk{K}())
176+
threaded_jacobian!(fc!, y, jac, x, prep.chunk)
158177
return jac
159178
end
160179

161180
function DI.jacobian!(
162181
f!,
163182
y,
164183
jac,
165-
::NoJacobianPrep,
166-
::AutoPolyesterForwardDiff{K},
184+
prep::PolyesterForwardDiffTwoArgJacobianPrep,
185+
::AutoPolyesterForwardDiff,
167186
x,
168187
contexts::Vararg{Context,C},
169-
) where {K,C}
188+
) where {C}
170189
fc! = with_contexts(f!, contexts...)
171-
threaded_jacobian!(fc!, y, jac, x, Chunk{K}())
190+
threaded_jacobian!(fc!, y, jac, x, prep.chunk)
172191
return jac
173192
end

0 commit comments

Comments
 (0)