Skip to content

Commit 33475ed

Browse files
authored
Remove split reverse mode for mutating functions (#143)
* Remove split reverse mode for mutating functions * Remove tests * Rm benchmark
1 parent 28b537d commit 33475ed

6 files changed

Lines changed: 6 additions & 41 deletions

File tree

DifferentiationInterface/docs/src/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ This means the Hessian is obtained as the sparse Jacobian of the gradient.
124124
### Split reverse mode
125125

126126
Many reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
127-
We make this available for everyone with the following operators:
127+
We make this available for allocating functions only, with the following operators:
128128

129129
| out-of-place | in-place (or not) |
130130
| ---------------------------------- | ------------------------------------ |

DifferentiationInterface/src/jacobian.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,14 @@ function value_and_jacobian_aux!!(
193193
x::AbstractArray,
194194
extras::PullbackJacobianExtras,
195195
)
196-
y, pullbackfunc!! = value_and_pullback!!_split(
197-
f!, y, backend, x, extras.pullback_extras
198-
)
199196
for (k, i) in enumerate(CartesianIndices(y))
200197
dy_i = basis(backend, y, i)
201198
jac_row_i_old = reshape(view(jac, k, :), size(x))
202-
jac_row_i_new = pullbackfunc!!(y, jac_row_i_old, dy_i)
199+
jac_row_i_new = last(
200+
value_and_pullback!!(
201+
f!, y, jac_row_i_old, backend, x, dy_i, extras.pullback_extras
202+
),
203+
)
203204
# this allocates
204205
copyto!(jac_row_i_old, jac_row_i_new)
205206
end

DifferentiationInterface/src/pullback.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -158,26 +158,3 @@ function value_and_pullback!!_split(
158158
pullbackfunc!!(dx, dy) = pullback!!(f, dx, backend, x, dy, extras)
159159
return f(x), pullbackfunc!!
160160
end
161-
162-
"""
163-
value_and_pullback!!_split(f!, y, backend, x, [extras])
164-
165-
Apply split reverse mode autodiff.
166-
167-
Returns a tuple `(y, pullbackfunc!!)` where the second element is a function (closure) with the following signature:
168-
169-
pullbackfunc!!(y, dx, dy) -> dx
170-
"""
171-
function value_and_pullback!!_split(
172-
f!,
173-
y,
174-
backend::AbstractADType,
175-
x,
176-
extras::PullbackExtras=prepare_pullback(f!, backend, y, x),
177-
)
178-
function pullbackfunc!!(y, dx, dy)
179-
return value_and_pullback!!(f!, y, dx, backend, x, dy, extras)[2]
180-
end
181-
f!(y, x)
182-
return y, pullbackfunc!!
183-
end

DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,10 @@ function run_benchmark!(
114114
(; f, x, y, dy) = deepcopy(scen)
115115
f! = f
116116
extras = prepare_pullback(f!, ba, y, x)
117-
_, pullbackfunc!! = value_and_pullback!!_split(f!, y, ba, x, extras)
118117
bench1 = @be (mysimilar(y), mysimilar(x)) value_and_pullback!!(
119118
f!, _[1], _[2], ba, x, dy, extras
120119
)
121-
bench2 = @be (mysimilar(y), mysimilar(x)) pullbackfunc!!(_[1], _[2], dy)
122120
record!(data, ba, value_and_pullback!!, scen, bench1)
123-
record!(data, ba, pullbackfunc!!, scen, bench2)
124121
return nothing
125122
end
126123

DifferentiationInterfaceTest/src/tests/correctness.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,13 @@ function test_correctness(
154154
y10 = mysimilar(y)
155155
y1, dx1 = value_and_pullback!!(f!, y10, mysimilar(x), ba, x, dy, extras)
156156

157-
y20 = mysimilar(y)
158-
y2, pullbackfunc!! = value_and_pullback!!_split(f!, y20, ba, x, extras)
159-
dx2 = pullbackfunc!!(y20, mysimilar(x), dy)
160-
161157
let ()(x, y) = isapprox(x, y; atol, rtol)
162158
@testset "Primal value" begin
163159
@test y10 y
164-
@test y20 y
165160
@test y1 y
166-
@test y2 y
167161
end
168162
@testset "Cotangent value" begin
169163
@test dx1 dx_true
170-
@test dx2 dx_true
171164
end
172165
end
173166
test_scen_intact(new_scen, scen)

DifferentiationInterfaceTest/src/tests/type_stability.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{true};)
5151
y_in = mysimilar(y)
5252
dx_in = mysimilar(x)
5353

54-
_, pullbackfunc!! = value_and_pullback!!_split(f!, y, ba, x, extras)
55-
5654
if Bool(pullback_performance(ba))
5755
@test_opt value_and_pullback!!(f!, y_in, dx_in, ba, x, dy, extras)
58-
@test_opt pullbackfunc!!(y_in, dx_in, dy)
5956
end
6057
return nothing
6158
end

0 commit comments

Comments
 (0)