Skip to content

Commit 3bfcd1f

Browse files
adrhillgdalle
andauthored
Separate testing scenarios for each operator (#120)
* Rename files for clarity * Make less confusing * Minor doc-string fix * Add Scenario types for each operator * Loosen type restrictions, add defaults * Dispatch on scenario type * Dispatch benchmarks on scenario type * Simplify compatibility tests Signed-off-by adrhill <adrian.hill@mailbox.org> * Fix dispatch * Remove operator filtering * Update docstring * Finish fixing * Remove test from docs * Tweaks * No logging in CI * Fix component scenarios * Fix component arrays * Make JET happy --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 78d3d99 commit 3bfcd1f

34 files changed

Lines changed: 1090 additions & 1017 deletions

docs/make.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ makedocs(;
6666
),
6767
pages=[
6868
"Home" => "index.md", #
69-
"tutorial.md", #
70-
"overview.md", #
71-
"api.md", #
72-
"backends.md", #
73-
"developer.md",
69+
"Start here" => ["tutorial.md", "overview.md", "backends.md"],
70+
"Reference" => ["core.md", "testing.md"],
71+
"Advanced" => ["design.md", "extensions.md"],
7472
],
7573
warnonly=:missing_docs, # missing docs for ADTypes.jl are normal
7674
)

docs/src/backends.md

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -103,30 +103,3 @@ rows = map(all_backends()) do backend # hide
103103
end # hide
104104
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
105105
```
106-
107-
## Package extensions
108-
109-
```@meta
110-
CurrentModule = DifferentiationInterface
111-
```
112-
113-
Backend-specific extension content is not part of the public API.
114-
115-
```@autodocs
116-
Modules = [
117-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceChainRulesCoreExt),
118-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceDiffractorExt),
119-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceEnzymeExt),
120-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFastDifferentiationExt),
121-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFiniteDiffExt),
122-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFiniteDifferencesExt),
123-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceForwardDiffExt),
124-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfacePolyesterForwardDiffExt),
125-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceReverseDiffExt),
126-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceSparseDiffToolsExt),
127-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTapirExt),
128-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTrackerExt),
129-
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceZygoteExt)
130-
]
131-
Filter = t -> !(t isa Type && t <: ADTypes.AbstractADType)
132-
```

docs/src/api.md renamed to docs/src/core.md

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CurrentModule = Main
33
CollapsedDocStrings = true
44
```
55

6-
# API reference
6+
# Core API
77

88
```@docs
99
DifferentiationInterface
@@ -51,20 +51,6 @@ Modules = [DifferentiationInterface]
5151
Pages = ["backends.jl"]
5252
```
5353

54-
## Preparation
55-
56-
```@autodocs
57-
Modules = [DifferentiationInterface]
58-
Pages = ["prepare.jl"]
59-
```
60-
61-
## Testing & benchmarking
62-
63-
```@autodocs
64-
Modules = [DifferentiationInterfaceTest]
65-
Private = false
66-
```
67-
6854
## Internals
6955

