11# # Pushforward
22
33function DI. prepare_pushforward (
4- f, :: AutoEnzyme{<:Union{ForwardMode,Nothing}} , x, tx:: Tangents
5- )
4+ f:: F ,
5+ :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
6+ x,
7+ tx:: Tangents ,
8+ contexts:: Vararg{Context,C} ,
9+ ) where {F,C}
610 return NoPushforwardExtras ()
711end
812
913function DI. value_and_pushforward (
10- f,
14+ f:: F ,
1115 :: NoPushforwardExtras ,
1216 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
1317 x,
1418 tx:: Tangents{1} ,
15- )
19+ contexts:: Vararg{Context,C} ,
20+ ) where {F,C}
1621 f_and_df = get_f_and_df (f, backend)
1722 dx_sametype = convert (typeof (x), only (tx))
1823 x_and_dx = Duplicated (x, dx_sametype)
19- dy, y = autodiff (forward_mode_withprimal (backend), f_and_df, x_and_dx)
24+ dy, y = autodiff (
25+ forward_mode_withprimal (backend), f_and_df, x_and_dx, map (translate, contexts)...
26+ )
2027 return y, Tangents (dy)
2128end
2229
2330function DI. value_and_pushforward (
24- f,
31+ f:: F ,
2532 :: NoPushforwardExtras ,
2633 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
2734 x,
2835 tx:: Tangents{B} ,
29- ) where {B}
36+ contexts:: Vararg{Context,C} ,
37+ ) where {F,B,C}
3038 f_and_df = get_f_and_df (f, backend, Val (B))
3139 dxs_sametype = map (Fix1 (convert, typeof (x)), tx. d)
3240 x_and_dxs = BatchDuplicated (x, dxs_sametype)
33- dys, y = autodiff (forward_mode_withprimal (backend), f_and_df, x_and_dxs)
41+ dys, y = autodiff (
42+ forward_mode_withprimal (backend), f_and_df, x_and_dxs, map (translate, contexts)...
43+ )
3444 return y, Tangents (dys... )
3545end
3646
3747function DI. pushforward (
38- f,
48+ f:: F ,
3949 :: NoPushforwardExtras ,
4050 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
4151 x,
4252 tx:: Tangents{1} ,
43- )
53+ contexts:: Vararg{Context,C} ,
54+ ) where {F,C}
4455 f_and_df = get_f_and_df (f, backend)
4556 dx_sametype = convert (typeof (x), only (tx))
4657 x_and_dx = Duplicated (x, dx_sametype)
47- dy = only (autodiff (forward_mode_noprimal (backend), f_and_df, x_and_dx))
58+ dy = only (
59+ autodiff (
60+ forward_mode_noprimal (backend), f_and_df, x_and_dx, map (translate, contexts)...
61+ ),
62+ )
4863 return Tangents (dy)
4964end
5065
5166function DI. pushforward (
52- f,
67+ f:: F ,
5368 :: NoPushforwardExtras ,
5469 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
5570 x,
5671 tx:: Tangents{B} ,
57- ) where {B}
72+ contexts:: Vararg{Context,C} ,
73+ ) where {F,B,C}
5874 f_and_df = get_f_and_df (f, backend, Val (B))
5975 dxs_sametype = map (Fix1 (convert, typeof (x)), tx. d)
6076 x_and_dxs = BatchDuplicated (x, dxs_sametype)
61- dys = only (autodiff (forward_mode_noprimal (backend), f_and_df, x_and_dxs))
77+ dys = only (
78+ autodiff (
79+ forward_mode_noprimal (backend), f_and_df, x_and_dxs, map (translate, contexts)...
80+ ),
81+ )
6282 return Tangents (dys... )
6383end
6484
6585function DI. value_and_pushforward! (
66- f,
86+ f:: F ,
6787 ty:: Tangents ,
6888 extras:: NoPushforwardExtras ,
6989 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
7090 x,
7191 tx:: Tangents ,
72- )
92+ contexts:: Vararg{Context,C} ,
93+ ) where {F,C}
7394 # dy cannot be passed anyway
74- y, new_ty = DI. value_and_pushforward (f, extras, backend, x, tx)
95+ y, new_ty = DI. value_and_pushforward (f, extras, backend, x, tx, contexts ... )
7596 return y, copyto! (ty, new_ty)
7697end
7798
7899function DI. pushforward! (
79- f,
100+ f:: F ,
80101 ty:: Tangents ,
81102 extras:: NoPushforwardExtras ,
82103 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
83104 x,
84105 tx:: Tangents ,
85- )
106+ contexts:: Vararg{Context,C} ,
107+ ) where {F,C}
86108 # dy cannot be passed anyway
87- return copyto! (ty, DI. pushforward (f, extras, backend, x, tx))
109+ return copyto! (ty, DI. pushforward (f, extras, backend, x, tx, contexts ... ))
88110end
89111
90112# # Gradient
@@ -94,19 +116,19 @@ struct EnzymeForwardGradientExtras{B,O} <: GradientExtras
94116end
95117
96118function DI. prepare_gradient (
97- f, backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} , x
98- )
119+ f:: F , backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} , x
120+ ) where {F}
99121 B = pick_batchsize (backend, length (x))
100122 shadows = create_shadows (Val (B), x)
101123 return EnzymeForwardGradientExtras {B,typeof(shadows)} (shadows)
102124end
103125
104126function DI. gradient (
105- f,
127+ f:: F ,
106128 extras:: EnzymeForwardGradientExtras{B} ,
107129 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
108130 x,
109- ) where {B}
131+ ) where {F, B}
110132 f_and_df = get_f_and_df (f, backend)
111133 derivs = gradient (
112134 forward_mode_noprimal (backend), f_and_df, x; chunk= Val (B), shadows= extras. shadows
@@ -115,11 +137,11 @@ function DI.gradient(
115137end
116138
117139function DI. value_and_gradient (
118- f,
140+ f:: F ,
119141 extras:: EnzymeForwardGradientExtras{B} ,
120142 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
121143 x,
122- ) where {B}
144+ ) where {F, B}
123145 f_and_df = get_f_and_df (f, backend)
124146 (; derivs, val) = gradient (
125147 forward_mode_withprimal (backend), f_and_df, x; chunk= Val (B), shadows= extras. shadows
@@ -128,22 +150,22 @@ function DI.value_and_gradient(
128150end
129151
130152function DI. gradient! (
131- f,
153+ f:: F ,
132154 grad,
133155 extras:: EnzymeForwardGradientExtras{B} ,
134156 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
135157 x,
136- ) where {B}
158+ ) where {F, B}
137159 return copyto! (grad, DI. gradient (f, extras, backend, x))
138160end
139161
140162function DI. value_and_gradient! (
141- f,
163+ f:: F ,
142164 grad,
143165 extras:: EnzymeForwardGradientExtras{B} ,
144166 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
145167 x,
146- ) where {B}
168+ ) where {F, B}
147169 y, new_grad = DI. value_and_gradient (f, extras, backend, x)
148170 return y, copyto! (grad, new_grad)
149171end
@@ -156,20 +178,20 @@ struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras
156178end
157179
158180function DI. prepare_jacobian (
159- f, backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} , x
160- )
181+ f:: F , backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} , x
182+ ) where {F}
161183 y = f (x)
162184 B = pick_batchsize (backend, length (x))
163185 shadows = create_shadows (Val (B), x)
164186 return EnzymeForwardOneArgJacobianExtras {B,typeof(shadows)} (shadows, length (y))
165187end
166188
167189function DI. jacobian (
168- f,
190+ f:: F ,
169191 extras:: EnzymeForwardOneArgJacobianExtras{B} ,
170192 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
171193 x,
172- ) where {B}
194+ ) where {F, B}
173195 f_and_df = get_f_and_df (f, backend)
174196 derivs = jacobian (
175197 forward_mode_noprimal (backend), f_and_df, x; chunk= Val (B), shadows= extras. shadows
@@ -179,11 +201,11 @@ function DI.jacobian(
179201end
180202
181203function DI. value_and_jacobian (
182- f,
204+ f:: F ,
183205 extras:: EnzymeForwardOneArgJacobianExtras{B} ,
184206 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
185207 x,
186- ) where {B}
208+ ) where {F, B}
187209 f_and_df = get_f_and_df (f, backend)
188210 (; derivs, val) = jacobian (
189211 forward_mode_withprimal (backend), f_and_df, x; chunk= Val (B), shadows= extras. shadows
@@ -193,22 +215,22 @@ function DI.value_and_jacobian(
193215end
194216
195217function DI. jacobian! (
196- f,
218+ f:: F ,
197219 jac,
198220 extras:: EnzymeForwardOneArgJacobianExtras ,
199221 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
200222 x,
201- )
223+ ) where {F}
202224 return copyto! (jac, DI. jacobian (f, extras, backend, x))
203225end
204226
205227function DI. value_and_jacobian! (
206- f,
228+ f:: F ,
207229 jac,
208230 extras:: EnzymeForwardOneArgJacobianExtras ,
209231 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
210232 x,
211- )
233+ ) where {F}
212234 y, new_jac = DI. value_and_jacobian (f, extras, backend, x)
213235 return y, copyto! (jac, new_jac)
214236end
0 commit comments