Skip to content

Commit 959f634

Browse files
authored
Replace Tangents with NTuple (#501)
* Replace Tangents with NTuple * Improve Mooncake * Second der type stability * Values for Enzyme * Fixes * Fix * Fix type params * Fix
1 parent 1307890 commit 959f634

64 files changed

Lines changed: 706 additions & 934 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DifferentiationInterface/docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ DifferentiationInterface
1313
```@docs
1414
Context
1515
Constant
16-
Tangents
1716
```
1817

1918
## First order

DifferentiationInterface/docs/src/explanation/operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ These operators are computed using only the input `x`.
3333

3434
### Low-level operators
3535

36-
These operators are computed using the input `x` and another argument `t` of type [`Tangents`](@ref), which contains one or more tangents.
36+
These operators are computed using the input `x` and another argument `t` of type `NTuple`, which contains one or more tangents.
3737
You can think of tangents as perturbations propagated through the function; they live either in the same space as `x` or in the same space as `y`.
3838

3939
| operator | order | input `x` | output `y` | element type of `t` | operator result type | operator result shape |

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,7 @@ using ChainRulesCore:
1212
using Compat
1313
import DifferentiationInterface as DI
1414
using DifferentiationInterface:
15-
Constant,
16-
DifferentiateWith,
17-
NoPullbackPrep,
18-
NoPushforwardPrep,
19-
PullbackPrep,
20-
Tangents,
21-
unwrap
15+
Constant, DifferentiateWith, NoPullbackPrep, NoPushforwardPrep, PullbackPrep, unwrap
2216

2317
ruleconfig(backend::AutoChainRules) = backend.ruleconfig
2418

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
function ChainRulesCore.rrule(dw::DifferentiateWith, x)
22
@compat (; f, backend) = dw
33
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point(f, backend, x, DI.Tangents(y))
4+
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,))
55
function pullbackfunc(dy)
6-
tx = DI.pullback(f, prep_same, backend, x, DI.Tangents(dy))
6+
tx = DI.pullback(f, prep_same, backend, x, (dy,))
77
return (NoTangent(), only(tx))
88
end
99
return y, pullbackfunc

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: PullbackPrep
66
end
77

88
function DI.prepare_pullback(
9-
f, ::AutoReverseChainRules, x, ty::Tangents, contexts::Vararg{Constant,C}
9+
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{Constant,C}
1010
) where {C}
1111
return NoPullbackPrep()
1212
end
@@ -16,7 +16,7 @@ function DI.prepare_pullback_same_point(
1616
::NoPullbackPrep,
1717
backend::AutoReverseChainRules,
1818
x,
19-
ty::Tangents,
19+
ty::NTuple,
2020
contexts::Vararg{Constant,C},
2121
) where {C}
2222
rc = ruleconfig(backend)
@@ -29,7 +29,7 @@ function DI.value_and_pullback(
2929
::NoPullbackPrep,
3030
backend::AutoReverseChainRules,
3131
x,
32-
ty::Tangents,
32+
ty::NTuple,
3333
contexts::Vararg{Constant,C},
3434
) where {C}
3535
rc = ruleconfig(backend)
@@ -45,7 +45,7 @@ function DI.value_and_pullback(
4545
prep::ChainRulesPullbackPrepSamePoint,
4646
::AutoReverseChainRules,
4747
x,
48-
ty::Tangents,
48+
ty::NTuple,
4949
contexts::Vararg{Constant,C},
5050
) where {C}
5151
@compat (; y, pb) = prep
@@ -60,7 +60,7 @@ function DI.pullback(
6060
prep::ChainRulesPullbackPrepSamePoint,
6161
::AutoReverseChainRules,
6262
x,
63-
ty::Tangents,
63+
ty::NTuple,
6464
contexts::Vararg{Constant,C},
6565
) where {C}
6666
@compat (; pb) = prep

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DifferentiationInterfaceDiffractorExt
22

33
using ADTypes: ADTypes, AutoDiffractor
44
import DifferentiationInterface as DI
5-
using DifferentiationInterface: NoPushforwardPrep, Tangents
5+
using DifferentiationInterface: NoPushforwardPrep
66
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆
77

88
DI.check_available(::AutoDiffractor) = true
@@ -11,19 +11,20 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1111

