Skip to content

Commit 6b55559

Browse files
authored
feat: customizable DIT benchmarks (#636)
* feat: allow custom max duration in DIT benchmarks * feat: customize benchmark aggregation * Constant fix for Lux
1 parent 2743e18 commit 6b55559

5 files changed

Lines changed: 132 additions & 103 deletions

File tree

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
181181
scen = Scenario{:gradient,:out}(
182182
square_loss,
183183
ComponentArray(ps);
184-
contexts=(Constant(model), Constant(x), Constant(st)),
184+
contexts=(DI.Constant(model), DI.Constant(x), DI.Constant(st)),
185185
res1=g,
186186
)
187187
push!(scens, scen)

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ Each setting tests/benchmarks a different subset of calls:
5656
5757
- `count_calls=true`: whether to also count function calls during benchmarking
5858
- `benchmark_test=true`: whether to include tests which succeed iff benchmark doesn't error
59+
- `benchmark_seconds=1`: how long to run each benchmark for
60+
- `benchmark_aggregation=minimum`: function used to aggregate sample measurements
5961
"""
6062
function test_differentiation(
6163
backends::Vector{<:AbstractADType},
@@ -87,6 +89,8 @@ function test_differentiation(
8789
# benchmark options
8890
count_calls::Bool=true,
8991
benchmark_test::Bool=true,
92+
benchmark_seconds::Real=1,
93+
benchmark_aggregation=minimum,
9094
)
9195
@assert type_stability in (:none, :prepared, :full)
9296
@assert allocations in (:none, :prepared, :full)
@@ -173,6 +177,8 @@ function test_differentiation(
173177
subset=benchmark,
174178
count_calls,
175179
benchmark_test,
180+
benchmark_seconds,
181+
benchmark_aggregation,
176182
)
177183
end
178184
yield()
@@ -211,6 +217,8 @@ function benchmark_differentiation(
211217
logging::Bool=false,
212218
count_calls::Bool=true,
213219
benchmark_test::Bool=true,
220+
benchmark_seconds::Real=1,
221+
benchmark_aggregation=minimum,
214222
)
215223
return test_differentiation(
216224
backends,
@@ -223,5 +231,7 @@ function benchmark_differentiation(
223231
excluded,
224232
count_calls,
225233
benchmark_test,
234+
benchmark_seconds,
235+
benchmark_aggregation,
226236
)
227237
end

DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ $(TYPEDFIELDS)
5656
5757
See the documentation of [Chairmarks.jl](https://github.com/LilithHafner/Chairmarks.jl) for more details on the measurement fields.
5858
"""
59-
Base.@kwdef struct DifferentiationBenchmarkDataRow
59+
Base.@kwdef struct DifferentiationBenchmarkDataRow{T}
6060
"backend used for benchmarking"
6161
backend::AbstractADType
6262
"scenario used for benchmarking"
@@ -71,16 +71,16 @@ Base.@kwdef struct DifferentiationBenchmarkDataRow
7171
samples::Int
7272
"number of evaluations used for averaging in each sample"
7373
evals::Int
74-
"minimum runtime over all samples, in seconds"
75-
time::Float64
76-
"minimum number of allocations over all samples"
77-
allocs::Float64
78-
"minimum memory allocated over all samples, in bytes"
79-
bytes::Float64
80-
"minimum fraction of time spent in garbage collection over all samples, between 0.0 and 1.0"
81-
gc_fraction::Float64
82-
"minimum fraction of time spent compiling over all samples, between 0.0 and 1.0"
83-
compile_fraction::Float64
74+
"aggregated runtime over all samples, in seconds"
75+
time::T
76+
"aggregated number of allocations over all samples"
77+
allocs::T
78+
"aggregated memory allocated over all samples, in bytes"
79+
bytes::T
80+
"aggregated fraction of time spent in garbage collection over all samples, between 0.0 and 1.0"
81+
gc_fraction::T
82+
"aggregated fraction of time spent compiling over all samples, between 0.0 and 1.0"
83+
compile_fraction::T
8484
end
8585

8686
function record!(
@@ -91,21 +91,22 @@ function record!(
9191
prepared::Union{Nothing,Bool},
9292
bench::Benchmark,
9393
calls::Integer,
94+
aggregation,
9495
)
95-
bench_min = minimum(bench)
96+
bench_agg = aggregation(bench)
9697
row = DifferentiationBenchmarkDataRow(;
9798
backend=backend,
9899
scenario=scenario,
99100
operator=Symbol(operator),
100101
prepared=prepared,
101102
calls=calls,
102103
samples=length(bench.samples),
103-
evals=Int(bench_min.evals),
104-
time=bench_min.time,
105-
allocs=bench_min.allocs,
106-
bytes=bench_min.bytes,
107-
gc_fraction=bench_min.gc_fraction,
108-
compile_fraction=bench_min.compile_fraction,
104+
evals=Int(bench_agg.evals),
105+
time=bench_agg.time,
106+
allocs=bench_agg.allocs,
107+
bytes=bench_agg.bytes,
108+
gc_fraction=bench_agg.gc_fraction,
109+
compile_fraction=bench_agg.compile_fraction,
109110
)
110111
return push!(data, row)
111112
end

0 commit comments

Comments
 (0)