forked from JuliaDiff/DifferentiationInterface.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatchsize.jl
More file actions
108 lines (102 loc) · 4.45 KB
/
batchsize.jl
File metadata and controls
108 lines (102 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
using ADTypes
using DifferentiationInterface
using DifferentiationInterface:
AutoSimpleFiniteDiff,
BatchSizeSettings,
pick_batchsize,
reasonable_batchsize,
threshold_batchsize
import DifferentiationInterface as DI
using StaticArrays
using Test
BSS = BatchSizeSettings
@testset "Default" begin
@test (@inferred pick_batchsize(AutoZygote(), zeros(0))) isa BSS{1,false,true}
@test (@inferred pick_batchsize(AutoZygote(), zeros(2))) isa BSS{1,false,true}
@test (@inferred pick_batchsize(AutoZygote(), zeros(100))) isa BSS{1,false,true}
@test_throws ArgumentError pick_batchsize(AutoSparse(AutoZygote()), zeros(2))
@test_throws ArgumentError pick_batchsize(
SecondOrder(AutoZygote(), AutoZygote()), zeros(2)
)
@test_throws ArgumentError pick_batchsize(
MixedMode(AutoSimpleFiniteDiff(), AutoZygote()), zeros(2)
)
end
@testset "SimpleFiniteDiff (adaptive)" begin
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(0))) isa BSS{1,false,true}
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(2))) isa BSS{2,true,true}
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(6))) isa BSS{6,true,true}
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(12))) isa BSS{12,true,true}
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(24))) isa BSS{12,false,true}
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(100))) isa BSS{12,false,false}
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(0)))) isa
BSS{0,true,true}
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(2)))) isa
BSS{2,true,true}
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(6)))) isa
BSS{6,true,true}
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(100)))) isa
BSS{100,true,true}
end
@testset "SimpleFiniteDiff (fixed)" begin
@test_throws ArgumentError pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(2))
@test_throws ArgumentError pick_batchsize(
AutoSimpleFiniteDiff(; chunksize=4), @SVector(zeros(2))
)
@test pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(6)) isa BSS{4}
@test pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(100)) isa BSS{4}
BSS{4,true,true}
@test pick_batchsize(AutoSimpleFiniteDiff(; chunksize=4), zeros(99)) isa BSS{4}
BSS{4,true,false}
@test (@inferred pick_batchsize(
AutoSimpleFiniteDiff(; chunksize=4), @SVector(zeros(6))
)) isa BSS{4,false,false}
@test (@inferred pick_batchsize(
AutoSimpleFiniteDiff(; chunksize=4), @SVector(zeros(100))
)) isa BSS{4,false,true}
end
@testset "Thresholding" begin
@test threshold_batchsize(AutoSimpleFiniteDiff(), 2) isa AutoSimpleFiniteDiff{nothing}
@test threshold_batchsize(AutoSimpleFiniteDiff(; chunksize=4), 2) isa
AutoSimpleFiniteDiff{2}
@test threshold_batchsize(AutoSimpleFiniteDiff(; chunksize=4), 6) isa
AutoSimpleFiniteDiff{4}
@test threshold_batchsize(AutoSparse(AutoSimpleFiniteDiff(; chunksize=4)), 2) isa
AutoSparse{<:AutoSimpleFiniteDiff{2}}
@test threshold_batchsize(
SecondOrder(
AutoSimpleFiniteDiff(; chunksize=4), AutoSimpleFiniteDiff(; chunksize=3)
),
6,
) isa SecondOrder{<:AutoSimpleFiniteDiff{4},<:AutoSimpleFiniteDiff{3}}
@test threshold_batchsize(
SecondOrder(
AutoSimpleFiniteDiff(; chunksize=4), AutoSimpleFiniteDiff(; chunksize=3)
),
2,
) isa SecondOrder{<:AutoSimpleFiniteDiff{2},<:AutoSimpleFiniteDiff{2}}
@test threshold_batchsize(
SecondOrder(
AutoSimpleFiniteDiff(; chunksize=1), AutoSimpleFiniteDiff(; chunksize=3)
),
2,
) isa SecondOrder{<:AutoSimpleFiniteDiff{1},<:AutoSimpleFiniteDiff{2}}
@test threshold_batchsize(
SecondOrder(
AutoSimpleFiniteDiff(; chunksize=4), AutoSimpleFiniteDiff(; chunksize=1)
),
2,
) isa SecondOrder{<:AutoSimpleFiniteDiff{2},<:AutoSimpleFiniteDiff{1}}
@test threshold_batchsize(
MixedMode(AutoSimpleFiniteDiff(; chunksize=4), AutoZygote()), 2
) isa MixedMode{<:AutoSimpleFiniteDiff{2},<:AutoZygote}
end
@testset "Reasonable" begin
for Bmax in 1:5
@test all(<=(Bmax), reasonable_batchsize.(1:10, Bmax))
@test issorted(div.(1:10, reasonable_batchsize.(1:10, Bmax), RoundUp))
if Bmax > 2
@test reasonable_batchsize(Bmax + 1, Bmax) < Bmax
end
end
end