@@ -4,7 +4,7 @@ function DI.prepare_pushforward(
44 f:: F ,
55 :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
66 x,
7- tx:: Tangents ,
7+ tx:: NTuple ,
88 contexts:: Vararg{Context,C} ,
99) where {F,C}
1010 return NoPushforwardPrep ()
@@ -15,7 +15,7 @@ function DI.value_and_pushforward(
1515 :: NoPushforwardPrep ,
1616 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
1717 x,
18- tx:: Tangents {1} ,
18+ tx:: NTuple {1} ,
1919 contexts:: Vararg{Context,C} ,
2020) where {F,C}
2121 f_and_df = get_f_and_df (f, backend)
@@ -24,32 +24,32 @@ function DI.value_and_pushforward(
2424 dy, y = autodiff (
2525 forward_mode_withprimal (backend), f_and_df, x_and_dx, map (translate, contexts)...
2626 )
27- return y, Tangents (dy)
27+ return y, (dy, )
2828end
2929
3030function DI. value_and_pushforward (
3131 f:: F ,
3232 :: NoPushforwardPrep ,
3333 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
3434 x,
35- tx:: Tangents {B} ,
35+ tx:: NTuple {B} ,
3636 contexts:: Vararg{Context,C} ,
3737) where {F,B,C}
3838 f_and_df = get_f_and_df (f, backend, Val (B))
39- dxs_sametype = map (Fix1 (convert, typeof (x)), tx. d )
40- x_and_dxs = BatchDuplicated (x, dxs_sametype )
41- dys , y = autodiff (
42- forward_mode_withprimal (backend), f_and_df, x_and_dxs , map (translate, contexts)...
39+ tx_sametype = map (Fix1 (convert, typeof (x)), tx)
40+ x_and_tx = BatchDuplicated (x, tx_sametype )
41+ ty , y = autodiff (
42+ forward_mode_withprimal (backend), f_and_df, x_and_tx , map (translate, contexts)...
4343 )
44- return y, Tangents (dys ... )
44+ return y, values (ty )
4545end
4646
4747function DI. pushforward (
4848 f:: F ,
4949 :: NoPushforwardPrep ,
5050 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
5151 x,
52- tx:: Tangents {1} ,
52+ tx:: NTuple {1} ,
5353 contexts:: Vararg{Context,C} ,
5454) where {F,C}
5555 f_and_df = get_f_and_df (f, backend)
@@ -60,53 +60,56 @@ function DI.pushforward(
6060 forward_mode_noprimal (backend), f_and_df, x_and_dx, map (translate, contexts)...
6161 ),
6262 )
63- return Tangents (dy)
63+ return (dy, )
6464end
6565
6666function DI. pushforward (
6767 f:: F ,
6868 :: NoPushforwardPrep ,
6969 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
7070 x,
71- tx:: Tangents {B} ,
71+ tx:: NTuple {B} ,
7272 contexts:: Vararg{Context,C} ,
7373) where {F,B,C}
7474 f_and_df = get_f_and_df (f, backend, Val (B))
75- dxs_sametype = map (Fix1 (convert, typeof (x)), tx. d )
76- x_and_dxs = BatchDuplicated (x, dxs_sametype )
77- dys = only (
75+ tx_sametype = map (Fix1 (convert, typeof (x)), tx)
76+ x_and_tx = BatchDuplicated (x, tx_sametype )
77+ ty = only (
7878 autodiff (
79- forward_mode_noprimal (backend), f_and_df, x_and_dxs , map (translate, contexts)...
79+ forward_mode_noprimal (backend), f_and_df, x_and_tx , map (translate, contexts)...
8080 ),
8181 )
82- return Tangents (dys ... )
82+ return values (ty )
8383end
8484
8585function DI. value_and_pushforward! (
8686 f:: F ,
87- ty:: Tangents ,
87+ ty:: NTuple ,
8888 prep:: NoPushforwardPrep ,
8989 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
9090 x,
91- tx:: Tangents ,
91+ tx:: NTuple ,
9292 contexts:: Vararg{Context,C} ,
9393) where {F,C}
9494 # dy cannot be passed anyway
9595 y, new_ty = DI. value_and_pushforward (f, prep, backend, x, tx, contexts... )
96- return y, copyto! (ty, new_ty)
96+ foreach (copyto!, ty, new_ty)
97+ return y, ty
9798end
9899
99100function DI. pushforward! (
100101 f:: F ,
101- ty:: Tangents ,
102+ ty:: NTuple ,
102103 prep:: NoPushforwardPrep ,
103104 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
104105 x,
105- tx:: Tangents ,
106+ tx:: NTuple ,
106107 contexts:: Vararg{Context,C} ,
107108) where {F,C}
108109 # dy cannot be passed anyway
109- return copyto! (ty, DI. pushforward (f, prep, backend, x, tx, contexts... ))
110+ new_ty = DI. pushforward (f, prep, backend, x, tx, contexts... )
111+ foreach (copyto!, ty, new_ty)
112+ return ty
110113end
111114
112115# # Gradient
0 commit comments