Skip to content

Commit 415dc6c

Browse files
authored
Fix DifferentiationInterfaceTest tutorial (#335)
* Fix DIT tutorial * Make Scenario public * More stuff
1 parent 8dc755b commit 415dc6c

7 files changed

Lines changed: 33 additions & 20 deletions

File tree

DifferentiationInterface/docs/make.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ makedocs(;
3939
"Reference" => ["operators.md", "backends.md", "api.md"],
4040
"Advanced" => ["dev_guide.md", "overloads.md"],
4141
],
42-
checkdocs=:exports,
4342
plugins=[links],
4443
)
4544

DifferentiationInterfaceTest/docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
55
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
66
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
77
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
8+
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
89
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1011
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"

DifferentiationInterfaceTest/docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using DifferentiationInterface
22
using DifferentiationInterfaceTest
33
using Documenter
4+
using DocumenterInterLinks
45

56
using BenchmarkTools: BenchmarkTools
67
using DataFrames: DataFrames

DifferentiationInterfaceTest/docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ static_scenarios
3030
## Scenario types
3131

3232
```@docs
33+
Scenario
3334
PushforwardScenario
3435
PullbackScenario
3536
DerivativeScenario

DifferentiationInterfaceTest/docs/src/tutorial.md

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ We present a typical workflow with DifferentiationInterfaceTest.jl, building on
55
```@repl tuto
66
using DifferentiationInterface, DifferentiationInterfaceTest
77
import ForwardDiff, Enzyme
8-
import Markdown, PrettyTables, Printf
98
```
109

1110
## Introduction
@@ -31,7 +30,8 @@ Of course we know the true gradient mapping:
3130
DifferentiationInterfaceTest.jl relies with so-called "scenarios", in which you encapsulate the information needed for your test:
3231

3332
- the function `f`
34-
- the input `x` and output `y`
33+
- the input `x` and output `y` of the function `f`
34+
- the reference output of the operator (here `grad`)
3535
- the number of arguments for `f` (either `1` or `2`)
3636
- the behavior of the operator (either `:inplace` or `:outofplace`)
3737

@@ -41,8 +41,8 @@ There is one scenario constructor per operator, and so here we will use [`Gradie
4141
xv = rand(Float32, 3)
4242
xm = rand(Float64, 3, 2)
4343
scenarios = [
44-
GradientScenario(f; x=xv, y=f(xv), nb_args=1, place=:inplace),
45-
GradientScenario(f; x=xm, y=f(xm), nb_args=1, place=:inplace)
44+
GradientScenario(f; x=xv, y=f(xv), grad=∇f(xv), nb_args=1, place=:inplace),
45+
GradientScenario(f; x=xm, y=f(xm), grad=∇f(xm), nb_args=1, place=:inplace)
4646
];
4747
nothing # hide
4848
```
@@ -73,16 +73,4 @@ This is made easy by the [`benchmark_differentiation`](@ref) function, whose syn
7373
df = benchmark_differentiation(backends, scenarios);
7474
```
7575

