Skip to content

Commit 77704d3

Browse files
authored
Make value_and_pullback the primitive (#157)
1 parent 95fe3bd commit 77704d3

18 files changed

Lines changed: 296 additions & 180 deletions

File tree

DifferentiationInterface/docs/src/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function all_backends()
1818
AutoFiniteDiff(),
1919
AutoFiniteDifferences(FiniteDifferences.central_fdm(3, 1)),
2020
AutoForwardDiff(),
21-
AutoPolyesterForwardDiff(; chunksize=2),
21+
AutoPolyesterForwardDiff(; chunksize=1),
2222
AutoReverseDiff(),
2323
AutoTapir(),
2424
AutoTracker(),

DifferentiationInterface/docs/src/design.md

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,5 @@ Advanced users are welcome to code more backends and submit pull requests!
1717

1818
## Fallback call structure
1919

20-
For simplicity, we remove `value_` in the operator names below.
21-
22-
!!! note "Edge labels"
23-
24-
Full edges in the following graphs require a single call to the destination.
25-
Dotted edges require multiple calls to the destination, the number is indicated on the edge.
26-
2720
!!! warning
28-
This is still in flux, come back later!
21+
This is still in flux, come back later!

DifferentiationInterface/docs/src/overview.md

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ Most backends have custom implementations, which we reuse if possible.
77

88
We choose the following terminology for the high-level operators we provide:
99

10-
| operator | input `x` | output `y` | result type | result shape |
11-
| -------------------- | --------------- | --------------- | ---------------- | ------------------------ |
12-
| [`derivative`](@ref) | `Number` | `Any` | same as `y` | `size(y)` |
13-
| [`gradient`](@ref) | `Any` | `Number` | same as `x` | `size(x)` |
14-
| [`jacobian`](@ref) | `AbstractArray` | `AbstractArray` | `AbstractMatrix` | `(length(y), length(x))` |
10+
| operator | input `x` | output `y` | result type | result shape |
11+
| -------------------- | --------------- | --------------------------- | ---------------- | ------------------------ |
12+
| [`derivative`](@ref) | `Number` | `Number` or `AbstractArray` | same as `y` | `size(y)` |
13+
| [`gradient`](@ref) | `AbstractArray` | `Number` | same as `x` | `size(x)` |
14+
| [`jacobian`](@ref) | `AbstractArray` | `AbstractArray` | `AbstractMatrix` | `(length(y), length(x))` |
1515

1616
They are all based on the following low-level operators:
1717

@@ -58,11 +58,11 @@ You can either pick a single backend to do all the work, or combine an "outer" b
5858

5959
The available operators are similar to first-order ones:
6060

61-
| operator | input `x` | output `y` | result type | result shape |
62-
| --------------------------- | --------------- | ------------ | ---------------- | ------------------------ |
63-
| [`second_derivative`](@ref) | `Number` | `Any` | same as `y` | `size(y)` |
64-
| [`hvp`](@ref) | `Any` | `Number` | same as `x` | `size(x)` |
65-
| [`hessian`](@ref) | `AbstractArray` | `Number` | `AbstractMatrix` | `(length(x), length(x))` |
61+
| operator | input `x` | output `y` | result type | result shape |
62+
| --------------------------- | --------------- | --------------------------- | ---------------- | ------------------------ |
63+
| [`second_derivative`](@ref) | `Number` | `Number` or `AbstractArray` | same as `y` | `size(y)` |
64+
| [`hvp`](@ref) | `AbstractArray` | `Number` | same as `x` | `size(x)` |
65+
| [`hessian`](@ref) | `AbstractArray` | `Number` | `AbstractMatrix` | `(length(x), length(x))` |
6666

6767
We only define two variants for now:
6868

@@ -94,12 +94,9 @@ This is a backend-specific procedure, but we expose a common syntax to achieve i
9494
If you run `prepare_operator(backend, f, x)`, it will create an object called `extras` containing the necessary information to speed up `operator` and its variants.
9595
This information is specific to `backend` and `f`, as well as the _type and size_ of the input `x`, but it should work with different _values_ of `x`.
9696

97-
You can then call `operator(backend, f, similar_x, extras)`, which should be faster than `operator(backend, f, similar_x)`.
97+
You can then call `operator(backend, f, x2, extras)`, which should be faster than `operator(f, backend, x2)`.
9898
This is especially worth it if you plan to call `operator` several times in similar settings: you can think of it as a warm up.
9999

100-
By default, all the preparation functions return `nothing`.
101-
We do not make any guarantees on their implementation for each backend, or on the performance gains that can be expected.
102-
103100
!!! warning
104101
For `SecondOrder` backends, the inner differentiation cannot be prepared at the moment, only the outer one is.
105102

@@ -123,12 +120,13 @@ This means the Hessian is obtained as the sparse Jacobian of the gradient.
123120

124121
### Split reverse mode
125122

126-
Many reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
127-
We make this available for allocating functions only, with the following operators:
123+
Some reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
124+
We make this available for all backends with the following operators:
128125

129-
| out-of-place | in-place (or not) |
130-
| ---------------------------------- | ------------------------------------ |
131-
| [`value_and_pullback_split`](@ref) | [`value_and_pullback!!_split`](@ref) |
126+
| | out-of-place | in-place (or not) |
127+
| -------------------- | ---------------------------------- | -------------------------------------- |
128+
| allocating functions | [`value_and_pullback_split`](@ref) | [`value_and_pullback!!_split`](@ref) |
129+
| mutating functions | - | [`value_and_pullback!!_split!!`](@ref) |
132130

133131
!!! danger
134132
Split reverse mode is still experimental, use at your own risk.

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode
1717

1818
## Pullback
1919

20-
DI.prepare_pullback(f, ::AutoForwardChainRules, x) = NoPullbackExtras()
20+
DI.prepare_pullback(f, ::AutoReverseChainRules, x) = NoPullbackExtras()
2121

2222
function DI.value_and_pullback_split(
2323
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
@@ -28,4 +28,11 @@ function DI.value_and_pullback_split(
2828
return y, pullbackfunc
2929
end
3030

31+
function DI.value_and_pullback(
32+
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras
33+
)
34+
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
35+
return y, pullbackfunc(dy)
36+
end
37+
3138
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,6 @@ function DI.pullback!!(f, dx, backend::AutoReverseEnzyme, x, dy, extras::NoPullb
6969
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)[2]
7070
end
7171

72-
### Closure
73-
74-
function DI.value_and_pullback_split(
75-
f, backend::AutoReverseEnzyme, x, extras::NoPullbackExtras
76-
)
77-
y = f(x)
78-
pullbackfunc(dy) = DI.pullback(f, backend, x, dy, extras)
79-
return y, pullbackfunc
80-
end
81-
82-
function DI.value_and_pullback!!_split(
83-
f, backend::AutoReverseEnzyme, x, extras::NoPullbackExtras
84-
)
85-
y = f(x)
86-
pullbackfunc!!(dx, dy) = DI.pullback!!(f, dx, backend, x, dy, extras)
87-
return y, pullbackfunc!!
88-
end
89-
9072
## Gradient
9173

9274
DI.prepare_gradient(f, ::AutoReverseEnzyme, x) = NoGradientExtras()

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/allocating.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,38 @@
22

33
DI.prepare_pullback(f, ::AnyAutoReverseDiff, x) = NoPullbackExtras()
44

5-
function DI.value_and_pullback_split(
6-
f, ::AnyAutoReverseDiff, x::AbstractArray, ::NoPullbackExtras
5+
function DI.value_and_pullback(
6+
f, ::AnyAutoReverseDiff, x::AbstractArray, dy, ::NoPullbackExtras
77
)
88
y = f(x)
9-
pullbackfunc = if y isa Number
10-
dy -> dy .* gradient(f, x)
9+
dx = if y isa Number
10+
dy .* gradient(f, x)
1111
elseif y isa AbstractArray
12-
dy -> gradient(z -> dot(f(z), dy), x)
12+
gradient(z -> dot(f(z), dy), x)
1313
end
14-
return y, pullbackfunc
14+
return y, dx
1515
end
1616

17-
function DI.value_and_pullback!!_split(
18-
f, ::AnyAutoReverseDiff, x::AbstractArray, ::NoPullbackExtras
17+
function DI.value_and_pullback!!(
18+
f, dx, ::AnyAutoReverseDiff, x::AbstractArray, dy, ::NoPullbackExtras
1919
)
2020
y = f(x)
21-
pullbackfunc!! = if y isa Number
22-
(dx, dy) -> begin
23-
dx = gradient!(dx, f, x)
24-
dx .*= dy
25-
end
21+
dx = if y isa Number
22+
dx = gradient!(dx, f, x)
23+
dx .*= dy
2624
elseif y isa AbstractArray
27-
(dx, dy) -> gradient!(dx, z -> dot(f(z), dy), x)
25+
gradient!(dx, z -> dot(f(z), dy), x)
2826
end
29-
return y, pullbackfunc!!
27+
return y, dx
3028
end
3129

32-
function DI.value_and_pullback_split(
33-
f, backend::AnyAutoReverseDiff, x::Number, ::NoPullbackExtras
30+
function DI.value_and_pullback(
31+
f, backend::AnyAutoReverseDiff, x::Number, dy, ::NoPullbackExtras
3432
)
3533
x_array = [x]
3634
f_array = f only
37-
y, pullbackfunc = DI.value_and_pullback_split(f_array, backend, x_array)
38-
return y, only pullbackfunc
35+
y, dx_array = DI.value_and_pullback(f_array, backend, x_array, dy)
36+
return y, only(dx_array)
3937
end
4038

4139
## Gradient

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/allocating.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,3 @@ function DI.pullback!!(
3838
)
3939
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)[2]
4040
end
41-
42-
function DI.value_and_pullback_split(
43-
f, backend::AutoTapir, x, extras::TapirAllocatingPullbackExtras
44-
)
45-
y = f(x)
46-
pullbackfunc(dy) = DI.pullback(f, backend, x, dy, extras)
47-
return y, pullbackfunc
48-
end
49-
50-
function DI.value_and_pullback!!_split(
51-
f, backend::AutoTapir, x, extras::TapirAllocatingPullbackExtras
52-
)
53-
y = f(x)
54-
pullbackfunc!!(dx, dy) = DI.pullback!!(f, dx, backend, x, dy, extras)
55-
return y, pullbackfunc!!
56-
end

DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ function DI.value_and_pullback_split(f, ::AutoTracker, x, ::NoPullbackExtras)
1717
return y, pullbackfunc
1818
end
1919

20+
function DI.value_and_pullback(f, backend::AutoTracker, x, dy, extras::NoPullbackExtras)
21+
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
22+
return y, pullbackfunc(dy)
23+
end
24+
2025
## Gradient
2126

2227
DI.prepare_gradient(f, ::AutoTracker, x) = NoGradientExtras()

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ function DI.value_and_pullback_split(f, ::AnyAutoZygote, x, ::NoPullbackExtras)
2222
return y, pullbackfunc
2323
end
2424

25+
function DI.value_and_pullback(f, backend::AnyAutoZygote, x, dy, extras::NoPullbackExtras)
26+
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
27+
return y, pullbackfunc(dy)
28+
end
29+
2530
## Gradient
2631

2732
DI.prepare_gradient(f, ::AnyAutoZygote, x) = NoGradientExtras()

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ export SecondOrder
107107

108108
export value_and_pushforward!!, value_and_pushforward
109109
export value_and_pullback!!, value_and_pullback
110-
export value_and_pullback!!_split, value_and_pullback_split
110+
export value_and_pullback!!_split, value_and_pullback_split, value_and_pullback!!_split!!
111111

112112
export value_and_derivative!!, value_and_derivative
113113
export value_and_gradient!!, value_and_gradient

0 commit comments

Comments
 (0)