Skip to content

Commit d2c6734

Browse files
authored
Call count in benchmark (#151)
* Call count in benchmark * Fix
1 parent b54c981 commit d2c6734

10 files changed

Lines changed: 222 additions & 226 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ const AnyAutoFastDifferentiation = Union{
3131
}
3232

3333
DI.mode(::AnyAutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
34+
DI.pushforward_performance(::AnyAutoFastDifferentiation) = DI.PushforwardFast()
35+
DI.pullback_performance(::AnyAutoFastDifferentiation) = DI.PullbackSlow()
3436

3537
myvec(x::Number) = [x]
3638
myvec(x::AbstractArray) = vec(x)

DifferentiationInterface/src/pullback.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ prepare_pullback(f!, ::AbstractADType, y, x) = NoPullbackExtras()
2929
function value_and_pullback_split(
3030
f, backend::AbstractADType, x, extras::PullbackExtras=prepare_pullback(f, backend, x)
3131
)
32-
if !Bool(pushforward_performance(backend))
33-
error("Pushforward not available for backend $backend")
34-
end
32+
return value_and_pullback_split_aux(
33+
f, backend, x, extras, pullback_performance(backend)
34+
)
35+
end
36+
37+
function value_and_pullback_split_aux(f, backend, x, extras, ::PullbackSlow)
3538
pushforward_extras = prepare_pushforward(f, backend, x)
3639
y = f(x)
3740
pullbackfunc = if x isa Number && y isa Number

DifferentiationInterface/src/pushforward.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ function value_and_pushforward(
3333
dx,
3434
extras::PushforwardExtras=prepare_pushforward(f, backend, x),
3535
)
36-
if !Bool(pullback_performance(backend))
37-
error("Pullback not available for backend $backend")
38-
end
36+
return value_and_pushforward_aux(
37+
f, backend, x, dx, extras, pushforward_performance(backend)
38+
)
39+
end
40+
41+
function value_and_pushforward_aux(f, backend, x, dx, extras, ::PushforwardSlow)
3942
pullback_extras = prepare_pullback(f, backend, x)
4043
y, pullbackfunc = value_and_pullback_split(f, backend, x, pullback_extras)
4144
dy = if x isa Number && y isa Number

DifferentiationInterfaceTest/docs/src/tutorial.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ Let us experiment with various input types and sizes:
4545
scenarios = [
4646
GradientScenario(f; x=rand(Float64, 3), ref=∇f),
4747
GradientScenario(f; x=rand(Float32, 3, 4), ref=∇f),
48-
GradientScenario(f; x=rand(Float16, 3, 4, 5), ref=∇f),
4948
];
5049
```
5150

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ include("utils/filter.jl")
5252

5353
include("tests/correctness.jl")
5454
include("tests/type_stability.jl")
55-
include("tests/call_count.jl")
5655
include("tests/sparsity.jl")
5756
include("tests/benchmark.jl")
5857
include("test_differentiation.jl")

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ Testing:
1414
- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario
1515
- If a backend object like `correctness=AutoForwardDiff()` is passed instead of a boolean, the results will be compared using that reference backend as the ground truth.
1616
- Otherwise, the scenario-specific reference operator will be used as the ground truth instead, see [`AbstractScenario`](@ref) for details.
17-
- `call_count=false`: whether to check that the function is called the right number of times
1817
- `type_stability=false`: whether to check type stability with JET.jl (thanks to `@test_opt`)
1918
- `sparsity`: whether to check sparsity of the jacobian / hessian
2019
- `detailed=false`: whether to print a detailed or condensed test log
@@ -79,13 +78,18 @@ function test_differentiation(
7978
prog = ProgressUnknown(; desc="$title", spinner=true, enabled=logging)
8079

8180
@testset verbose = true "$title" begin
82-
@testset verbose = detailed "$(backend_string(backend))" for backend in backends
83-
@testset "$scen" for scen in filter(s -> compatible(backend, s), scenarios)
81+
@testset verbose = detailed "$(backend_string(backend))" for (i, backend) in
82+
enumerate(backends)
83+
filtered_scenarios = filter(s -> compatible(backend, s), scenarios)
84+
@testset "$scen" for (j, scen) in enumerate(filtered_scenarios)
8485
next!(
8586
prog;
8687
showvalues=[
87-
(:backend, backend_string(backend)),
88-
(:scenario, typeof(scen).name.name),
88+
(:backend, "$(backend_string(backend)) - $i/$(length(backends))"),
89+
(
90+
:scenario,
91+
"$(typeof(scen).name.name) - $j/$(length(filtered_scenarios))",
92+
),
8993
(:function, scen.f),
9094
(:input, typeof(scen.x)),
9195
(:output, typeof(scen.y)),
@@ -102,9 +106,6 @@ function test_differentiation(
102106
)
103107
end
104108
end
105-
call_count && @testset "Call count" begin
106-
test_call_count(backend, scen)
107-
end
108109
type_stability && @testset "Type stability" begin
109110
test_jet(backend, scen)
110111
end
@@ -163,13 +164,17 @@ function benchmark_differentiation(
163164

164165
benchmark_data = BenchmarkDataRow[]
165166
prog = ProgressUnknown(; desc="Benchmarking", spinner=true, enabled=logging)
166-
for backend in backends
167-
for scen in filter(s -> compatible(backend, s), scenarios)
167+
for (i, backend) in enumerate(backends)
168+
filtered_scenarios = filter(s -> compatible(backend, s), scenarios)
169+
for (j, scen) in enumerate(filtered_scenarios)
168170
next!(
169171
prog;
170172
showvalues=[
171-
(:backend, backend_string(backend)),
172-
(:scenario, typeof(scen).name.name),
173+
(:backend, "$(backend_string(backend)) - $i/$(length(backends))"),
174+
(
175+
:scenario,
176+
"$(typeof(scen).name.name) - $j/$(length(filtered_scenarios))",
177+
),
173178
(:function, scen.f),
174179
(:input, typeof(scen.x)),
175180
(:output, typeof(scen.y)),

0 commit comments

Comments
 (0)