76-
The resulting object is `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), whose columns correspond to the fields of [`DifferentiationBenchmarkDataRow`](@ref):
77-
Here's what it looks like with all of its columns.
78-
79-
```@example tuto
80-
table = PrettyTables.pretty_table(
81-
String,
82-
df;
83-
backend=Val(:markdown),
84-
header=names(df),
85-
)
86-
87-
Markdown.parse(table)
88-
```
76+
The resulting object is a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), whose columns correspond to the fields of [`DifferentiationBenchmarkDataRow`](@ref):

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ include("tests/sparsity.jl")
9292
include("tests/benchmark.jl")
9393
include("test_differentiation.jl")
9494

95+
export Scenario
9596
export PushforwardScenario,
9697
PullbackScenario,
9798
DerivativeScenario,

DifferentiationInterfaceTest/src/scenarios/scenario.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ This generic type should never be used directly: use the specific constructor co
2525
# Fields
2626
2727
$(TYPEDFIELDS)
28+
29+
Note that the `res1` and `res2` fields are given more meaningful names in the keyword arguments of each specialized constructor.
30+
For example:
31+
32+
- the keyword `grad` of `GradientScenario` becomes `res1`
33+
- the keyword `hess` of `HessianScenario` becomes `res2`, and the keyword `grad` becomes `res1`
2834
"""
2935
struct Scenario{op,args,pl,F,X,Y,D,R1,R2}
3036
"function `f` (if `args==1`) or `f!` (if `args==2`) to apply"
@@ -35,9 +41,9 @@ struct Scenario{op,args,pl,F,X,Y,D,R1,R2}
3541
y::Y
3642
"seed for pushforward, pullback or HVP"
3743
seed::D
38-
"first-order result"
44+
"first-order result of the operator"
3945
res1::R1
40-
"second-order result"
46+
"second-order result of the operator (when it makes sense)"
4147
res2::R2
4248

4349
function Scenario{op,args,pl}(
@@ -120,20 +126,26 @@ end
120126

121127
"""
122128
$(SIGNATURES)
129+
130+
Construct a [`Scenario`](@ref) to test `pushforward` and its variants.
123131
"""
124132
function PushforwardScenario(f; x, y, dx, dy=nothing, nb_args, place=:inplace)
125133
return Scenario{:pushforward,nb_args,place}(f; x, y, seed=dx, res1=dy, res2=nothing)
126134
end
127135

128136
"""
129137
$(SIGNATURES)
138+
139+
Construct a [`Scenario`](@ref) to test `pullback` and its variants.
130140
"""
131141
function PullbackScenario(f; x, y, dy, dx=nothing, nb_args, place=:inplace)
132142
return Scenario{:pullback,nb_args,place}(f; x, y, seed=dy, res1=dx, res2=nothing)
133143
end
134144

135145
"""
136146
$(SIGNATURES)
147+
148+
Construct a [`Scenario`](@ref) to test `derivative` and its variants.
137149
"""
138150
function DerivativeScenario(f; x, y, der=nothing, nb_args, place=:inplace)
139151
return Scenario{:derivative,nb_args,place}(
@@ -143,20 +155,26 @@ end
143155

144156
"""
145157
$(SIGNATURES)
158+
159+
Construct a [`Scenario`](@ref) to test `gradient` and its variants.
146160
"""
147161
function GradientScenario(f; x, y, grad=nothing, nb_args, place=:inplace)
148162
return Scenario{:gradient,nb_args,place}(f; x, y, seed=nothing, res1=grad, res2=nothing)
149163
end
150164

151165
"""
152166
$(SIGNATURES)
167+
168+
Construct a [`Scenario`](@ref) to test `jacobian` and its variants.
153169
"""
154170
function JacobianScenario(f; x, y, jac=nothing, nb_args, place=:inplace)
155171
return Scenario{:jacobian,nb_args,place}(f; x, y, seed=nothing, res1=jac, res2=nothing)
156172
end
157173

158174
"""
159175
$(SIGNATURES)
176+
177+
Construct a [`Scenario`](@ref) to test `second_derivative` and its variants.
160178
"""
161179
function SecondDerivativeScenario(
162180
f; x, y, der=nothing, der2=nothing, nb_args, place=:inplace
@@ -168,13 +186,17 @@ end
168186

169187
"""
170188
$(SIGNATURES)
189+
190+
Construct a [`Scenario`](@ref) to test `hvp` and its variants.
171191
"""
172192
function HVPScenario(f; x, y, dx, grad=nothing, dg=nothing, nb_args, place=:inplace)
173193
return Scenario{:hvp,nb_args,place}(f; x, y, seed=dx, res1=grad, res2=dg)
174194
end
175195

176196
"""
177197
$(SIGNATURES)
198+
199+
Construct a [`Scenario`](@ref) to test `hessian` and its variants.
178200
"""
179201
function HessianScenario(f; x, y, grad=nothing, hess=nothing, nb_args, place=:inplace)
180202
return Scenario{:hessian,nb_args,place}(f; x, y, seed=nothing, res1=grad, res2=hess)

0 commit comments

Comments
 (0)