Skip to content

Commit 8f93335

Browse files
committed
Add more support for zero batch size (incomplete Jacobian and Hessian)
1 parent babbf11 commit 8f93335

6 files changed

Lines changed: 73 additions & 19 deletions

File tree

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,14 @@ struct PullbackJacobianPrep{
153153
S<:AbstractVector{<:NTuple},
154154
R<:AbstractVector{<:NTuple},
155155
E<:PullbackPrep,
156+
Y,
156157
} <: StandardJacobianPrep{SIG}
157158
_sig::Val{SIG}
158159
batch_size_settings::BS
159160
batched_seeds::S
160161
batched_results::R
161162
pullback_prep::E
163+
y_example::Y
162164
end
163165

164166
function prepare_jacobian_nokwarg(
@@ -212,7 +214,7 @@ function _prepare_jacobian_aux(
212214
]
213215
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
214216
pushforward_prep = prepare_pushforward_nokwarg(
215-
strict, f_or_f!y..., backend, x, batched_seeds[1], contexts...
217+
strict, f_or_f!y..., backend, x, ntuple(b -> zero(x), Val(B)), contexts...
216218
)
217219
return PushforwardJacobianPrep(
218220
_sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
@@ -237,10 +239,10 @@ function _prepare_jacobian_aux(
237239
]
238240
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
239241
pullback_prep = prepare_pullback_nokwarg(
240-
strict, f_or_f!y..., backend, x, batched_seeds[1], contexts...
242+
strict, f_or_f!y..., backend, x, ntuple(b -> zero(y), Val(B)), contexts...
241243
)
242244
return PullbackJacobianPrep(
243-
_sig, batch_size_settings, batched_seeds, batched_results, pullback_prep
245+
_sig, batch_size_settings, batched_seeds, batched_results, pullback_prep, y
244246
)
245247
end
246248

@@ -367,7 +369,7 @@ function _jacobian_aux(
367369
(; A, B_last) = batch_size_settings
368370

369371
pushforward_prep_same = prepare_pushforward_same_point(
370-
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
372+
f_or_f!y..., pushforward_prep, backend, x, ntuple(b -> zero(x), Val(B)), contexts...
371373
)
372374

373375
jac = mapreduce(hcat, eachindex(batched_seeds)) do a
@@ -419,11 +421,16 @@ function _jacobian_aux(
419421
x,
420422
contexts::Vararg{Context,C},
421423
) where {FY,SIG,B,aligned,C}
422-
(; batch_size_settings, batched_seeds, pullback_prep) = prep
424+
(; batch_size_settings, batched_seeds, pullback_prep, y_example) = prep
423425
(; A, B_last) = batch_size_settings
424426

425427
pullback_prep_same = prepare_pullback_same_point(
426-
f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts...
428+
f_or_f!y...,
429+
pullback_prep,
430+
backend,
431+
x,
432+
ntuple(b -> zero(y_example), Val(B)),
433+
contexts...,
427434
)
428435

429436
jac = mapreduce(vcat, eachindex(batched_seeds)) do a
@@ -487,11 +494,16 @@ function _jacobian_aux!(
487494
x,
488495
contexts::Vararg{Context,C},
489496
) where {FY,SIG,B,C}
490-
(; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep
497+
(; batch_size_settings, batched_seeds, batched_results, pullback_prep, y_example) = prep
491498
(; N) = batch_size_settings
492499

493500
pullback_prep_same = prepare_pullback_same_point(
494-
f_or_f!y..., pullback_prep, backend, x, batched_seeds[1], contexts...
501+
f_or_f!y...,
502+
pullback_prep,
503+
backend,
504+
x,
505+
ntuple(b -> zero(y_example), Val(B)),
506+
contexts...,
495507
)
496508

497509
for a in eachindex(batched_seeds, batched_results)

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,13 @@ function _prepare_pullback_aux(
285285
contexts::Vararg{Context,C};
286286
) where {F,C}
287287
_sig = signature(f, backend, x, ty, contexts...; strict)
288-
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
288+
dx = if x isa Number
289+
oneunit(x)
290+
elseif isempty(x)
291+
zero(x)
292+
else
293+
basis(x, first(CartesianIndices(x)))
294+
end
289295
pushforward_prep = prepare_pushforward_nokwarg(
290296
strict, f, backend, x, (dx,), contexts...
291297
)
@@ -303,7 +309,13 @@ function _prepare_pullback_aux(
303309
contexts::Vararg{Context,C};
304310
) where {F,C}
305311
_sig = signature(f!, y, backend, x, ty, contexts...; strict)
306-
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
312+
dx = if x isa Number
313+
oneunit(x)
314+
elseif isempty(x)
315+
zero(x)
316+
else
317+
basis(x, first(CartesianIndices(x)))
318+
end
307319
pushforward_prep = prepare_pushforward_nokwarg(
308320
strict, f!, y, backend, x, (dx,), contexts...
309321
)

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,13 @@ function _prepare_pushforward_aux(
290290
) where {F,C}
291291
_sig = signature(f, backend, x, tx, contexts...; strict)
292292
y = f(x, map(unwrap, contexts)...)
293-
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
293+
dy = if y isa Number
294+
oneunit(y)
295+
elseif isempty(y)
296+
zero(y)
297+
else
298+
basis(y, first(CartesianIndices(y)))
299+
end
294300
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
295301
return PullbackPushforwardPrep(_sig, pullback_prep)
296302
end
@@ -306,7 +312,13 @@ function _prepare_pushforward_aux(
306312
contexts::Vararg{Context,C};
307313
) where {F,C}
308314
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
309-
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
315+
dy = if y isa Number
316+
oneunit(y)
317+
elseif isempty(y)
318+
zero(y)
319+
else
320+
basis(y, first(CartesianIndices(y)))
321+
end
310322
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
311323
return PullbackPushforwardPrep(_sig, pullback_prep)
312324
end

DifferentiationInterface/src/utils/batchsize.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,25 @@ end
2323

2424
function BatchSizeSettings{B,singlebatch,aligned}(N::Integer) where {B,singlebatch,aligned}
2525
B > N > 0 && throw(ArgumentError("Batch size $B larger than input size $N"))
26-
A = div(N, B, RoundUp)
27-
B_last = N % B
26+
if B == N == 0
27+
A = B_last = 0
28+
else
29+
A = div(N, B, RoundUp)
30+
B_last = N % B
31+
end
2832
return BatchSizeSettings{B,singlebatch,aligned}(N, A, B_last)
2933
end
3034

3135
function BatchSizeSettings{B}(::Val{N}) where {B,N}
3236
singlebatch = B == N
33-
aligned = N % B == 0
37+
aligned = (B == N == 0) || (N % B == 0)
3438
return BatchSizeSettings{B,singlebatch,aligned}(N)
3539
end
3640

3741
function BatchSizeSettings{B}(N::Integer) where {B}
3842
# type-unstable
3943
singlebatch = B == N
40-
aligned = N % B == 0
44+
aligned = (B == N == 0) || (N % B == 0)
4145
return BatchSizeSettings{B,singlebatch,aligned}(N)
4246
end
4347

@@ -123,9 +127,7 @@ Reproduces the heuristic from ForwardDiff to minimize
123127
Source: https://github.com/JuliaDiff/ForwardDiff.jl/blob/ec74fbc32b10bbf60b3c527d8961666310733728/src/prelude.jl#L19-L29
124128
"""
125129
function reasonable_batchsize(N::Integer, Bmax::Integer)
126-
if N == 0
127-
return 1
128-
elseif N <= Bmax
130+
if N <= Bmax
129131
return N
130132
else
131133
A = div(N, Bmax, RoundUp)

DifferentiationInterface/test/Core/Internals/batchsize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ BSS = BatchSizeSettings
2525
end
2626

2727
@testset "SimpleFiniteDiff (adaptive)" begin
28+
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(0))) isa BSS{0,true,true}
2829
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(2))) isa BSS{2,true,true}
2930
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(6))) isa BSS{6,true,true}
3031
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(12))) isa BSS{12,true,true}
3132
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(24))) isa BSS{12,false,true}
3233
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(100))) isa BSS{12,false,false}
34+
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(0)))) isa
35+
BSS{0,true,true}
3336
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(2)))) isa
3437
BSS{2,true,true}
3538
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(6)))) isa

DifferentiationInterface/test/Core/ZeroBackends/test.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using DifferentiationInterface
22
using DifferentiationInterface: AutoZeroForward, AutoZeroReverse
33
using DifferentiationInterfaceTest
4+
using LinearAlgebra
45
using ComponentArrays: ComponentArrays
56
using JLArrays: JLArrays
67
using SparseMatrixColorings
@@ -50,3 +51,15 @@ end
5051
logging=LOGGING,
5152
)
5253
end
54+
55+
@testset "Empty arrays" begin
56+
make_empty(t) = typeof(t)[]
57+
make_empty!(y, t) = nothing
58+
@test gradient(sum, AutoZeroForward(), Float64[]) == Float64[]
59+
@test derivative(make_empty, AutoZeroReverse(), 1.0) == Float64[]
60+
@test derivative(make_empty!, Float64[], AutoZeroReverse(), 1.0) == Float64[]
61+
@test_broken jacobian(copy, AutoZeroForward(), Float64[]) == I(0)
62+
@test_broken jacobian(copy, AutoZeroReverse(), Float64[]) == I(0)
63+
@test_broken jacobian(copyto!, Float64[], AutoZeroForward(), Float64[]) == I(0)
64+
@test_broken jacobian(copyto!, Float64[], AutoZeroReverse(), Float64[]) == I(0)
65+
end

0 commit comments

Comments
 (0)