Skip to content

Commit ccf5247

Browse files
authored
[BREAKING] Force use of Tangents in pushforward, pullback and hvp (#455)
* Force use of `Tangents` * Fix * Add `map` and other utilities * Add methods * Scenarios * Fixes * Fixes * Fix * Coverage and docs * Avoid duplicate imports * Map * Dix focs
1 parent 25289d9 commit ccf5247

54 files changed

Lines changed: 437 additions & 386 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DifferentiationInterface/docs/src/api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ DifferentiationInterface
1010

1111
## First order
1212

13+
```@docs
14+
Tangents
15+
```
16+
1317
### Pushforward
1418

1519
```@docs

DifferentiationInterface/docs/src/operators.md

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,17 @@ These operators are computed using only the input `x`.
3030

3131
### Low-level operators
3232

33-
These operators are computed using the input `x` and a "seed" `v`, which lives either
33+
These operators are computed using the input `x` and a tangent `t` of type [`Tangents`](@ref).
34+
This tangent is essentially an `NTuple`, whose elements live either
3435

35-
- in the same space as `x` (we call it `dx`)
36-
- or in the same space as `y` (we call it `dy`)
36+
- in the same space as `x` (we call it `tx`)
37+
- or in the same space as `y` (we call it `ty`)
3738

38-
| operator | order | input `x` | output `y` | seed `v` | operator result type | operator result shape |
39-
| :-------------------------- | :---- | :-------------- | :----------- | :------- | :------------------- | :-------------------- |
40-
| [`pushforward`](@ref) (JVP) | 1 | `Any` | `Any` | `dx` | same as `y` | `size(y)` |
41-
| [`pullback`](@ref) (VJP) | 1 | `Any` | `Any` | `dy` | same as `x` | `size(x)` |
42-
| [`hvp`](@ref) | 2 | `AbstractArray` | `Number` | `dx` | same as `x` | `size(x)` |
39+
| operator | order | input `x` | output `y` | tangent `t` | operator result type | operator result shape |
40+
| :-------------------------- | :---- | :-------------- | :----------- | :---------- | :------------------- | :-------------------- |
41+
| [`pushforward`](@ref) (JVP) | 1 | `Any` | `Any` | `tx` | same as `y` | `size(y)` |
42+
| [`pullback`](@ref) (VJP) | 1 | `Any` | `Any` | `ty` | same as `x` | `size(x)` |
43+
| [`hvp`](@ref) | 2 | `AbstractArray` | `Number` | `tx` | same as `x` | `size(x)` |
4344

4445
## Variants
4546

@@ -73,8 +74,8 @@ This results in various operator signatures (the necessary arguments and their o
7374

7475
| function signature | out-of-place operator | in-place operator |
7576
| :-------------------- | :--------------------------- | :------------------------------------ |
76-
| out-of-place function | `op(f, backend, x, [v])` | `op!(f, result, backend, x, [v])` |
77-
| in-place function | `op(f!, y, backend, x, [v])` | `op!(f!, y, result, backend, x, [v])` |
77+
| out-of-place function | `op(f, backend, x, [t])` | `op!(f, result, backend, x, [t])` |
78+
| in-place function | `op(f!, y, backend, x, [t])` | `op!(f!, y, result, backend, x, [t])` |
7879

7980
!!! warning
8081
The positional arguments between `f`/`f!` and `backend` are always mutated.
@@ -103,15 +104,15 @@ In addition, the preparation syntax depends on the number of arguments accepted
103104

104105
| function signature | preparation signature |
105106
| :-------------------- | :----------------------------------- |
106-
| out-of-place function | `prepare_op(f, backend, x, [v])` |
107-
| in-place function | `prepare_op(f!, y, backend, x, [v])` |
107+
| out-of-place function | `prepare_op(f, backend, x, [t])` |
108+
| in-place function | `prepare_op(f!, y, backend, x, [t])` |
108109

109110
Preparation creates an object called `extras` which contains the the necessary information to speed up an operator and its variants.
110111
The idea is that you prepare only once, which can be costly, but then call the operator several times while reusing the same `extras`.
111112

112113
```julia
113-
op(f, backend, x, [v]) # slow because it includes preparation
114-
op(f, extras, backend, x, [v]) # fast because it skips preparation
114+
op(f, backend, x, [t]) # slow because it includes preparation
115+
op(f, extras, backend, x, [t]) # fast because it skips preparation
115116
```
116117

117118
!!! warning
@@ -124,9 +125,9 @@ Here are the general rules that we strive to implement:
124125

125126
| | different point | same point |
126127
| :------------------------ | :--------------------------------------- | :--------------------------------------- |
127-
| the output `extras` of... | `prepare_op(f, b, x)` | `prepare_op_same_point(f, b, x, v)` |
128-
| can be used in... | `op(f, extras, b, other_x)` | `op(f, extras, b, x, other_v)` |
129-
| provided that... | `other_x` has same type and shape as `x` | `other_v` has same type and shape as `v` |
128+
| the output `extras` of... | `prepare_op(f, b, x)` | `prepare_op_same_point(f, b, x, t)` |
129+
| can be used in... | `op(f, extras, b, other_x)` | `op(f, extras, b, x, other_t)` |
130+
| provided that... | `other_x` has same type and shape as `x` | `other_t` has same type and shape as `t` |
130131

131132
These rules hold for the majority of backends, but there are some exceptions: see [this page](@ref "Preparation") to know more.
132133

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
# forward mode unused for lack of implementations
2-
#=
3-
function ChainRulesCore.frule((_, dx), dw::DifferentiateWith, x)
4-
@compat (; f, backend) = dw
5-
y, dy = DI.value_and_pushforward(f, backend, x, dx)
6-
return y, dy
7-
end
8-
=#
9-
101
function ChainRulesCore.rrule(dw::DifferentiateWith, x)
112
@compat (; f, backend) = dw
123
y = f(x)
13-
extras_same = DI.prepare_pullback_same_point(f, backend, x, y)
14-
pullbackfunc(dy) = (NoTangent(), DI.pullback(f, extras_same, backend, x, dy))
4+
extras_same = DI.prepare_pullback_same_point(f, backend, x, DI.Tangents(y))
5+
function pullbackfunc(dy)
6+
tx = DI.pullback(f, extras_same, backend, x, DI.Tangents(dy))
7+
return (NoTangent(), only(tx))
8+
end
159
return y, pullbackfunc
1610
end

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,28 @@ function DI.value_and_pullback(
2020
)
2121
rc = ruleconfig(backend)
2222
y, pb = rrule_via_ad(rc, f, x)
23-
return y, Tangents(last.(pb.(ty.d)))
23+
tx = map(ty) do dy
24+
last(pb(dy))
25+
end
26+
return y, tx
2427
end
2528

2629
function DI.value_and_pullback(
2730
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
2831
)
2932
@compat (; y, pb) = extras
30-
return copy(y), Tangents(last.(pb.(ty.d)))
33+
tx = map(ty) do dy
34+
last(pb(dy))
35+
end
36+
return copy(y), tx
3137
end
3238

3339
function DI.pullback(
3440
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
3541
)
3642
@compat (; pb) = extras
37-
return Tangents(last.(pb.(ty.d)))
43+
tx = map(ty) do dy
44+
last(pb(dy))
45+
end
46+
return tx
3847
end

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1414
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::Tangents) = NoPushforwardExtras()
1515

1616
function DI.pushforward(f, ::NoPushforwardExtras, ::AutoDiffractor, x, tx::Tangents)
17-
dys = map(tx.d) do dx
17+
ty = map(tx) do dx
1818
# code copied from Diffractor.jl
1919
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
2020
dy = z[TaylorTangentIndex(1)]
2121
end
22-
return Tangents(dys)
22+
return ty
2323
end
2424

2525
function DI.value_and_pushforward(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ using DifferentiationInterface:
1616
NoPullbackExtras,
1717
NoPushforwardExtras,
1818
Tangents,
19-
SingleTangent,
2019
pick_batchsize
2120
using Enzyme:
2221
Active,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ function DI.value_and_pushforward(
1313
x,
1414
tx::Tangents,
1515
)
16-
dys = map(tx.d) do dx
17-
DI.pushforward(f, extras, backend, x, dx)
16+
ty = map(tx) do dx
17+
only(DI.pushforward(f, extras, backend, x, Tangents(dx)))
1818
end
1919
y = f(x)
20-
return y, Tangents(dys)
20+
return y, ty
2121
end
2222

2323
function DI.value_and_pushforward(
@@ -36,7 +36,7 @@ function DI.value_and_pushforward(
3636
else
3737
autodiff(forward_mode(backend), f_and_df, Duplicated, x_and_dx)
3838
end
39-
return y, SingleTangent(new_dy)
39+
return y, Tangents(new_dy)
4040
end
4141

4242
function DI.pushforward(
@@ -55,7 +55,7 @@ function DI.pushforward(
5555
else
5656
only(autodiff(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx))
5757
end
58-
return SingleTangent(new_dy)
58+
return Tangents(new_dy)
5959
end
6060

6161
function DI.value_and_pushforward!(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ function DI.value_and_pushforward(
1414
x,
1515
tx::Tangents,
1616
)
17-
dys = map(tx.d) do dx
18-
DI.pushforward(f!, y, extras, backend, x, dx)
17+
ty = map(tx) do dx
18+
only(DI.pushforward(f!, y, extras, backend, x, Tangents(dx)))
1919
end
2020
f!(y, x)
21-
return y, Tangents(dys)
21+
return y, ty
2222
end
2323

2424
function DI.value_and_pushforward(
@@ -40,5 +40,5 @@ function DI.value_and_pushforward(
4040
else
4141
autodiff(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
4242
end
43-
return y, SingleTangent(dy_sametype)
43+
return y, Tangents(dy_sametype)
4444
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ function DI.value_and_pullback(
1313
x,
1414
ty::Tangents,
1515
)
16-
dxs = map(ty.d) do dy
17-
only(DI.pullback(f, extras, backend, x, SingleTangent(dy)))
16+
tx = map(ty) do dy
17+
only(DI.pullback(f, extras, backend, x, Tangents(dy)))
1818
end
1919
y = f(x)
20-
return y, Tangents(dxs)
20+
return y, tx
2121
end
2222

2323
### Out-of-place
@@ -38,7 +38,7 @@ function DI.value_and_pullback(
3838
autodiff(ReverseWithPrimal, f_and_df, Active, Active(x))
3939
end
4040
new_dx = dy * only(der)
41-
return y, SingleTangent(new_dx)
41+
return y, Tangents(new_dx)
4242
else
4343
dy = only(ty)
4444
f_and_df = force_annotation(get_f_and_df(f, backend))
@@ -51,7 +51,7 @@ function DI.value_and_pullback(
5151
tape, y, new_dy = forw(f_and_df, Active(x))
5252
copyto!(new_dy, dy)
5353
new_dx = only(only(rev(f_and_df, Active(x), tape)))
54-
return y, SingleTangent(new_dx)
54+
return y, Tangents(new_dx)
5555
end
5656
end
5757

@@ -76,10 +76,10 @@ function DI.value_and_pullback(
7676
# TODO: generalize beyond Arrays?
7777
dx_sametype .*= dy
7878
end
79-
return y, SingleTangent(dx_sametype)
79+
return y, Tangents(dx_sametype)
8080
else
8181
dx = make_zero(x)
82-
return DI.value_and_pullback!(f, SingleTangent(dx), extras, backend, x, ty)
82+
return DI.value_and_pullback!(f, Tangents(dx), extras, backend, x, ty)
8383
end
8484
end
8585

@@ -201,7 +201,8 @@ function DI.value_and_gradient(
201201
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
202202
x,
203203
)
204-
return DI.value_and_pullback(f, NoPullbackExtras(), backend, x, true)
204+
y, tx = DI.value_and_pullback(f, NoPullbackExtras(), backend, x, Tangents(true))
205+
return y, only(tx)
205206
end
206207

207208
function DI.value_and_gradient!(
@@ -211,7 +212,10 @@ function DI.value_and_gradient!(
211212
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
212213
x,
213214
)
214-
return DI.value_and_pullback!(f, grad, NoPullbackExtras(), backend, x, true)
215+
y, _ = DI.value_and_pullback!(
216+
f, Tangents(grad), NoPullbackExtras(), backend, x, Tangents(true)
217+
)
218+
return y, grad
215219
end
216220

217221
## Jacobian

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ function DI.value_and_pullback(
1414
x,
1515
ty::Tangents,
1616
)
17-
dxs = map(ty.d) do dy
18-
only(DI.pullback(f!, y, extras, backend, x, SingleTangent(dy)))
17+
tx = map(ty) do dy
18+
only(DI.pullback(f!, y, extras, backend, x, Tangents(dy)))
1919
end
2020
f!(y, x)
21-
return y, Tangents(dxs)
21+
return y, tx
2222
end
2323

2424
function DI.value_and_pullback(
@@ -38,7 +38,7 @@ function DI.value_and_pullback(
3838
else
3939
only(autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, Active(x)))
4040
end
41-
return y, SingleTangent(new_dx)
41+
return y, Tangents(new_dx)
4242
end
4343

4444
function DI.value_and_pullback(
@@ -60,5 +60,5 @@ function DI.value_and_pullback(
6060
else
6161
autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
6262
end
63-
return y, SingleTangent(dx_sametype)
63+
return y, Tangents(dx_sametype)
6464
end

0 commit comments

Comments
 (0)