1212
## Pushforward
1313

14-
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::Tangents) = NoPushforwardPrep()
14+
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = NoPushforwardPrep()
1515

16-
function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::Tangents)
16+
function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
1717
ty = map(tx) do dx
1818
# code copied from Diffractor.jl
1919
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
2020
dy = z[TaylorTangentIndex(1)]
21+
dy
2122
end
2223
return ty
2324
end
2425

2526
function DI.value_and_pushforward(
26-
f, prep::NoPushforwardPrep, backend::AutoDiffractor, x, tx::Tangents
27+
f, prep::NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
2728
)
2829
return f(x), DI.pushforward(f, prep, backend, x, tx)
2930
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ using DifferentiationInterface:
1717
NoJacobianPrep,
1818
NoPullbackPrep,
1919
NoPushforwardPrep,
20-
Tangents,
2120
pick_batchsize
2221
using Enzyme:
2322
Active,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function DI.prepare_pushforward(
44
f::F,
55
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
66
x,
7-
tx::Tangents,
7+
tx::NTuple,
88
contexts::Vararg{Context,C},
99
) where {F,C}
1010
return NoPushforwardPrep()
@@ -15,7 +15,7 @@ function DI.value_and_pushforward(
1515
::NoPushforwardPrep,
1616
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1717
x,
18-
tx::Tangents{1},
18+
tx::NTuple{1},
1919
contexts::Vararg{Context,C},
2020
) where {F,C}
2121
f_and_df = get_f_and_df(f, backend)
@@ -24,32 +24,32 @@ function DI.value_and_pushforward(
2424
dy, y = autodiff(
2525
forward_mode_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
2626
)
27-
return y, Tangents(dy)
27+
return y, (dy,)
2828
end
2929

3030
function DI.value_and_pushforward(
3131
f::F,
3232
::NoPushforwardPrep,
3333
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
3434
x,
35-
tx::Tangents{B},
35+
tx::NTuple{B},
3636
contexts::Vararg{Context,C},
3737
) where {F,B,C}
3838
f_and_df = get_f_and_df(f, backend, Val(B))
39-
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
40-
x_and_dxs = BatchDuplicated(x, dxs_sametype)
41-
dys, y = autodiff(
42-
forward_mode_withprimal(backend), f_and_df, x_and_dxs, map(translate, contexts)...
39+
tx_sametype = map(Fix1(convert, typeof(x)), tx)
40+
x_and_tx = BatchDuplicated(x, tx_sametype)
41+
ty, y = autodiff(
42+
forward_mode_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
4343
)
44-
return y, Tangents(dys...)
44+
return y, values(ty)
4545
end
4646

4747
function DI.pushforward(
4848
f::F,
4949
::NoPushforwardPrep,
5050
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
5151
x,
52-
tx::Tangents{1},
52+
tx::NTuple{1},
5353
contexts::Vararg{Context,C},
5454
) where {F,C}
5555
f_and_df = get_f_and_df(f, backend)
@@ -60,53 +60,56 @@ function DI.pushforward(
6060
forward_mode_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
6161
),
6262
)
63-
return Tangents(dy)
63+
return (dy,)
6464
end
6565

6666
function DI.pushforward(
6767
f::F,
6868
::NoPushforwardPrep,
6969
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
7070
x,
71-
tx::Tangents{B},
71+
tx::NTuple{B},
7272
contexts::Vararg{Context,C},
7373
) where {F,B,C}
7474
f_and_df = get_f_and_df(f, backend, Val(B))
75-
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
76-
x_and_dxs = BatchDuplicated(x, dxs_sametype)
77-
dys = only(
75+
tx_sametype = map(Fix1(convert, typeof(x)), tx)
76+
x_and_tx = BatchDuplicated(x, tx_sametype)
77+
ty = only(
7878
autodiff(
79-
forward_mode_noprimal(backend), f_and_df, x_and_dxs, map(translate, contexts)...
79+
forward_mode_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
8080
),
8181
)
82-
return Tangents(dys...)
82+
return values(ty)
8383
end
8484

8585
function DI.value_and_pushforward!(
8686
f::F,
87-
ty::Tangents,
87+
ty::NTuple,
8888
prep::NoPushforwardPrep,
8989
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
9090
x,
91-
tx::Tangents,
91+
tx::NTuple,
9292
contexts::Vararg{Context,C},
9393
) where {F,C}
9494
# dy cannot be passed anyway
9595
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
96-
return y, copyto!(ty, new_ty)
96+
foreach(copyto!, ty, new_ty)
97+
return y, ty
9798
end
9899

99100
function DI.pushforward!(
100101
f::F,
101-
ty::Tangents,
102+
ty::NTuple,
102103
prep::NoPushforwardPrep,
103104
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
104105
x,
105-
tx::Tangents,
106+
tx::NTuple,
106107
contexts::Vararg{Context,C},
107108
) where {F,C}
108109
# dy cannot be passed anyway
109-
return copyto!(ty, DI.pushforward(f, prep, backend, x, tx, contexts...))
110+
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
111+
foreach(copyto!, ty, new_ty)
112+
return ty
110113
end
111114

112115
## Gradient

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function DI.prepare_pushforward(
55
y,
66
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
77
x,
8-
tx::Tangents,
8+
tx::NTuple,
99
contexts::Vararg{Context,C},
1010
) where {F,C}
1111
return NoPushforwardPrep()
@@ -17,7 +17,7 @@ function DI.value_and_pushforward(
1717
::NoPushforwardPrep,
1818
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1919
x,
20-
tx::Tangents{1},
20+
tx::NTuple{1},
2121
contexts::Vararg{Context,C},
2222
) where {F,C}
2323
f!_and_df! = get_f_and_df(f!, backend)
@@ -33,7 +33,7 @@ function DI.value_and_pushforward(
3333
x_and_dx,
3434
map(translate, contexts)...,
3535
)
36-
return y, Tangents(dy_sametype)
36+
return y, (dy_sametype,)
3737
end
3838

3939
function DI.value_and_pushforward(
@@ -42,21 +42,64 @@ function DI.value_and_pushforward(
4242
::NoPushforwardPrep,
4343
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
4444
x,
45-
tx::Tangents{B},
45+
tx::NTuple{B},
4646
contexts::Vararg{Context,C},
4747
) where {F,B,C}
4848
f!_and_df! = get_f_and_df(f!, backend, Val(B))
49-
dxs_sametype = map(Fix1(convert, typeof(x)), tx.d)
50-
dys_sametype = ntuple(_ -> make_zero(y), Val(B))
51-
x_and_dxs = BatchDuplicated(x, dxs_sametype)
52-
y_and_dys = BatchDuplicated(y, dys_sametype)
49+
tx_sametype = map(Fix1(convert, typeof(x)), tx)
50+
ty_sametype = ntuple(_ -> make_zero(y), Val(B))
51+
x_and_tx = BatchDuplicated(x, tx_sametype)
52+
y_and_ty = BatchDuplicated(y, ty_sametype)
5353
autodiff(
5454
forward_mode_noprimal(backend),
5555
f!_and_df!,
5656
Const,
57-
y_and_dys,
58-
x_and_dxs,
57+
y_and_ty,
58+
x_and_tx,
5959
map(translate, contexts)...,
6060
)
61-
return y, Tangents(dys_sametype...)
61+
return y, ty_sametype
62+
end
63+
64+
function DI.pushforward(
65+
f!::F,
66+
y,
67+
prep::NoPushforwardPrep,
68+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
69+
x,
70+
tx::NTuple,
71+
contexts::Vararg{Context,C},
72+
) where {F,C}
73+
_, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
74+
return ty
75+
end
76+
77+
function DI.value_and_pushforward!(
78+
f!::F,
79+
y,
80+
ty::NTuple,
81+
prep::NoPushforwardPrep,
82+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
83+
x,
84+
tx::NTuple,
85+
contexts::Vararg{Context,C},
86+
) where {F,C}
87+
y, new_ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
88+
foreach(copyto!, ty, new_ty)
89+
return y, ty
90+
end
91+
92+
function DI.pushforward!(
93+
f!::F,
94+
y,
95+
ty::NTuple,
96+
prep::NoPushforwardPrep,
97+
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
98+
x,
99+
tx::NTuple,
100+
contexts::Vararg{Context,C},
101+
) where {F,C}
102+
new_ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...)
103+
foreach(copyto!, ty, new_ty)
104+
return ty
62105
end

0 commit comments

Comments
 (0)