Skip to content

Commit 102fa86

Browse files
authored
Set default batch size to 1 (#340)
* Default batch size to 1 * Add test * Remove ambiguity * Default Enzyme batch size to 8 * Actuallly 16 * Avoid ambiguity
1 parent ff529cb commit 102fa86

8 files changed

Lines changed: 48 additions & 35 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.5.6"
4+
version = "0.5.7"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ reverse_mode(::AnyAutoEnzyme{Nothing}) = Reverse
5757

5858
DI.check_available(::AutoEnzyme) = true
5959

60+
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
61+
DI.pick_batchsize(::AnyAutoEnzyme, dimension::Integer) = min(dimension, 16)
62+
6063
# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
6164
function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
6265
b = zero(a)

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ abstract type FromPrimitive <: AbstractADType end
33
check_available(fromprim::FromPrimitive) = check_available(fromprim.backend)
44
twoarg_support(fromprim::FromPrimitive) = twoarg_support(fromprim.backend)
55

6+
function pick_batchsize(fromprim::FromPrimitive, dimension::Integer)
7+
return pick_batchsize(fromprim.backend, dimension)
8+
end
9+
610
## Forward
711

812
struct AutoForwardFromPrimitive{B} <: FromPrimitive

DifferentiationInterface/src/utils/batch.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
pick_batchsize(backend::AbstractADType, dimension::Integer)
33
44
Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`.
5+
6+
Returns `1` for backends which have not overloaded it.
57
"""
6-
function pick_batchsize(::AbstractADType, dimension::Integer)
7-
return min(dimension, 8)
8-
end
8+
pick_batchsize(::AbstractADType, dimension::Integer) = 1
99

1010
"""
1111
Batch{B,T}

DifferentiationInterface/test/Internals/autosparse.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using ADTypes
2+
using DifferentiationInterface
3+
import DifferentiationInterface as DI
4+
using Test
5+
6+
@testset "SecondOrder" begin
7+
backend = SecondOrder(AutoForwardDiff(), AutoZygote())
8+
@test ADTypes.mode(backend) isa ADTypes.ForwardMode
9+
@test DifferentiationInterface.outer(backend) isa AutoForwardDiff
10+
@test DifferentiationInterface.inner(backend) isa AutoZygote
11+
end
12+
13+
@testset "Sparse" begin
14+
for backend in [AutoForwardDiff(), AutoZygote()]
15+
sparse_backend = AutoSparse(backend)
16+
@test ADTypes.mode(sparse_backend) == ADTypes.mode(backend)
17+
@test check_available(sparse_backend) == check_available(backend)
18+
@test DI.twoarg_support(sparse_backend) == DI.twoarg_support(backend)
19+
@test DI.pushforward_performance(sparse_backend) ==
20+
DI.pushforward_performance(backend)
21+
@test DI.pullback_performance(sparse_backend) == DI.pullback_performance(backend)
22+
end
23+
24+
for backend in [
25+
SecondOrder(AutoForwardDiff(), AutoZygote()),
26+
SecondOrder(AutoZygote(), AutoForwardDiff()),
27+
]
28+
sparse_backend = AutoSparse(backend)
29+
@test ADTypes.mode(sparse_backend) == ADTypes.mode(backend)
30+
@test DI.hvp_mode(sparse_backend) == DI.hvp_mode(backend)
31+
end
32+
end
33+
34+
@testset "Batch size" begin
35+
@test DI.pick_batchsize(AutoZygote(), 2) == 1
36+
end

DifferentiationInterface/test/Internals/second_order.jl

Lines changed: 0 additions & 9 deletions
This file was deleted.

DifferentiationInterface/test/Single/ForwardDiff/fromprimitive.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ for backend in vcat(fromprimitive_backends)
1313
@test check_available(backend)
1414
@test check_twoarg(backend)
1515
@test check_hessian(backend)
16+
@test DifferentiationInterface.pick_batchsize(backend, 100) == 5
1617
end
1718

1819
## Dense backends

0 commit comments

Comments
 (0)