Skip to content

Commit 0047722

Browse files
adrhillgdalle
andauthored
Simplify tutorial and README (#216)
* Simplify Tutorial and README * Safer random test --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 0f71a42 commit 0047722

3 files changed

Lines changed: 53 additions & 41 deletions

File tree

DifferentiationInterface/README.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ This package provides a backend-agnostic syntax to differentiate functions of th
2020

2121
## Features
2222

23-
- First- and second-order operators
23+
- First- and second-order operators (gradients, Jacobians, Hessians and [more](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/overview/))
2424
- In-place and out-of-place differentiation
2525
- Preparation mechanism (e.g. to create a config or tape)
2626
- Thorough validation on standard inputs and outputs (numbers, vectors, matrices)
@@ -68,19 +68,21 @@ julia> Pkg.add(
6868

6969
## Example
7070

71-
```jldoctest readme
72-
julia> import ADTypes, ForwardDiff
73-
74-
julia> using DifferentiationInterface
71+
```julia
72+
using DifferentiationInterface
73+
import ForwardDiff, Enzyme, Zygote # import automatic differentiation backends you want to use
7574

76-
julia> backend = ADTypes.AutoForwardDiff();
75+
f(x) = sum(abs2, x)
7776

78-
julia> f(x) = sum(abs2, x);
77+
x = [1.0, 2.0, 3.0]
7978

80-
julia> value_and_gradient(f, backend, [1., 2., 3.])
81-
(14.0, [2.0, 4.0, 6.0])
79+
value_and_gradient(f, AutoForwardDiff(), x) # returns (14.0, [2.0, 4.0, 6.0]) using ForwardDiff.jl
80+
value_and_gradient(f, AutoEnzyme(), x) # returns (14.0, [2.0, 4.0, 6.0]) using Enzyme.jl
81+
value_and_gradient(f, AutoZygote(), x) # returns (14.0, [2.0, 4.0, 6.0]) using Zygote.jl
8282
```
8383

84+
For more performance, take a look at the [DifferentiationInterface tutorial](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/tutorial/).
85+
8486
## Related packages
8587

8688
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl.

DifferentiationInterface/docs/src/tutorial.md

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,50 @@ CurrentModule = Main
66

77
We present a typical workflow with DifferentiationInterface.jl and showcase its potential performance benefits.
88

9-
```@repl tuto
9+
```@example tuto
1010
using DifferentiationInterface
11-
import ADTypes, ForwardDiff, Enzyme
12-
using BenchmarkTools
11+
12+
import ForwardDiff, Enzyme # ⚠️ import the backends you want to use ⚠️
1313
```
1414

15-
## Computing a gradient
15+
!!! tip
16+
Importing backends with `import` instead of `using` avoids name conflicts and makes sure you are using operators from DifferentiationInterface.jl.
17+
This is useful since most backends also export operators like `gradient` and `jacobian`.
1618

17-
A common use case of AD is optimizing real-valued functions with first- or second-order methods.
18-
Let's define a simple objective
1919

20-
```@repl tuto
21-
f(x::AbstractArray) = sum(abs2, x)
22-
```
20+
## Computing a gradient
2321

24-
and a random input vector
22+
A common use case of automatic differentiation (AD) is optimizing real-valued functions with first- or second-order methods.
23+
Let's define a simple objective and a random input vector
2524

26-
```@repl tuto
27-
x = [1.0, 2.0, 3.0];
25+
```@example tuto
26+
f(x) = sum(abs2, x)
27+
28+
x = [1.0, 2.0, 3.0]
29+
nothing # hide
2830
```
2931

3032
To compute its gradient, we need to choose a "backend", i.e. an AD package that DifferentiationInterface.jl will call under the hood.
3133
Most backend types are defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl) and re-exported by DifferentiationInterface.jl.
34+
3235
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) is very generic and efficient for low-dimensional inputs, so it's a good starting point:
3336

34-
```@repl tuto
35-
backend = ADTypes.AutoForwardDiff()
37+
```@example tuto
38+
backend = AutoForwardDiff()
39+
nothing # hide
3640
```
3741

3842
Now you can use DifferentiationInterface.jl to get the gradient:
3943

40-
```@repl tuto
44+
```@example tuto
4145
gradient(f, backend, x)
4246
```
4347

4448
Was that fast?
4549
[BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl) helps you answer that question.
4650

4751
```@repl tuto
52+
using BenchmarkTools
4853
@btime gradient($f, $backend, $x);
4954
```
5055

@@ -58,13 +63,14 @@ Not bad, but you can do better.
5863

5964
## Overwriting a gradient
6065

61-
Since you know how much space your gradient will occupy, you can pre-allocate that memory and offer it to AD.
66+
Since you know how much space your gradient will occupy (the same as your input `x`), you can pre-allocate that memory and offer it to AD.
6267
Some backends get a speed boost from this trick.
6368

64-
```@repl tuto
65-
grad = zero(x)
66-
gradient!(f, grad, backend, x);
67-
grad
69+
```@example tuto
70+
grad = similar(x)
71+
gradient!(f, grad, backend, x)
72+
73+
grad # has been mutated
6874
```
6975

7076
The bang indicates that one of the arguments of `gradient!` might be mutated.
@@ -76,24 +82,26 @@ More precisely, our convention is that _every positional argument between the fu
7682

7783
For some reason the in-place version is not much better than your first attempt.
7884
However, it has one less allocation, which corresponds to the gradient vector you provided.
79-
Don't worry, you're not done yet.
85+
Don't worry, you can get even more performance.
8086

8187
## Preparing for multiple gradients
8288

8389
Internally, ForwardDiff.jl creates some data structures to keep track of things.
8490
These objects can be reused between gradient computations, even on different input values.
8591
We abstract away the preparation step behind a backend-agnostic syntax:
8692

87-
```@repl tuto
93+
```@example tuto
8894
extras = prepare_gradient(f, backend, x)
95+
nothing # hide
8996
```
9097

9198
You don't need to know what this object is, you just need to pass it to the gradient operator.
9299

93-
```@repl tuto
94-
grad = zero(x);
95-
gradient!(f, grad, backend, x, extras);
96-
grad
100+
```@example tuto
101+
grad = similar(x)
102+
gradient!(f, grad, backend, x, extras)
103+
104+
grad # has been mutated
97105
```
98106

99107
Preparation makes the gradient computation much faster, and (in this case) allocation-free.
@@ -115,13 +123,14 @@ So let's try the state-of-the-art [Enzyme.jl](https://github.com/EnzymeAD/Enzyme
115123

116124
For this one, the backend definition is slightly more involved, because you need to feed the "mode" to the object from ADTypes.jl:
117125

118-
```@repl tuto
119-
backend2 = ADTypes.AutoEnzyme(; mode=Enzyme.Reverse)
126+
```@example tuto
127+
backend2 = AutoEnzyme(; mode=Enzyme.Reverse)
128+
nothing # hide
120129
```
121130

122131
But once it is done, things run smoothly with exactly the same syntax:
123132

124-
```@repl tuto
133+
```@example tuto
125134
gradient(f, backend2, x)
126135
```
127136

@@ -136,4 +145,5 @@ And you can run the same benchmarks:
136145

137146
Not only is it blazingly fast, you achieved this speedup without looking at the docs of either ForwardDiff.jl or Enzyme.jl!
138147
In short, DifferentiationInterface.jl allows for easy testing and comparison of AD backends.
139-
If you want to go further, check out the [DifferentiationTest.jl tutorial](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/tutorial/).
148+
If you want to go further, check out the [DifferentiationInterfaceTest.jl tutorial](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterfaceTest/dev/tutorial/).
149+
It provides benchmarking utilities to compare backends and help you select the one that is best suited for your problem.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
alg = DI.GreedyColoringAlgorithm()
22

3-
A = sprand(Bool, 100, 200, 0.1)
3+
A = sprand(Bool, 100, 200, 0.05)
44

55
column_colors = ADTypes.column_coloring(A, alg)
66
@test DI.check_structurally_orthogonal_columns(A, column_colors)
@@ -10,7 +10,7 @@ row_colors = ADTypes.row_coloring(A, alg)
1010
@test DI.check_structurally_orthogonal_rows(A, row_colors)
1111
@test maximum(row_colors) < size(A, 1) ÷ 2
1212

13-
S = Symmetric(sprand(Bool, 100, 100, 0.1)) + I
13+
S = Symmetric(sprand(Bool, 100, 100, 0.05)) + I
1414
symmetric_colors = ADTypes.symmetric_coloring(S, alg)
1515
@test DI.check_symmetrically_structurally_orthogonal(S, symmetric_colors)
1616
@test maximum(symmetric_colors) < size(A, 2) ÷ 2

0 commit comments

Comments
 (0)