@@ -151,26 +151,58 @@ function _pullback_via_pushforward(
151151 f:: F ,
152152 pushforward_prep:: PushforwardPrep ,
153153 backend:: AbstractADType ,
154- x:: Number ,
154+ x:: Real ,
155155 dy,
156156 contexts:: Vararg{Context,C} ,
157157) where {F,C}
158- t1 = pushforward (f, pushforward_prep, backend, x, (one (x),), contexts... )
159- dx = dot (only (t1) , dy)
158+ a = only ( pushforward (f, pushforward_prep, backend, x, (one (x),), contexts... ) )
159+ dx = dot (a , dy)
160160 return dx
161161end
162162
163163function _pullback_via_pushforward (
164164 f:: F ,
165165 pushforward_prep:: PushforwardPrep ,
166166 backend:: AbstractADType ,
167- x:: AbstractArray ,
167+ x:: Complex ,
168+ dy,
169+ contexts:: Vararg{Context,C} ,
170+ ) where {F,C}
171+ a = only (pushforward (f, pushforward_prep, backend, x, (one (x),), contexts... ))
172+ b = only (pushforward (f, pushforward_prep, backend, x, (im * one (x),), contexts... ))
173+ dx = real (dot (a, dy)) + im * real (dot (b, dy))
174+ return dx
175+ end
176+
177+ function _pullback_via_pushforward (
178+ f:: F ,
179+ pushforward_prep:: PushforwardPrep ,
180+ backend:: AbstractADType ,
181+ x:: AbstractArray{<:Real} ,
168182 dy,
169183 contexts:: Vararg{Context,C} ,
170184) where {F,C}
171185 dx = map (CartesianIndices (x)) do j
172- t1 = pushforward (f, pushforward_prep, backend, x, (basis (x, j),), contexts... )
173- dot (only (t1), dy)
186+ a = only (pushforward (f, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
187+ dot (a, dy)
188+ end
189+ return dx
190+ end
191+
192+ function _pullback_via_pushforward (
193+ f:: F ,
194+ pushforward_prep:: PushforwardPrep ,
195+ backend:: AbstractADType ,
196+ x:: AbstractArray{<:Complex} ,
197+ dy,
198+ contexts:: Vararg{Context,C} ,
199+ ) where {F,C}
200+ dx = map (CartesianIndices (x)) do j
201+ a = only (pushforward (f, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
202+ b = only (
203+ pushforward (f, pushforward_prep, backend, x, (im * basis (x, j),), contexts... ),
204+ )
205+ real (dot (a, dy)) + im * real (dot (b, dy))
174206 end
175207 return dx
176208end
@@ -236,12 +268,43 @@ function _pullback_via_pushforward(
236268 y,
237269 pushforward_prep:: PushforwardPrep ,
238270 backend:: AbstractADType ,
239- x:: Number ,
271+ x:: Real ,
272+ dy,
273+ contexts:: Vararg{Context,C} ,
274+ ) where {F,C}
275+ a = only (pushforward (f!, y, pushforward_prep, backend, x, (one (x),), contexts... ))
276+ dx = dot (a, dy)
277+ return dx
278+ end
279+
280+ function _pullback_via_pushforward (
281+ f!:: F ,
282+ y,
283+ pushforward_prep:: PushforwardPrep ,
284+ backend:: AbstractADType ,
285+ x:: Complex ,
240286 dy,
241287 contexts:: Vararg{Context,C} ,
242288) where {F,C}
243- t1 = pushforward (f!, y, pushforward_prep, backend, x, (one (x),), contexts... )
244- dx = dot (only (t1), dy)
289+ a = only (pushforward (f!, y, pushforward_prep, backend, x, (one (x),), contexts... ))
290+ b = only (pushforward (f!, y, pushforward_prep, backend, x, (im * one (x),), contexts... ))
291+ dx = real (dot (a, dy)) + im * real (dot (b, dy))
292+ return dx
293+ end
294+
295+ function _pullback_via_pushforward (
296+ f!:: F ,
297+ y,
298+ pushforward_prep:: PushforwardPrep ,
299+ backend:: AbstractADType ,
300+ x:: AbstractArray{<:Real} ,
301+ dy,
302+ contexts:: Vararg{Context,C} ,
303+ ) where {F,C}
304+ dx = map (CartesianIndices (x)) do j # preserve shape
305+ a = only (pushforward (f!, y, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
306+ dot (a, dy)
307+ end
245308 return dx
246309end
247310
@@ -250,13 +313,18 @@ function _pullback_via_pushforward(
250313 y,
251314 pushforward_prep:: PushforwardPrep ,
252315 backend:: AbstractADType ,
253- x:: AbstractArray ,
316+ x:: AbstractArray{<:Complex} ,
254317 dy,
255318 contexts:: Vararg{Context,C} ,
256319) where {F,C}
257320 dx = map (CartesianIndices (x)) do j # preserve shape
258- t1 = pushforward (f!, y, pushforward_prep, backend, x, (basis (x, j),), contexts... )
259- dot (only (t1), dy)
321+ a = only (pushforward (f!, y, pushforward_prep, backend, x, (basis (x, j),), contexts... ))
322+ b = only (
323+ pushforward (
324+ f!, y, pushforward_prep, backend, x, (im * basis (x, j),), contexts...
325+ ),
326+ )
327+ real (dot (a, dy)) + im * real (dot (b, dy))
260328 end
261329 return dx
262330end
0 commit comments