Skip to content

Commit 21cf0f4

Browse files
authored
Gradient for AutoFastDifferentiation + some renaming (#180)
* Gradient for AutoFastDifferentiation + some renaming * Reactivate first order * Documenter mutation better
1 parent c241361 commit 21cf0f4

20 files changed

Lines changed: 129 additions & 48 deletions

File tree

DifferentiationInterface/README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ This package provides a backend-agnostic syntax to differentiate functions of th
3030

3131
We support most of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):
3232

33-
| Backend | Object |
34-
| :------------------------------------------------------------------------------ | :--------------------------------------------------------- |
35-
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
36-
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
37-
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Enzyme.Forward)`, `AutoEnzyme(Enzyme.Reverse)` |
38-
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
39-
| [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) | `AutoFiniteDifferences(fdm)` |
40-
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
41-
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |
42-
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
43-
| [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl) | `AutoSparseForwardDiff()`, `AutoSparseFiniteDiff()` |
44-
| [Tracker.jl](https://github.com/FluxML/Tracker.jl) | `AutoTracker()` |
45-
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |
33+
| Backend | Object |
34+
| :------------------------------------------------------------------------------ | :----------------------------------------------------------------------- |
35+
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(; ruleconfig)` |
36+
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
37+
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(; mode=Enzyme.Forward)`, `AutoEnzyme(; mode=Enzyme.Reverse)` |
38+
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
39+
| [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) | `AutoFiniteDifferences(; fdm)` |
40+
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
41+
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |
42+
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
43+
| [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl) | `AutoSparseForwardDiff()`, `AutoSparseFiniteDiff()` |
44+
| [Tracker.jl](https://github.com/FluxML/Tracker.jl) | `AutoTracker()` |
45+
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |
4646

4747
We also provide some experimental backends ourselves:
4848

DifferentiationInterface/docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ value_and_pullback!_split
8787

8888
```@docs
8989
check_available
90-
check_mutation
90+
check_twoarg
9191
check_hessian
9292
```
9393

DifferentiationInterface/docs/src/backends.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
9393

9494
All backends are compatible with one-argument functions `f(x) = y`.
9595
Only some are compatible with two-argument functions `f!(y, x) = nothing`.
96-
You can use [`check_mutation`](@ref) to check that feature, like we did below:
96+
You can use [`check_twoarg`](@ref) to check that feature, like we did below:
9797

9898
```@example backends
9999
header = "| backend | mutation |" # hide
100100
subheader = "|:---|:---:|" # hide
101101
rows = map(all_backends()) do backend # hide
102-
"| `$(backend_string(backend))` | $(check_mutation(backend) ? '✅' : '❌') |" # hide
102+
"| `$(backend_string(backend))` | $(check_twoarg(backend) ? '✅' : '❌') |" # hide
103103
end # hide
104104
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
105105
```

DifferentiationInterface/docs/src/overview.md

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ Several variants of each operator are defined:
4747
In order to ensure symmetry between one-argument functions `f(x) = y` and two-argument functions `f!(y, x) = nothing`, we define the same operators for both cases.
4848
However they have different signatures:
4949

50-
| signature | out-of-place | in-place |
51-
| :--------- | :--------------------------------- | :--------------------------------------- |
52-
| `f(x)` | `operator(f, backend, x, ...)` | `operator!(f, res, backend, x, ...)` |
53-
| `f!(y, x)` | `operator(f!, y, backend, x, ...)` | `operator!(f!, y, res, backend, x, ...)` |
50+
| signature | out-of-place | in-place |
51+
| :--------- | :--------------------------------- | :------------------------------------------ |
52+
| `f(x)` | `operator(f, backend, x, ...)` | `operator!(f, result, backend, x, ...)` |
53+
| `f!(y, x)` | `operator(f!, y, backend, x, ...)` | `operator!(f!, y, result, backend, x, ...)` |
5454

5555
!!! warning
56-
Every variant of the operator will mutate `y` when applied to a two-argument function `f!(y, x) = nothing`, even if it does not have a `!` in its name.
56+
Our mutation convention is that all positional arguments between `f`/`f!` and `backend` are mutated (the `extras` as well, see below).
57+
This convention holds regardless of the bang `!` in the operator name, because we assume that a user passing a two-argument function `f!(y, x)` anticipates mutation anyway.
58+
59+
Still, better be careful with two-argument functions, because every variant of the operator will mutate `y`... even if it does not have a `!` in its name (see the bottom left cell in the table).
5760

5861
## Preparation
5962

@@ -71,14 +74,21 @@ This is a backend-specific procedure, but we expose a common syntax to achieve i
7174
| `pullback` | [`prepare_pullback`](@ref) |
7275
| `hvp` | [`prepare_hvp`](@ref) |
7376

74-
If you run `prepare_operator(backend, f, x, [seed])`, it will create an object called `extras` containing the necessary information to speed up `operator` and its variants.
77+
Unsurprisingly, preparation syntax depends on the number of arguments:
78+
79+
| signature | preparation signature |
80+
| :--------- | :----------------------------------------- |
81+
| `f(x)` | `prepare_operator(f, backend, x, ...)` |
82+
| `f!(y, x)` | `prepare_operator(f!, y, backend, x, ...)` |
83+
84+
The preparation `prepare_operator(f, backend, x)` will create an object called `extras` containing the necessary information to speed up `operator` and its variants.
7585
This information is specific to `backend` and `f`, as well as the _type and size_ of the input `x` and the _control flow_ within the function, but it should work with different _values_ of `x`.
7686

77-
You can then call `operator(backend, f, x2, extras)`, which should be faster than `operator(f, backend, x2)`.
87+
You can then call e.g. `operator(backend, f, x2, extras)`, which should be faster than `operator(f, backend, x2)`.
7888
This is especially worth it if you plan to call `operator` several times in similar settings: you can think of it as a warm up.
7989

8090
!!! warning
81-
The `extras` object is nearly always mutated, even if the operator does not have a `!` in its name.
91+
The `extras` object is nearly always mutated when given to an operator, even when said operator does not have a bang `!` in its name.
8292

8393
### Second order
8494

@@ -129,7 +139,7 @@ We make this available for all backends with the following operators:
129139
### Non-standard types
130140

131141
The package is thoroughly tested with inputs and outputs of the following types: `Float64`, `Vector{Float64}` and `Matrix{Float64}`.
132-
We also expect it to work on all kinds of `Number` and `AbstractArray` variables.
142+
We also expect it to work on most kinds of `Number` and `AbstractArray` variables.
133143
Beyond that, you are in uncharted territory.
134144
We voluntarily keep the type annotations minimal, so that passing more complex objects or custom structs _might work with some backends_, but we make no guarantees about that.
135145

DifferentiationInterface/docs/src/tutorial.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ Some backends get a speed boost from this trick.
6363

6464
```@repl tuto
6565
grad = zero(x)
66-
grad = gradient!(f, grad, backend, x)
66+
gradient!(f, grad, backend, x);
67+
grad
6768
```
6869

69-
Note the double exclamation mark, which is a convention telling you that `grad` _may or may not_ be overwritten, but will be returned either way (see [this section](@ref Variants) for more details).
70+
The bang indicates that one of the arguments of `gradient!` might be mutated.
71+
More precisely, our convention is that _every positional argument between the function and the backend is mutated (and the `extras` too, see below)_.
7072

7173
```@repl tuto
7274
@btime gradient!($f, _grad, $backend, $x) evals=1 setup=(_grad=similar($x));
@@ -90,7 +92,8 @@ You don't need to know what this object is, you just need to pass it to the grad
9092

