Skip to content

Commit 84d98a3

Browse files
authored
Renounce BangBang, enforce mutation (#169)
1 parent cf19040 commit 84d98a3

71 files changed

Lines changed: 3295 additions & 1814 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ An interface to various automatic differentiation (AD) backends in Julia.
1515

1616
This package provides a backend-agnostic syntax to differentiate functions of the following types:
1717

18-
- _allocating_: `f(x) = y`
19-
- _mutating_: `f!(y, x) = nothing`
18+
- _one-argument functions_ (allocating): `f(x) = y`
19+
- _two-argument functions_ (mutating): `f!(y, x) = nothing`
2020

2121
## Features
2222

23-
- First and second order operators
23+
- First- and second-order operators
2424
- In-place and out-of-place differentiation
2525
- Preparation mechanism (e.g. to create a config or tape)
26-
- Thorough validation on standard inputs and outputs (scalars, vectors, matrices)
26+
- Thorough validation on standard inputs and outputs (numbers, vectors, matrices)
2727
- Testing and benchmarking utilities accessible to users with [DifferentiationInterfaceTest](https://github.com/gdalle/DifferentiationInterface.jl/tree/main/DifferentiationInterfaceTest)
2828

2929
## Compatibility

DifferentiationInterface/docs/src/api.md

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,50 +11,84 @@ DifferentiationInterface
1111

1212
## Derivative
1313

14-
```@autodocs
15-
Modules = [DifferentiationInterface]
16-
Pages = ["src/derivative.jl"]
17-
Private = false
14+
```@docs
15+
prepare_derivative
16+
derivative
17+
derivative!
18+
value_and_derivative
19+
value_and_derivative!
1820
```
1921

2022
## Gradient
2123

22-
```@autodocs
23-
Modules = [DifferentiationInterface]
24-
Pages = ["gradient.jl"]
25-
Private = false
24+
```@docs
25+
prepare_gradient
26+
gradient
27+
gradient!
28+
value_and_gradient
29+
value_and_gradient!
2630
```
2731

2832
## Jacobian
2933

30-
```@autodocs
31-
Modules = [DifferentiationInterface]
32-
Pages = ["jacobian.jl"]
33-
Private = false
34+
```@docs
35+
prepare_jacobian
36+
jacobian
37+
jacobian!
38+
value_and_jacobian
39+
value_and_jacobian!
3440
```
3541

3642
## Second order
3743

38-
```@autodocs
39-
Modules = [DifferentiationInterface]
40-
Pages = ["second_order.jl", "second_derivative.jl", "hessian.jl", "hvp.jl"]
41-
Private = false
44+
```@docs
45+
SecondOrder
46+
```
47+
48+
```@docs
49+
prepare_second_derivative
50+
second_derivative
51+
second_derivative!
52+
```
53+
54+
```@docs
55+
prepare_hvp
56+
hvp
57+
hvp!
58+
```
59+
60+
```@docs
61+
prepare_hessian
62+
hessian
63+
hessian!
4264
```
4365

4466
## Primitives
4567

46-
```@autodocs
47-
Modules = [DifferentiationInterface]
48-
Pages = ["pushforward.jl", "pullback.jl"]
49-
Private = false
68+
```@docs
69+
prepare_pushforward
70+
pushforward
71+
pushforward!
72+
value_and_pushforward
73+
value_and_pushforward!
74+
```
75+
76+
```@docs
77+
prepare_pullback
78+
pullback
79+
pullback!
80+
value_and_pullback
81+
value_and_pullback!
82+
value_and_pullback_split
83+
value_and_pullback!_split
5084
```
5185

5286
## Backend queries
5387

54-
```@autodocs
55-
Modules = [DifferentiationInterface]
56-
Pages = ["backends.jl"]
57-
Private = false
88+
```@docs
89+
check_available
90+
check_mutation
91+
check_hessian
5892
```
5993

6094
## Internals

DifferentiationInterface/docs/src/backends.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
8888

8989
## Mutation support
9090

91-
All backends are compatible with allocating functions `f(x) = y`.
92-
Only some are compatible with mutating functions `f!(y, x) = nothing`.
91+
All backends are compatible with one-argument functions `f(x) = y`.
92+
Only some are compatible with two-argument functions `f!(y, x) = nothing`.
9393
You can use [`check_mutation`](@ref) to check that feature, like we did below:
9494

9595
```@example backends
@@ -114,8 +114,3 @@ rows = map(all_backends()) do backend # hide
114114
end # hide
115115
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
116116
```
117-
118-
!!! warning
119-
Second-order operators can also be used with a combination of backends inside the [`SecondOrder`](@ref) struct.
120-
There are many possible combinations, a lot of which will fail.
121-
Due to compilation overhead, we do not currently test them all to display the working ones in the documentation, but we might if users deem it relevant.

DifferentiationInterface/docs/src/overview.md

Lines changed: 84 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,26 @@
33
## Operators
44

55
Depending on the type of input and output, differentiation operators can have various names.
6-
Most backends have custom implementations, which we reuse if possible.
76

8-
We choose the following terminology for the high-level operators we provide:
7+
We provide the following high-level operators:
98

10-
| operator | input `x` | output `y` | result type | result shape |
11-
| :------------------- | :-------------- | :-------------------------- | :--------------- | :----------------------- |
12-
| [`derivative`](@ref) | `Number` | `Number` or `AbstractArray` | same as `y` | `size(y)` |
13-
| [`gradient`](@ref) | `AbstractArray` | `Number` | same as `x` | `size(x)` |
14-
| [`jacobian`](@ref) | `AbstractArray` | `AbstractArray` | `AbstractMatrix` | `(length(y), length(x))` |
9+
| operator | order | input `x` | output `y` | result type | result shape |
10+
| :-------------------------- | :---- | :-------------- | :-------------------------- | :--------------- | :----------------------- |
11+
| [`derivative`](@ref) | 1 | `Number` | `Number` or `AbstractArray` | same as `y` | `size(y)` |
12+
| [`second_derivative`](@ref) | 2 | `Number` | `Number` or `AbstractArray` | same as `y` | `size(y)` |
13+
| [`gradient`](@ref) | 1 | `AbstractArray` | `Number` | same as `x` | `size(x)` |
14+
| [`hvp`](@ref) | 2 | `AbstractArray` | `Number` | same as `x` | `size(x)` |
15+
| [`hessian`](@ref) | 2 | `AbstractArray` | `Number` | `AbstractMatrix` | `(length(x), length(x))` |
16+
| [`jacobian`](@ref) | 1 | `AbstractArray` | `AbstractArray` | `AbstractMatrix` | `(length(y), length(x))` |
1517

16-
They are all based on the following low-level operators:
18+
They can all be derived from two low-level operators:
1719

18-
- [`pushforward`](@ref) (or JVP), to propagate input tangents
19-
- [`pullback`](@ref) (or VJP), to backpropagate output cotangents
20+
| operator | order | input `x` | output `y` | result type | result shape |
21+
| :----------------------------- | :---- | :--------- | :----------- | :---------- | :----------- |
22+
| [`pushforward`](@ref) (or JVP) | 1 | `Any` | `Any` | same as `y` | `size(y)` |
23+
| [`pullback`](@ref) (or VJP) | 1 | `Any` | `Any` | same as `x` | `size(x)` |
24+
25+
Luckily, most backends have custom implementations, which we reuse if possible instead of relying on fallbacks.
2026

2127
!!! tip
2228
See the book [The Elements of Differentiable Programming](https://arxiv.org/abs/2403.14606) for details on these concepts.
@@ -25,59 +31,33 @@ They are all based on the following low-level operators:
2531

2632
Several variants of each operator are defined:
2733

28-
| out-of-place | in-place (or not) | out-of-place + primal | in-place (or not) + primal |
29-
| :-------------------- | :---------------------- | :------------------------------ | :-------------------------------- |
30-
| [`derivative`](@ref) | [`derivative!!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!!`](@ref) |
31-
| [`gradient`](@ref) | [`gradient!!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!!`](@ref) |
32-
| [`jacobian`](@ref) | [`jacobian!!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!!`](@ref) |
33-
| [`pushforward`](@ref) | [`pushforward!!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!!`](@ref) |
34-
| [`pullback`](@ref) | [`pullback!!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!!`](@ref) |
34+
| out-of-place | in-place | out-of-place + primal | in-place + primal |
35+
| :-------------------------- | :--------------------------- | :------------------------------ | :------------------------------- |
36+
| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) |
37+
| [`second_derivative`](@ref) | [`second_derivative!`](@ref) | NA | NA |
38+
| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
39+
| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA |
40+
| [`hessian`](@ref) | [`hessian!`](@ref) | NA | NA |
41+
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
42+
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
43+
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
3544

36-
!!! warning
37-
We use the syntactic convention `!!` to signal that some of the arguments _can_ be mutated, but they do not _have to be_.
38-
Such arguments will always be part of the return, so that one can simply reuse the operator's output and forget its input.
39-
In other words, this is good:
40-
```julia
41-
# work with grad_in
42-
grad_out = gradient!!(f, grad_in, backend, x)
43-
# work with grad_out: OK
44-
```
45-
On the other hand, this is bad, because if `grad_in` has not been mutated, you will forget the results:
46-
```julia
47-
# work with grad_in
48-
gradient!!(f, grad_in, backend, x)
49-
# mistakenly keep working with grad_in: NOT OK
50-
```
51-
Note that we don't guarantee `grad_out` will have the same type as `grad_in`.
52-
Its type can even depend on the choice of backend.
53-
54-
## Second order
55-
56-
Second-order differentiation is also supported.
57-
You can either pick a single backend to do all the work, or combine an "outer" backend with an "inner" backend using the [`SecondOrder`](@ref) struct, like so: `SecondOrder(outer, inner)`.
58-
59-
The available operators are similar to first-order ones:
60-
61-
| operator | input `x` | output `y` | result type | result shape |
62-
| :-------------------------- | :-------------- | :-------------------------- | :--------------- | :----------------------- |
63-
| [`second_derivative`](@ref) | `Number` | `Number` or `AbstractArray` | same as `y` | `size(y)` |
64-
| [`hvp`](@ref) | `AbstractArray` | `Number` | same as `x` | `size(x)` |
65-
| [`hessian`](@ref) | `AbstractArray` | `Number` | `AbstractMatrix` | `(length(x), length(x))` |
66-
67-
We only define two variants for now:
68-
69-
| out-of-place | in-place (or not) |
70-
| :-------------------------- | :---------------------------- |
71-
| [`second_derivative`](@ref) | [`second_derivative!!`](@ref) |
72-
| [`hvp`](@ref) | [`hvp!!`](@ref) |
73-
| [`hessian`](@ref) | [`hessian!!`](@ref) |
45+
## Mutation and signatures
7446

75-
!!! danger
76-
Second-order differentiation is still experimental, use at your own risk.
47+
In order to ensure symmetry between one-argument functions `f(x) = y` and two-argument functions `f!(y, x) = nothing`, we define the same operators for both cases.
48+
However they have different signatures:
49+
50+
| signature | out-of-place | in-place |
51+
| :--------- | :--------------------------------- | :--------------------------------------- |
52+
| `f(x)` | `operator(f, backend, x, ...)` | `operator!(f, res, backend, x, ...)` |
53+
| `f!(y, x)` | `operator(f!, y, backend, x, ...)` | `operator!(f!, y, res, backend, x, ...)` |
54+
55+
!!! warning
56+
Every variant of the operator will mutate `y` when applied to a two-argument function `f!(y, x) = nothing`, even if it does not have a `!` in its name.
7757

7858
## Preparation
7959

80-
In many cases, AD can be accelerated if the function has been run at least once (e.g. to record a tape) and if some cache objects are provided.
60+
In many cases, AD can be accelerated if the function has been run at least once (e.g. to create a config or record a tape) and if some cache objects are provided.
8161
This is a backend-specific procedure, but we expose a common syntax to achieve it.
8262

8363
| operator | preparation function |
@@ -91,42 +71,69 @@ This is a backend-specific procedure, but we expose a common syntax to achieve i
9171
| `pullback` | [`prepare_pullback`](@ref) |
9272
| `hvp` | [`prepare_hvp`](@ref) |
9373

94-
If you run `prepare_operator(backend, f, x)`, it will create an object called `extras` containing the necessary information to speed up `operator` and its variants.
95-
This information is specific to `backend` and `f`, as well as the _type and size_ of the input `x`, but it should work with different _values_ of `x`.
74+
If you run `prepare_operator(backend, f, x, [seed])`, it will create an object called `extras` containing the necessary information to speed up `operator` and its variants.
75+
This information is specific to `backend` and `f`, as well as the _type and size_ of the input `x` and the _control flow_ within the function, but it should work with different _values_ of `x`.
9676

9777
You can then call `operator(backend, f, x2, extras)`, which should be faster than `operator(f, backend, x2)`.
9878
This is especially worth it if you plan to call `operator` several times in similar settings: you can think of it as a warm up.
9979

10080
!!! warning
101-
For `SecondOrder` backends, the inner differentiation cannot be prepared at the moment, only the outer one is.
81+
The `extras` object is nearly always mutated, even if the operator does not have a `!` in its name.
10282

103-
## FAQ
83+
### Second order
10484

105-
### Multiple inputs/outputs
85+
We offer two ways to perform second-order differentiation (for [`second_derivative`](@ref), [`hvp`](@ref) and [`hessian`](@ref)):
10686

107-
Restricting the API to one input and one output has many coding advantages, but it is not very flexible.
108-
If you need more than that, use [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap several objects inside a single `ComponentVector`.
87+
- pick a single backend to do all the work
88+
- combine an "outer" and "inner" backend within the [`SecondOrder`](@ref) struct: the inner backend will be called first, and the outer backend will differentiate the generated code
89+
90+
!!! warning
91+
There are many possible backend combinations, a lot of which will fail.
92+
At the moment, trial and error is your best friend.
93+
Usually, the most efficient approach for Hessians is forward-over-reverse, i.e. a forward-mode outer backend and a reverse-mode inner backend.
94+
95+
## Experimental
96+
97+
!!! danger
98+
Everything in this section is still experimental, use it at your own risk.
10999

110100
### Sparsity
111101

112-
If you need to work with sparse Jacobians, you can pick one of the [sparse backends](@ref Sparse) from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
113-
The sparsity pattern is computed automatically with [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) during the preparation step.
102+
[ADTypes.jl](https://github.com/SciML/ADTypes.jl) provides [sparse versions](@ref Sparse) of many common AD backends.
103+
They can accelerate the computation of sparse Jacobians and Hessians:
114104

115-
If you need to work with sparse Hessians, you can use a sparse backend as the _outer_ backend of a `SecondOrder`.
116-
This means the Hessian is obtained as the sparse Jacobian of the gradient.
105+
- for sparse Jacobians, just select one of them as your first-order backend.
106+
- for sparse Hessians, select one of them as the _outer part_ of a [`SecondOrder`](@ref) backend (in that case, the Hessian is obtained as the sparse Jacobian of the gradient).
117107

118-
!!! danger
119-
Sparsity support is still experimental, use at your own risk.
108+
The sparsity pattern is computed automatically with [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) during the preparation step.
109+
110+
!!! info "Planned feature"
111+
Modular sparsity pattern computation, with other algorithms beyond those from Symbolics.jl.
120112

121113
### Split reverse mode
122114

123115
Some reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
124116
We make this available for all backends with the following operators:
125117

126-
| | out-of-place | in-place (or not) |
127-
| :------------------- | :--------------------------------- | :------------------------------------- |
128-
| allocating functions | [`value_and_pullback_split`](@ref) | [`value_and_pullback!!_split`](@ref) |
129-
| mutating functions | - | [`value_and_pullback!!_split!!`](@ref) |
118+
| out-of-place | in-place |
119+
| :--------------------------------- | :---------------------------------- |
120+
| [`value_and_pullback_split`](@ref) | [`value_and_pullback!_split`](@ref) |
130121

131-
!!! danger
132-
Split reverse mode is still experimental, use at your own risk.
122+
## Not supported
123+
124+
### Batched evaluation
125+
126+
!!! info "Planned feature"
127+
Interface for providing several pushforward / pullback seeds at once, similar to the chunking in ForwardDiff.jl or the batches in Enzyme.jl.
128+
129+
### Non-standard types
130+
131+
The package is thoroughly tested with inputs and outputs of the following types: `Float64`, `Vector{Float64}` and `Matrix{Float64}`.
132+
We also expect it to work on all kinds of `Number` and `AbstractArray` variables.
133+
Beyond that, you are in uncharted territory.
134+
We voluntarily keep the type annotations minimal, so that passing more complex objects or custom structs _might work with some backends_, but we make no guarantees about that.
135+
136+
### Multiple inputs/outputs
137+
138+
Restricting the API to one input and one output has many coding advantages, but it is not very flexible.
139+
If you need more than that, use [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap several objects inside a single `ComponentVector`.

0 commit comments

Comments
 (0)