Skip to content

Commit 6454cbc

Browse files
authored
More efficient pullback for ReverseDiff (#147)
1 parent 3e6f364 commit 6454cbc

3 files changed

Lines changed: 46 additions & 22 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DifferentiationInterface:
66
DerivativeExtras, GradientExtras, HessianExtras, JacobianExtras, NoPullbackExtras
77
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult
88
using DocStringExtensions
9-
using LinearAlgebra: mul!
9+
using LinearAlgebra: dot, mul!
1010
using ReverseDiff:
1111
CompiledGradient,
1212
CompiledHessian,

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/allocating.jl

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,69 @@
33
DI.prepare_pullback(f, ::AnyAutoReverseDiff, x) = NoPullbackExtras()
44

55
function DI.value_and_pullback!!(
6+
f,
7+
dx::AbstractArray,
8+
backend::AnyAutoReverseDiff,
9+
x::AbstractArray,
10+
dy,
11+
extras::NoPullbackExtras,
12+
)
13+
return f(x), DI.pullback!!(f, dx, backend, x, dy, extras)
14+
end
15+
16+
function DI.value_and_pullback(
17+
f, backend::AnyAutoReverseDiff, x::AbstractArray, dy, extras::NoPullbackExtras
18+
)
19+
return f(x), DI.pullback(f, backend, x, dy, extras)
20+
end
21+
22+
### Number out
23+
24+
function DI.pullback!!(
625
f,
726
dx::AbstractArray,
827
::AnyAutoReverseDiff,
928
x::AbstractArray,
1029
dy::Number,
1130
::NoPullbackExtras,
1231
)
13-
y = f(x)
14-
gradient!(dx, f, x)
32+
dx = gradient!(dx, f, x)
1533
dx .*= dy
16-
return y, dx
34+
return dx
1735
end
1836

19-
function DI.value_and_pullback(
37+
function DI.pullback(
2038
f, ::AnyAutoReverseDiff, x::AbstractArray, dy::Number, ::NoPullbackExtras
2139
)
22-
y = f(x)
2340
dx = gradient(f, x)
2441
dx .*= dy
25-
return y, dx
42+
return dx
2643
end
2744

28-
function DI.value_and_pullback!!(
45+
### Array out
46+
47+
function DI.pullback!!(
2948
f,
3049
dx::AbstractArray,
3150
::AnyAutoReverseDiff,
3251
x::AbstractArray,
3352
dy::AbstractArray,
3453
::NoPullbackExtras,
3554
)
36-
y = f(x)
37-
jac = jacobian(f, x) # allocates
38-
mul!(vec(dx), transpose(jac), vec(dy))
39-
return y, dx
55+
dotproduct_closure(x) = dot(f(x), dy)
56+
dx = gradient!(dx, dotproduct_closure, x)
57+
return dx
4058
end
4159

42-
function DI.value_and_pullback(
43-
f, ::AnyAutoReverseDiff, x::AbstractArray, dy::AbstractArray, ::NoPullbackExtras
60+
function DI.pullback(
61+
f, ::AnyAutoReverseDiff, x::AbstractArray, dy::AbstractArray, extras::NoPullbackExtras
4462
)
45-
y = f(x)
46-
jac = jacobian(f, x) # allocates
47-
dx = reshape(transpose(jac) * vec(dy), size(x))
48-
return y, dx
63+
dotproduct_closure(x) = dot(f(x), dy)
64+
dx = gradient(dotproduct_closure, x)
65+
return dx
4966
end
5067

51-
### Trick for unsupported scalar input
68+
### Number in, not supported
5269

5370
function DI.value_and_pullback(
5471
f, backend::AnyAutoReverseDiff, x::Number, dy, ::NoPullbackExtras

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/mutating.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
DI.prepare_pullback(f!, ::AnyAutoReverseDiff, y, x) = NoPullbackExtras()
44

5+
### Array in
6+
57
function DI.value_and_pullback!!(
68
f!,
79
y::AbstractArray,
@@ -11,12 +13,17 @@ function DI.value_and_pullback!!(
1113
dy::AbstractArray,
1214
::NoPullbackExtras,
1315
)
14-
jac = jacobian(f!, y, x)
15-
mul!(vec(dx), transpose(jac), vec(dy))
16+
function dotproduct_closure(x)
17+
y_copy = similar(y, eltype(x))
18+
f!(y_copy, x)
19+
return dot(y_copy, dy)
20+
end
21+
dx = gradient!(dx, dotproduct_closure, x)
22+
f!(y, x)
1623
return y, dx
1724
end
1825

19-
### Trick for unsupported scalar input
26+
### Number in, not supported
2027

2128
function DI.value_and_pullback!!(
2229
f!,

0 commit comments

Comments
 (0)