7056
This is not part of the public API.
@@ -75,8 +61,3 @@ Public = false
7561
Order = [:function, :type]
7662
Filter = t -> !(t isa Type && t <: ADTypes.AbstractADType)
7763
```
78-
79-
```@autodocs
80-
Modules = [DifferentiationInterfaceTest]
81-
Public = false
82-
```
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# For AD developers
1+
# Package design
22

33
## Backend requirements
44

docs/src/extensions.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Package extensions
2+
3+
```@meta
4+
CurrentModule = DifferentiationInterface
5+
```
6+
7+
Backend-specific extension content is not part of the public API.
8+
If any docstrings are present in an extension, they will appear below.
9+
10+
```@autodocs
11+
Modules = [
12+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceChainRulesCoreExt),
13+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceDiffractorExt),
14+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceEnzymeExt),
15+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFastDifferentiationExt),
16+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFiniteDiffExt),
17+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceFiniteDifferencesExt),
18+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceForwardDiffExt),
19+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfacePolyesterForwardDiffExt),
20+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceReverseDiffExt),
21+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceSparseDiffToolsExt),
22+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTapirExt),
23+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTrackerExt),
24+
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceZygoteExt)
25+
]
26+
Filter = t -> !(t isa Type && t <: ADTypes.AbstractADType)
27+
```

docs/src/overview.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ Several variants of each operator are defined:
4040
```julia
4141
# work with grad_in
4242
grad_out = gradient!!(f, grad_in, backend, x)
43-
# work with grad_out
43+
# work with grad_out: OK
4444
```
4545
On the other hand, this is bad, because if `grad_in` has not been mutated, you will forget the results:
4646
```julia
4747
# work with grad_in
4848
gradient!!(f, grad_in, backend, x)
49-
# mistakenly keep working with grad_in
49+
# mistakenly keep working with grad_in: NOT OK
5050
```
5151
Note that we don't guarantee `grad_out` will have the same type as `grad_in`.
5252

docs/src/testing.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
```@meta
2+
CurrentModule = Main
3+
CollapsedDocStrings = true
4+
```
5+
6+
# Testing API
7+
8+
```@docs
9+
DifferentiationInterfaceTest
10+
```
11+
12+
## Entry points
13+
14+
```@docs
15+
test_differentiation
16+
benchmark_differentiation
17+
```
18+
19+
## Pre-made scenario lists
20+
21+
```@docs
22+
default_scenarios
23+
component_scenarios
24+
gpu_scenarios
25+
static_scenarios
26+
```
27+
28+
## Scenario types
29+
30+
```@docs
31+
AbstractScenario
32+
PushforwardScenario
33+
PullbackScenario
34+
DerivativeScenario
35+
GradientScenario
36+
JacobianScenario
37+
SecondDerivativeScenario
38+
HVPScenario
39+
HessianScenario
40+
```
41+
42+
## Internals
43+
44+
This is not part of the public API.
45+
46+
```@autodocs
47+
Modules = [DifferentiationInterfaceTest]
48+
Public = false
49+
```

docs/src/tutorial.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,9 @@ For testing, you can use [`test_differentiation`](@ref) as follows:
149149
```@repl tuto
150150
test_differentiation(
151151
[AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)], # backends to compare
152-
[gradient, pullback], # operators to try
153-
[Scenario(f; x=rand(3)), Scenario(f; x=rand(3,3))]; # test scenarios
152+
[GradientScenario(f; x=rand(3)), GradientScenario(f; x=rand(3,3))]; # test scenarios
154153
correctness=AutoZygote(), # compare results to a "ground truth" from Zygote
155-
detailed=true, # print detailed test set
154+
detailed=true, # print detailed test log
156155
);
157156
```
158157

@@ -164,15 +163,14 @@ The interface of [`benchmark_differentiation`](@ref) is very similar to the one
164163
```@repl tuto
165164
data = benchmark_differentiation(
166165
[AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)],
167-
[gradient, pullback],
168-
[Scenario(f; x=rand(3)), Scenario(f; x=rand(3,3))];
166+
[GradientScenario(f; x=rand(3)), GradientScenario(f; x=rand(3,3))];
169167
);
170168
```
171169

172-
The `BenchmarkData` object is just a struct of vectors, and you can easily convert to a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl):
170+
The resulting object is just a vector of structs, and you can easily convert to a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl):
173171

174172
```@repl tuto
175-
df = DataFrames.DataFrame(pairs(data)...)
173+
df = DataFrames.DataFrame(data)
176174
```
177175

178176
Here's what the resulting `DataFrame` looks like with all its columns.

lib/DifferentiationInterfaceTest/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1313
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1516
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1617
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

lib/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,15 @@ using Chairmarks: @be, Benchmark, Sample
1717
using ComponentArrays: ComponentVector
1818
using DifferentiationInterface
1919
using DifferentiationInterface:
20-
AutoTapir,
21-
inner,
22-
mode,
23-
outer,
24-
supports_mutation,
25-
pushforward_performance,
26-
pullback_performance
20+
inner, mode, outer, supports_mutation, pushforward_performance, pullback_performance
2721
using DocStringExtensions
2822
import DifferentiationInterface as DI
2923
using JET: @test_call, @test_opt
3024
using JLArrays: jl
3125
using LinearAlgebra: Diagonal, dot
26+
using ProgressMeter: ProgressUnknown, next!
3227
using SparseArrays: SparseArrays, nnz, SparseMatrixCSC
33-
using StaticArrays: SVector, SMatrix
28+
using StaticArrays: MMatrix, MVector, SMatrix, SVector
3429
using Test: @testset, @test
3530

3631
include("scenarios/scenario.jl")
@@ -39,7 +34,7 @@ include("scenarios/static.jl")
3934
include("scenarios/component.jl")
4035
include("scenarios/gpu.jl")
4136

42-
include("utils/zero.jl")
37+
include("utils/zero_backends.jl")
4338
include("utils/compatibility.jl")
4439
include("utils/printing.jl")
4540
include("utils/misc.jl")
@@ -50,13 +45,19 @@ include("tests/type_stability.jl")
5045
include("tests/call_count.jl")
5146
include("tests/sparsity.jl")
5247
include("tests/benchmark.jl")
53-
include("tests/test.jl")
48+
include("test_differentiation.jl")
5449

55-
export all_operators
56-
export Scenario
50+
export AbstractScenario
51+
export PushforwardScenario,
52+
PullbackScenario,
53+
DerivativeScenario,
54+
GradientScenario,
55+
JacobianScenario,
56+
SecondDerivativeScenario,
57+
HVPScenario,
58+
HessianScenario
5759
export default_scenarios
5860
export static_scenarios, component_scenarios, gpu_scenarios
59-
export BenchmarkData
6061
export test_differentiation, benchmark_differentiation
6162

6263
end

0 commit comments

Comments
 (0)