Skip to content

Commit 88c48c1

Browse files
authored
More type stability tests (#543)
* More type stability tests * Remove old kwarg * Don't test everything for now * Use fill! * LTS JET
1 parent 84378d7 commit 88c48c1

6 files changed

Lines changed: 248 additions & 94 deletions

File tree

DifferentiationInterface/src/misc/zero_backends.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ end
44

55
(rz::ReturnZero)(i) = zero(rz.template)
66

7-
_zero!(x::AbstractArray) = x .= zero(eltype(x))
7+
_zero!(x::AbstractArray) = fill!(x, zero(eltype(x)))
88

99
## Forward
1010

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ test_differentiation(
3838
test_differentiation(
3939
AutoForwardDiff(; chunksize=5);
4040
correctness=false,
41-
type_stability=true,
42-
preparation_type_stability=true,
41+
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
4342
logging=LOGGING,
4443
);
4544

DifferentiationInterface/test/Misc/ZeroBackends/test.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,19 @@ end
1919
## Type stability
2020

2121
test_differentiation(
22-
zero_backends,
23-
default_scenarios(; include_constantified=true);
22+
AutoZeroForward(),
23+
default_scenarios(; include_batchified=false, include_constantified=true);
2424
correctness=false,
2525
type_stability=true,
26-
preparation_type_stability=true,
26+
logging=LOGGING,
27+
)
28+
29+
test_differentiation(
30+
AutoZeroReverse(),
31+
default_scenarios(; include_batchified=false, include_constantified=true);
32+
correctness=false,
33+
# TODO: set unprepared_op=true after ignoring DataFrames
34+
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
2735
logging=LOGGING,
2836
)
2937

@@ -32,10 +40,9 @@ test_differentiation(
3240
SecondOrder(AutoZeroForward(), AutoZeroReverse()),
3341
SecondOrder(AutoZeroReverse(), AutoZeroForward()),
3442
],
35-
default_scenarios();
43+
default_scenarios(; include_batchified=false, include_constantified=true);
3644
correctness=false,
37-
type_stability=true,
38-
preparation_type_stability=true,
45+
type_stability=(; preparation=true, prepared_op=true, unprepared_op=true),
3946
first_order=false,
4047
logging=LOGGING,
4148
)
@@ -44,8 +51,7 @@ test_differentiation(
4451
AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()),
4552
default_scenarios(; include_constantified=true);
4653
correctness=false,
47-
type_stability=true,
48-
preparation_type_stability=true,
54+
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
4955
excluded=[:pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative],
5056
logging=LOGGING,
5157
)

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Cross-test a list of `backends` on a list of `scenarios`, running a variety of d
2929
Testing:
3030
3131
- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario
32-
- `type_stability=false`: whether to check type stability of operators with JET.jl (thanks to `JET.@test_opt`)
32+
- `type_stability=false`: whether to check type stability of operators with JET.jl (thanks to `JET.@test_opt`). It can be either a `Bool` or a more detailed named tuple `(; preparation, prepared_op, unprepared_op)` to specify which variants should be analyzed.
3333
- `sparsity`: whether to check sparsity of the jacobian / hessian
3434
- `detailed=false`: whether to print a detailed or condensed test log
3535
@@ -51,8 +51,7 @@ function test_differentiation(
5151
scenarios::Vector{<:Scenario}=default_scenarios();
5252
# testing
5353
correctness::Bool=true,
54-
type_stability::Bool=false,
55-
preparation_type_stability::Bool=false,
54+
type_stability=false,
5655
call_count::Bool=false,
5756
sparsity::Bool=false,
5857
detailed=false,
@@ -73,10 +72,12 @@ function test_differentiation(
7372
scenarios; first_order, second_order, input_type, output_type, excluded
7473
)
7574

75+
bool_type_stability = (type_stability == true || type_stability isa NamedTuple)
76+
7677
title_additions =
7778
(correctness != false ? " + correctness" : "") *
7879
(call_count ? " + calls" : "") *
79-
(type_stability ? " + types" : "") *
80+
(bool_type_stability ? " + type stability" : "") *
8081
(sparsity ? " + sparsity" : "")
8182
title = "Testing" * title_additions[3:end]
8283

@@ -115,13 +116,14 @@ function test_differentiation(
115116
adapted_backend, scen; isapprox, atol, rtol, scenario_intact
116117
)
117118
end
118-
type_stability && @testset "Type stability" begin
119+
kwargs_type_stability = if type_stability isa NamedTuple
120+
type_stability
121+
else
122+
(; preparation=false, prepared_op=type_stability, unprepared_op=false)
123+
end
124+
bool_type_stability && @testset "Type stability" begin
119125
@static if VERSION >= v"1.7"
120-
test_jet(
121-
adapted_backend,
122-
scen;
123-
test_preparation=preparation_type_stability,
124-
)
126+
test_jet(adapted_backend, scen; kwargs_type_stability...)
125127
end
126128
end
127129
sparsity && @testset "Sparsity" begin

0 commit comments

Comments
 (0)