Skip to content

Commit 9ec4e7e

Browse files
authored
Contexts for ReverseDiff (#505)
1 parent 959f634 commit 9ec4e7e

6 files changed

Lines changed: 332 additions & 71 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,25 @@ In practice, many AD backends have custom implementations for high-level operato
6262
| `AutoTracker` | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
6363
| `AutoZygote` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | 🔀 | ❌ |
6464

65+
Moreover, each context type is supported by a specific subset of backends:
66+
67+
| | [`Constant`](@ref) |
68+
| -------------------------- | ------------------ |
69+
| `AutoChainRules` ||
70+
| `AutoDiffractor` ||
71+
| `AutoEnzyme` (forward) ||
72+
| `AutoEnzyme` (reverse) ||
73+
| `AutoFastDifferentiation` ||
74+
| `AutoFiniteDiff` ||
75+
| `AutoFiniteDifferences` ||
76+
| `AutoForwardDiff` ||
77+
| `AutoMooncake` ||
78+
| `AutoPolyesterForwardDiff` ||
79+
| `AutoReverseDiff` ||
80+
| `AutoSymbolics` ||
81+
| `AutoTracker` ||
82+
| `AutoZygote` ||
83+
6584
## Second order
6685

6786
For second-order operators like [`second_derivative`](@ref), [`hessian`](@ref) and [`hvp`](@ref), there are two main options.
@@ -81,9 +100,9 @@ In general, using a forward outer backend over a reverse inner backend will yiel
81100
## Backend switch
82101

83102
The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
84-
It takes a function `f` and specifies that `f` should be differentiated with the backend of your choice, instead of whatever other backend the code is trying to use.
85-
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, backend1)` with `backend2`, then `backend1` steps in and `backend2` does nothing.
86-
At the moment, `DifferentiateWith` only works when `backend2` supports [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl).
103+
It takes a function `f` and specifies that `f` should be differentiated with the substitute backend of your choice, instead of whatever true backend the surrounding code is trying to use.
104+
In other words, when someone tries to differentiate `dw = DifferentiateWith(f, substitute_backend)` with `true_backend`, then `substitute_backend` steps in and `true_backend` does not dive into the function `f` itself.
105+
At the moment, `DifferentiateWith` only works when `true_backend` is either [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) or a [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible backend.
87106

88107
## Implementations
89108

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
module DifferentiationInterfaceReverseDiffExt
22

33
using ADTypes: AutoReverseDiff
4+
using Base: Fix2
45
import DifferentiationInterface as DI
56
using DifferentiationInterface:
6-
DerivativePrep, GradientPrep, HessianPrep, JacobianPrep, NoPullbackPrep
7+
Context,
8+
DerivativePrep,
9+
GradientPrep,
10+
HessianPrep,
11+
JacobianPrep,
12+
NoGradientPrep,
13+
NoHessianPrep,
14+
NoJacobianPrep,
15+
NoPullbackPrep,
16+
unwrap,
17+
with_contexts
718
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
819
using LinearAlgebra: dot, mul!
920
using ReverseDiff:

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl

Lines changed: 167 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,74 @@
11
## Pullback
22

3-
DI.prepare_pullback(f, ::AutoReverseDiff, x, ty::NTuple) = NoPullbackPrep()
3+
function DI.prepare_pullback(
4+
f, ::AutoReverseDiff, x, ty::NTuple, contexts::Vararg{Context,C}
5+
) where {C}
6+
return NoPullbackPrep()
7+
end
48

59
function DI.value_and_pullback(
6-
f, ::NoPullbackPrep, ::AutoReverseDiff, x::AbstractArray, ty::NTuple
7-
)
8-
y = f(x)
10+
f,
11+
::NoPullbackPrep,
12+
::AutoReverseDiff,
13+
x::AbstractArray,
14+
ty::NTuple,
15+
contexts::Vararg{Context,C},
16+
) where {C}
17+
fc = with_contexts(f, contexts...)
18+
y = fc(x)
19+
dotclosure(z, dy) = dot(fc(z), dy)
920
tx = map(ty) do dy
1021
if y isa Number
11-
dy .* gradient(f, x)
22+
dy .* gradient(fc, x)
1223
elseif y isa AbstractArray
13-
gradient(z -> dot(f(z), dy), x)
24+
gradient(Fix2(dotclosure, dy), x)
1425
end
1526
end
1627
return y, tx
1728
end
1829

1930
function DI.value_and_pullback!(
20-
f, ::NoPullbackPrep, tx::NTuple, ::AutoReverseDiff, x::AbstractArray, ty::NTuple
21-
)
22-
y = f(x)
31+
f,
32+
::NoPullbackPrep,
33+
tx::NTuple,
34+
::AutoReverseDiff,
35+
x::AbstractArray,
36+
ty::NTuple,
37+
contexts::Vararg{Context,C},
38+
) where {C}
39+
fc = with_contexts(f, contexts...)
40+
y = fc(x)
41+
dotclosure(z, dy) = dot(fc(z), dy)
2342
for b in eachindex(tx, ty)
2443
dx, dy = tx[b], ty[b]
2544
if y isa Number
26-
dx = gradient!(dx, f, x)
45+
dx = gradient!(dx, fc, x)
2746
dx .*= dy
2847
elseif y isa AbstractArray
29-
gradient!(dx, z -> dot(f(z), dy), x)
48+
gradient!(dx, Fix2(dotclosure, dy), x)
3049
end
3150
end
3251
return y, tx
3352
end
3453

3554
function DI.value_and_pullback(
36-
f, ::NoPullbackPrep, backend::AutoReverseDiff, x::Number, ty::NTuple
37-
)
55+
f,
56+
::NoPullbackPrep,
57+
backend::AutoReverseDiff,
58+
x::Number,
59+
ty::NTuple,
60+
contexts::Vararg{Context,C},
61+
) where {C}
3862
x_array = [x]
39-
f_array = f only
40-
y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty)
63+
f_array(x_array, args...) = f(only(x_array), args...)
64+
y, tx_array = DI.value_and_pullback(f_array, backend, x_array, ty, contexts...)
4165
return y, only.(tx_array)
4266
end
4367

4468
## Gradient
4569

70+
### Without contexts
71+
4672
struct ReverseDiffGradientPrep{T} <: GradientPrep
4773
tape::T
4874
end
@@ -56,7 +82,7 @@ function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile}
5682
end
5783

5884
function DI.value_and_gradient!(
59-
f, grad::AbstractArray, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
85+
f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
6086
)
6187
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
6288
result = MutableDiffResult(y, (grad,))
@@ -71,18 +97,55 @@ function DI.value_and_gradient(
7197
return DI.value_and_gradient!(f, grad, prep, backend, x)
7298
end
7399

74-
function DI.gradient!(
75-
_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray
76-
)
100+
function DI.gradient!(_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x)
77101
return gradient!(grad, prep.tape, x)
78102
end
79103

80104
function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x)
81105
return gradient!(prep.tape, x)
82106
end
83107

108+
### With contexts
109+
110+
function DI.prepare_gradient(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C}
111+
return NoGradientPrep()
112+
end
113+
114+
function DI.value_and_gradient!(
115+
f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
116+
) where {C}
117+
fc = with_contexts(f, contexts...)
118+
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
119+
result = MutableDiffResult(y, (grad,))
120+
result = gradient!(result, fc, x)
121+
return DiffResults.value(result), DiffResults.derivative(result)
122+
end
123+
124+
function DI.value_and_gradient(
125+
f, prep::NoGradientPrep, backend::AutoReverseDiff, x, contexts::Vararg{Context,C}
126+
) where {C}
127+
grad = similar(x)
128+
return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
129+
end
130+
131+
function DI.gradient!(
132+
f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
133+
) where {C}
134+
fc = with_contexts(f, contexts...)
135+
return gradient!(grad, fc, x)
136+
end
137+
138+
function DI.gradient(
139+
f, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
140+
) where {C}
141+
fc = with_contexts(f, contexts...)
142+
return gradient(fc, x)
143+
end
144+
84145
## Jacobian
85146

147+
### Without contexts
148+
86149
struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep
87150
tape::T
88151
end
@@ -116,8 +179,47 @@ function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff,
116179
return jacobian!(prep.tape, x)
117180
end
118181

182+
### With contexts
183+
184+
function DI.prepare_jacobian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C}
185+
return NoJacobianPrep()
186+
end
187+
188+
function DI.value_and_jacobian!(
189+
f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
190+
) where {C}
191+
fc = with_contexts(f, contexts...)
192+
y = fc(x)
193+
result = MutableDiffResult(y, (jac,))
194+
result = jacobian!(result, fc, x)
195+
return DiffResults.value(result), DiffResults.derivative(result)
196+
end
197+
198+
function DI.value_and_jacobian(
199+
f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
200+
) where {C}
201+
fc = with_contexts(f, contexts...)
202+
return fc(x), jacobian(fc, x)
203+
end
204+
205+
function DI.jacobian!(
206+
f, jac, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
207+
) where {C}
208+
fc = with_contexts(f, contexts...)
209+
return jacobian!(jac, fc, x)
210+
end
211+
212+
function DI.jacobian(
213+
f, ::NoJacobianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
214+
) where {C}
215+
fc = with_contexts(f, contexts...)
216+
return jacobian(fc, x)
217+
end
218+
119219
## Hessian
120220

221+
### Without contexts
222+
121223
struct ReverseDiffHessianPrep{T} <: HessianPrep
122224
tape::T
123225
end
@@ -152,11 +254,54 @@ end
152254
function DI.value_gradient_and_hessian(
153255
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
154256
)
155-
result = MutableDiffResult(
156-
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
157-
)
257+
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
258+
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
158259
result = hessian!(result, prep.tape, x)
159260
return (
160261
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
161262
)
162263
end
264+
265+
### With contexts
266+
267+
function DI.prepare_hessian(f, ::AutoReverseDiff, x, contexts::Vararg{Context,C}) where {C}
268+
return NoHessianPrep()
269+
end
270+
271+
function DI.hessian!(
272+
f, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
273+
) where {C}
274+
fc = with_contexts(f, contexts...)
275+
return hessian!(hess, fc, x)
276+
end
277+
278+
function DI.hessian(
279+
f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
280+
) where {C}
281+
fc = with_contexts(f, contexts...)
282+
return hessian(fc, x)
283+
end
284+
285+
function DI.value_gradient_and_hessian!(
286+
f, grad, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
287+
) where {C}
288+
fc = with_contexts(f, contexts...)
289+
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
290+
result = MutableDiffResult(y, (grad, hess))
291+
result = hessian!(result, fc, x)
292+
return (
293+
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
294+
)
295+
end
296+
297+
function DI.value_gradient_and_hessian(
298+
f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
299+
) where {C}
300+
fc = with_contexts(f, contexts...)
301+
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
302+
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
303+
result = hessian!(result, fc, x)
304+
return (
305+
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
306+
)
307+
end

0 commit comments

Comments
 (0)