9193
```@repl tuto
9294
grad = zero(x);
93-
grad = gradient!(f, grad, backend, x, extras)
95+
gradient!(f, grad, backend, x, extras);
96+
grad
9497
```
9598

9699
Preparation makes the gradient computation much faster, and (in this case) allocation-free.
@@ -102,6 +105,8 @@ Preparation makes the gradient computation much faster, and (in this case) alloc
102105
);
103106
```
104107

108+
Beware that the `extras` object is nearly always mutated by differentiation operators, even though it is given as the last positional argument.
109+
105110
## Switching backends
106111

107112
The whole point of DifferentiationInterface.jl is that you can easily experiment with different AD solutions.

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}}
1212
const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}}
1313

1414
DI.check_available(::AutoChainRules) = true
15-
DI.supports_mutation(::AutoChainRules) = DI.MutationNotSupported()
15+
DI.mutation_support(::AutoChainRules) = DI.MutationNotSupported()
1616
DI.mode(::AutoForwardChainRules) = ADTypes.AbstractForwardMode
1717
DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode
1818

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DifferentiationInterface: NoPushforwardExtras
66
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆
77

88
DI.check_available(::AutoDiffractor) = true
9-
DI.supports_mutation(::AutoDiffractor) = DI.MutationNotSupported()
9+
DI.mutation_support(::AutoDiffractor) = DI.MutationNotSupported()
1010
DI.mode(::AutoDiffractor) = ADTypes.AbstractForwardMode
1111
DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = ADTypes.AbstractForwardMode
1212

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ const AnyAutoFastDifferentiation = Union{
3131
AutoFastDifferentiation,AutoSparseFastDifferentiation
3232
}
3333

34-
DI.check_available(::AutoFastDifferentiation) = true
34+
DI.check_available(::AnyAutoFastDifferentiation) = true
3535
DI.mode(::AnyAutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
3636
DI.pushforward_performance(::AnyAutoFastDifferentiation) = DI.PushforwardFast()
3737
DI.pullback_performance(::AnyAutoFastDifferentiation) = DI.PullbackSlow()

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,64 @@ function DI.value_and_derivative!(
136136
return f(x), DI.derivative!(f, der, backend, x, extras)
137137
end
138138

139+
## Gradient
140+
141+
struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras
142+
jac_exe::E1
143+
jac_exe!::E2
144+
end
145+
146+
function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x)
147+
y_prototype = f(x)
148+
x_var = make_variables(:x, size(x)...)
149+
y_var = f(x_var)
150+
151+
x_vec_var = vec(x_var)
152+
y_vec_var = monovec(y_var)
153+
jac_var = jacobian(y_vec_var, x_vec_var)
154+
jac_exe = make_function(jac_var, x_vec_var; in_place=false)
155+
jac_exe! = make_function(jac_var, x_vec_var; in_place=true)
156+
return FastDifferentiationOneArgGradientExtras(jac_exe, jac_exe!)
157+
end
158+
159+
function DI.gradient(
160+
f, ::AnyAutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
161+
)
162+
jac = extras.jac_exe(vec(x))
163+
grad_vec = @view jac[1, :]
164+
return reshape(grad_vec, size(x))
165+
end
166+
167+
function DI.gradient!(
168+
f,
169+
grad,
170+
::AnyAutoFastDifferentiation,
171+
x,
172+
extras::FastDifferentiationOneArgGradientExtras,
173+
)
174+
extras.jac_exe!(reshape(grad, 1, length(grad)), vec(x))
175+
return grad
176+
end
177+
178+
function DI.value_and_gradient(
179+
f,
180+
backend::AnyAutoFastDifferentiation,
181+
x,
182+
extras::FastDifferentiationOneArgGradientExtras,
183+
)
184+
return f(x), DI.gradient(f, backend, x, extras)
185+
end
186+
187+
function DI.value_and_gradient!(
188+
f,
189+
grad,
190+
backend::AnyAutoFastDifferentiation,
191+
x,
192+
extras::FastDifferentiationOneArgGradientExtras,
193+
)
194+
return f(x), DI.gradient!(f, grad, backend, x, extras)
195+
end
196+
139197
## Jacobian
140198

141199
struct FastDifferentiationOneArgJacobianExtras{Y,E1,E2} <: JacobianExtras

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
99
using LinearAlgebra: dot
1010

1111
DI.check_available(::AutoFiniteDifferences) = true
12-
DI.supports_mutation(::AutoFiniteDifferences) = DI.MutationNotSupported()
12+
DI.mutation_support(::AutoFiniteDifferences) = DI.MutationNotSupported()
1313

1414
function FiniteDifferences.to_vec(a::OneElement) # TODO: remove type piracy (https://github.com/JuliaDiff/FiniteDifferences.jl/issues/141)
1515
return FiniteDifferences.to_vec(collect(a))

0 commit comments

Comments
 (0)