Skip to content

Commit fc4d697

Browse files
authored
Add split reverse mode (#137)
1 parent 7d1fbc1 commit fc4d697

11 files changed

Lines changed: 163 additions & 61 deletions

File tree

DifferentiationInterface/docs/src/overview.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,15 @@ This means the Hessian is obtained as the sparse Jacobian of the gradient.
120120

121121
!!! danger
122122
Sparsity support is still experimental, use at your own risk.
123+
124+
### Split reverse mode
125+
126+
Many reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
127+
We make this available for everyone with the following operators:
128+
129+
| out-of-place | in-place (or not) |
130+
| ---------------------------------- | ------------------------------------ |
131+
| [`value_and_pullback_split`](@ref) | [`value_and_pullback!!_split`](@ref) |
132+
133+
!!! danger
134+
Split reverse mode is still experimental, use at your own risk.

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,6 @@ DI.supports_mutation(::AutoChainRules) = DI.MutationNotSupported()
1515
DI.mode(::AutoForwardChainRules) = ADTypes.AbstractForwardMode
1616
DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode
1717

18-
## Pushforward (unused)
19-
20-
#=
21-
DI.prepare_pushforward(f, ::AutoForwardChainRules, x) = NoPushforwardExtras()
22-
23-
function DI.value_and_pushforward(
24-
f, backend::AutoForwardChainRules, x, dx, ::NoPushforwardExtras
25-
)
26-
rc = ruleconfig(backend)
27-
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
28-
return y, new_dy
29-
end
30-
=#
31-
3218
## Pullback
3319

3420
DI.prepare_pullback(f, ::AutoForwardChainRules, x) = NoPullbackExtras()
@@ -40,4 +26,22 @@ function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, ::NoPul
4026
return y, new_dx
4127
end
4228

29+
function DI.value_and_pullback_split(
30+
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
31+
)
32+
rc = ruleconfig(backend)
33+
y, pullback = rrule_via_ad(rc, f, x)
34+
pullbackfunc(dy) = last(pullback(dy))
35+
return y, pullbackfunc
36+
end
37+
38+
function DI.value_and_pullback!!_split(
39+
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
40+
)
41+
rc = ruleconfig(backend)
42+
y, pullback = rrule_via_ad(rc, f, x)
43+
pullbackfunc!!(_dx, dy) = last(pullback(dy))
44+
return y, pullbackfunc!!
45+
end
46+
4347
end

DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ function DI.value_and_pullback(f, ::AutoTracker, x, dy, ::NoPullbackExtras)
1616
return y, data(only(back(dy)))
1717
end
1818

19+
function DI.value_and_pullback_split(f, ::AutoTracker, x, ::NoPullbackExtras)
20+
y, back = forward(f, x)
21+
pullbackfunc(dy) = data(only(back(dy)))
22+
return y, pullbackfunc
23+
end
24+
25+
function DI.value_and_pullback!!_split(f, ::AutoTracker, x, ::NoPullbackExtras)
26+
y, back = forward(f, x)
27+
pullbackfunc!!(_dx, dy) = data(only(back(dy)))
28+
return y, pullbackfunc!!
29+
end
30+
1931
## Gradient
2032

2133
DI.prepare_gradient(f, ::AutoTracker, x) = NoGradientExtras()

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ function DI.value_and_pullback(f, ::AnyAutoZygote, x, dy, ::NoPullbackExtras)
2222
return y, dx
2323
end
2424

25+
function DI.value_and_pullback_split(f, ::AnyAutoZygote, x, ::NoPullbackExtras)
26+
y, back = pullback(f, x)
27+
pullbackfunc(dy) = only(back(dy))
28+
return y, pullbackfunc
29+
end
30+
31+
function DI.value_and_pullback!!_split(f, ::AnyAutoZygote, x, ::NoPullbackExtras)
32+
y, back = pullback(f, x)
33+
pullbackfunc!!(_dx, dy) = only(back(dy))
34+
return y, pullbackfunc!!
35+
end
36+
2537
## Gradient
2638

2739
DI.prepare_gradient(f, ::AnyAutoZygote, x) = NoGradientExtras()

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ export SecondOrder
107107

108108
export value_and_pushforward!!, value_and_pushforward
109109
export value_and_pullback!!, value_and_pullback
110+
export value_and_pullback!!_split, value_and_pullback_split
110111

111112
export value_and_derivative!!, value_and_derivative
112113
export value_and_gradient!!, value_and_gradient
@@ -130,34 +131,18 @@ export prepare_second_derivative, prepare_hvp, prepare_hessian
130131
export check_available, check_mutation, check_hessian
131132

132133
function __init__()
133-
Base.Experimental.register_error_hint(StackOverflowError) do io, exc, argtypes, kwargs
134-
f_name = string(exc.f)
135-
if (
136-
f_name == "mode" ||
137-
contains(f_name, "pushforward") ||
138-
contains(f_name, "pullback") ||
139-
contains(f_name, "derivative") ||
140-
contains(f_name, "gradient") ||
141-
contains(f_name, "jacobian") ||
142-
contains(f_name, "hvp") ||
143-
contains(f_name, "hessian")
134+
Base.Experimental.register_error_hint(StackOverflowError) do io, exc
135+
print(
136+
io,
137+
"""\n
138+
HINT: One of DifferentiationInterface's functions might be missing a method, which would trigger an endless loop of `pullback` calling `pushforward` and vice-versa.
139+
Some possible fixes:
140+
- switch to another backend
141+
- if you don't want to switch, load the package extension corresponding to your backend
142+
- if your backend is already loaded, define the primitive operator for the right combination of argument types
143+
""",
144144
)
145-
for T in argtypes
146-
if T <: AbstractADType
147-
print(
148-
io,
149-
"""\n
150-
HINT: One of DifferentiationInterface's functions is missing a method, which causes an endless loop of `pullback` calling `pushforward` and vice-versa.
151-
Some possible fixes:
152-
- switch to another backend
153-
- if you don't want to switch, load the package extension corresponding to backend `$T`
154-
- if the package is already loaded, define the method `$f_name` for the right combination of argument types
155-
""",
156-
)
157-
return nothing
158-
end
159-
end
160-
end
145+
return nothing
161146
end
162147
end
163148

DifferentiationInterface/src/jacobian.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ function value_and_jacobian_aux(
6464
y = f(x)
6565
jac = stack(CartesianIndices(x); dims=2) do j
6666
dx_j = basis(backend, x, j)
67-
jac_col_j = last(
68-
value_and_pushforward(f, backend, x, dx_j, extras.pushforward_extras)
69-
)
67+
jac_col_j = pushforward(f, backend, x, dx_j, extras.pushforward_extras)
7068
vec(jac_col_j)
7169
end
7270
return y, jac
@@ -75,10 +73,10 @@ end
7573
function value_and_jacobian_aux(
7674
f, backend, x::AbstractArray, extras::PullbackJacobianExtras
7775
)
78-
y = f(x)
76+
y, pullbackfunc = value_and_pullback_split(f, backend, x, extras.pullback_extras)
7977
jac = stack(CartesianIndices(y); dims=1) do i
8078
dy_i = basis(backend, y, i)
81-
jac_row_i = last(value_and_pullback(f, backend, x, dy_i, extras.pullback_extras))
79+
jac_row_i = pullbackfunc(dy_i)
8280
vec(jac_row_i)
8381
end
8482
return y, jac
@@ -116,13 +114,11 @@ end
116114
function value_and_jacobian_aux!!(
117115
f, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras
118116
)
119-
y = f(x)
117+
y, pullbackfunc!! = value_and_pullback!!_split(f, backend, x, extras.pullback_extras)
120118
for (k, i) in enumerate(CartesianIndices(y))
121119
dy_i = basis(backend, y, i)
122120
jac_row_i_old = reshape(view(jac, k, :), size(x))
123-
jac_row_i_new = pullback!!(
124-
f, jac_row_i_old, backend, x, dy_i, extras.pullback_extras
125-
)
121+
jac_row_i_new = pullbackfunc!!(jac_row_i_old, dy_i)
126122
# this allocates
127123
copyto!(jac_row_i_old, jac_row_i_new)
128124
end
@@ -175,7 +171,6 @@ function value_and_jacobian_aux!!(
175171
x::AbstractArray,
176172
extras::PushforwardJacobianExtras,
177173
)
178-
f!(y, x)
179174
for (k, j) in enumerate(CartesianIndices(x))
180175
dx_j = basis(backend, x, j)
181176
jac_col_j_old = reshape(view(jac, :, k), size(y))
@@ -198,15 +193,13 @@ function value_and_jacobian_aux!!(
198193
x::AbstractArray,
199194
extras::PullbackJacobianExtras,
200195
)
201-
f!(y, x)
196+
y, pullbackfunc!! = value_and_pullback!!_split(
197+
f!, y, backend, x, extras.pullback_extras
198+
)
202199
for (k, i) in enumerate(CartesianIndices(y))
203200
dy_i = basis(backend, y, i)
204201
jac_row_i_old = reshape(view(jac, k, :), size(x))
205-
jac_row_i_new = last(
206-
value_and_pullback!!(
207-
f!, y, jac_row_i_old, backend, x, dy_i, extras.pullback_extras
208-
),
209-
)
202+
jac_row_i_new = pullbackfunc!!(y, jac_row_i_old, dy_i)
210203
# this allocates
211204
copyto!(jac_row_i_old, jac_row_i_new)
212205
end

DifferentiationInterface/src/pullback.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,60 @@ function value_and_pullback!!(
124124
end
125125
return y, dx
126126
end
127+
128+
## Closure
129+
130+
"""
131+
value_and_pullback_split(f, backend, x, [extras])
132+
133+
Apply split reverse mode autodiff.
134+
135+
Returns a tuple `(y, pullbackfunc)` where the second element is a function (closure) with the following signature:
136+
137+
pullbackfunc(dy) -> dx
138+
"""
139+
function value_and_pullback_split(
140+
f, backend::AbstractADType, x, extras::PullbackExtras=prepare_pullback(f, backend, x)
141+
)
142+
pullbackfunc(dy) = pullback(f, backend, x, dy, extras)
143+
return f(x), pullbackfunc
144+
end
145+
146+
"""
147+
value_and_pullback!!_split(f, backend, x, [extras])
148+
149+
Apply split reverse mode autodiff.
150+
151+
Returns a tuple `(y, pullbackfunc!!)` where the second element is a function (closure) with the following signature:
152+
153+
pullbackfunc!!(dx, dy) -> dx
154+
"""
155+
function value_and_pullback!!_split(
156+
f, backend::AbstractADType, x, extras::PullbackExtras=prepare_pullback(f, backend, x)
157+
)
158+
pullbackfunc!!(dx, dy) = pullback!!(f, dx, backend, x, dy, extras)
159+
return f(x), pullbackfunc!!
160+
end
161+
162+
"""
163+
value_and_pullback!!_split(f!, y, backend, x, [extras])
164+
165+
Apply split reverse mode autodiff.
166+
167+
Returns a tuple `(y, pullbackfunc!!)` where the second element is a function (closure) with the following signature:
168+
169+
pullbackfunc!!(y, dx, dy) -> dx
170+
"""
171+
function value_and_pullback!!_split(
172+
f!,
173+
y,
174+
backend::AbstractADType,
175+
x,
176+
extras::PullbackExtras=prepare_pullback(f!, backend, y, x),
177+
)
178+
function pullbackfunc!!(y, dx, dy)
179+
return value_and_pullback!!(f!, y, dx, backend, x, dy, extras)[2]
180+
end
181+
f!(y, x)
182+
return y, pullbackfunc!!
183+
end

DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737
function record!(
3838
data::Vector{BenchmarkDataRow},
3939
backend::AbstractADType,
40-
operator::Function,
40+
operator,
4141
scenario::AbstractScenario,
4242
bench,
4343
)
@@ -98,10 +98,13 @@ function run_benchmark!(
9898
)
9999
(; f, x, y, dy) = deepcopy(scen)
100100
extras = prepare_pullback(f, ba, x)
101+
_, pullbackfunc!! = value_and_pullback!!_split(f, ba, x, extras)
101102
bench1 = @be mysimilar(x) value_and_pullback!!(f, _, ba, x, dy, extras)
102103
bench2 = @be mysimilar(x) pullback!!(f, _, ba, x, dy, extras)
104+
bench3 = @be mysimilar(x) pullbackfunc!!(_, dy)
103105
record!(data, ba, value_and_pullback!!, scen, bench1)
104106
record!(data, ba, pullback!!, scen, bench2)
107+
record!(data, ba, pullbackfunc!!, scen, bench3)
105108
return nothing
106109
end
107110

@@ -111,10 +114,13 @@ function run_benchmark!(
111114
(; f, x, y, dy) = deepcopy(scen)
112115
f! = f
113116
extras = prepare_pullback(f!, ba, y, x)
117+
_, pullbackfunc!! = value_and_pullback!!_split(f!, y, ba, x, extras)
114118
bench1 = @be (mysimilar(y), mysimilar(x)) value_and_pullback!!(
115119
f!, _[1], _[2], ba, x, dy, extras
116120
)
121+
bench2 = @be (mysimilar(y), mysimilar(x)) pullbackfunc!!(_[1], _[2], dy)
117122
record!(data, ba, value_and_pullback!!, scen, bench1)
123+
record!(data, ba, pullbackfunc!!, scen, bench2)
118124
return nothing
119125
end
120126

DifferentiationInterfaceTest/src/tests/correctness.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,25 @@ function test_correctness(
109109
dx3 = pullback(f, ba, x, dy, extras)
110110
dx4 = pullback!!(f, mysimilar(x), ba, x, dy, extras)
111111

112+
y5, pullbackfunc = value_and_pullback_split(f, ba, x, extras)
113+
dx5 = pullbackfunc(dy)
114+
y6, pullbackfunc!! = value_and_pullback!!_split(f, ba, x, extras)
115+
dx6 = pullbackfunc!!(mysimilar(x), dy)
116+
112117
let ()(x, y) = isapprox(x, y; atol, rtol)
113118
@testset "Primal value" begin
114119
@test y1 y
115120
@test y2 y
121+
@test y5 y
122+
@test y6 y
116123
end
117124
@testset "Cotangent value" begin
118125
@test dx1 dx_true
119126
@test dx2 dx_true
120127
@test dx3 dx_true
121128
@test dx4 dx_true
129+
@test dx5 dx_true
130+
@test dx6 dx_true
122131
end
123132
end
124133
test_scen_intact(new_scen, scen)
@@ -145,13 +154,20 @@ function test_correctness(
145154
y10 = mysimilar(y)
146155
y1, dx1 = value_and_pullback!!(f!, y10, mysimilar(x), ba, x, dy, extras)
147156

157+
y20 = mysimilar(y)
158+
y2, pullbackfunc!! = value_and_pullback!!_split(f!, y20, ba, x, extras)
159+
dx2 = pullbackfunc!!(y20, mysimilar(x), dy)
160+
148161
let ()(x, y) = isapprox(x, y; atol, rtol)
149162
@testset "Primal value" begin
150163
@test y10 y
164+
@test y20 y
151165
@test y1 y
166+
@test y2 y
152167
end
153168
@testset "Cotangent value" begin
154169
@test dx1 dx_true
170+
@test dx2 dx_true
155171
end
156172
end
157173
test_scen_intact(new_scen, scen)

DifferentiationInterfaceTest/src/tests/type_stability.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,14 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{false};)
3232
extras = prepare_pullback(f, ba, x)
3333
dx_in = mysimilar(x)
3434

35+
_, pullbackfunc!! = value_and_pullback!!_split(f, ba, x, extras)
36+
_, pullbackfunc = value_and_pullback_split(f, ba, x, extras)
37+
3538
if Bool(pullback_performance(ba))
3639
@test_opt value_and_pullback!!(f, dx_in, ba, x, dy, extras)
3740
@test_opt value_and_pullback(f, ba, x, dy, extras)
41+
@test_opt pullbackfunc!!(dx_in, dy)
42+
@test_opt pullbackfunc(dy)
3843
end
3944
return nothing
4045
end
@@ -46,8 +51,11 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{true};)
4651
y_in = mysimilar(y)
4752
dx_in = mysimilar(x)
4853

54+
_, pullbackfunc!! = value_and_pullback!!_split(f!, y, ba, x, extras)
55+
4956
if Bool(pullback_performance(ba))
5057
@test_opt value_and_pullback!!(f!, y_in, dx_in, ba, x, dy, extras)
58+
@test_opt pullbackfunc!!(y_in, dx_in, dy)
5159
end
5260
return nothing
5361
end

0 commit comments

Comments
 (0)