Skip to content

Commit 7c97dd8

Browse files
committed
fix: respect array type in wrong-mode pushforward/pullback
1 parent 9a524d3 commit 7c97dd8

2 files changed

Lines changed: 22 additions & 20 deletions

File tree

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ function _pullback_via_pushforward(
167167
dy,
168168
contexts::Vararg{Context,C},
169169
) where {F,C}
170-
t1 = pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...)
171-
dx = dot(dy, only(t1))
170+
t = pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...)
171+
dx = dot(dy, only(t))
172172
return dx
173173
end
174174

@@ -180,9 +180,10 @@ function _pullback_via_pushforward(
180180
dy,
181181
contexts::Vararg{Context,C},
182182
) where {F,C}
183-
dx = map(CartesianIndices(x)) do j
184-
t1 = pushforward(f, pushforward_prep, backend, x, (basis(backend, x, j),), contexts...)
185-
dot(dy, only(t1))
183+
dx = map(x, CartesianIndices(x)) do xj, j
184+
bj = basis(backend, x, j)
185+
tj = pushforward(f, pushforward_prep, backend, x, (bj,), contexts...)
186+
dot(dy, only(tj))
186187
end
187188
return dx
188189
end
@@ -252,8 +253,8 @@ function _pullback_via_pushforward(
252253
dy,
253254
contexts::Vararg{Context,C},
254255
) where {F,C}
255-
t1 = pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...)
256-
dx = dot(dy, only(t1))
256+
t = pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...)
257+
dx = dot(dy, only(t))
257258
return dx
258259
end
259260

@@ -266,11 +267,10 @@ function _pullback_via_pushforward(
266267
dy,
267268
contexts::Vararg{Context,C},
268269
) where {F,C}
269-
dx = map(CartesianIndices(x)) do j # preserve shape
270-
t1 = pushforward(
271-
f!, y, pushforward_prep, backend, x, (basis(backend, x, j),), contexts...
272-
)
273-
dot(dy, only(t1))
270+
dx = map(x, CartesianIndices(x)) do xj, j # preserve shape
271+
bj = basis(backend, x, j)
272+
tj = pushforward(f!, y, pushforward_prep, backend, x, (bj,), contexts...)
273+
dot(dy, only(tj))
274274
end
275275
return dx
276276
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ function _pushforward_via_pullback(
175175
dx,
176176
contexts::Vararg{Context,C},
177177
) where {F,C}
178-
t1 = pullback(f, pullback_prep, backend, x, (one(y),), contexts...)
179-
dy = dot(dx, only(t1))
178+
t = pullback(f, pullback_prep, backend, x, (one(y),), contexts...)
179+
dy = dot(dx, only(t))
180180
return dy
181181
end
182182

@@ -189,9 +189,10 @@ function _pushforward_via_pullback(
189189
dx,
190190
contexts::Vararg{Context,C},
191191
) where {F,C}
192-
dy = map(CartesianIndices(y)) do i
193-
t1 = pullback(f, pullback_prep, backend, x, (basis(backend, y, i),), contexts...)
194-
dot(dx, only(t1))
192+
dy = map(y, CartesianIndices(y)) do yi, i
193+
bi = basis(backend, y, i)
194+
ti = pullback(f, pullback_prep, backend, x, (bi,), contexts...)
195+
dot(dx, only(ti))
195196
end
196197
return dy
197198
end
@@ -261,9 +262,10 @@ function _pushforward_via_pullback(
261262
dx,
262263
contexts::Vararg{Context,C},
263264
) where {F,C}
264-
dy = map(CartesianIndices(y)) do i # preserve shape
265-
t1 = pullback(f!, y, pullback_prep, backend, x, (basis(backend, y, i),), contexts...)
266-
dot(dx, only(t1))
265+
dy = map(y, CartesianIndices(y)) do yi, i # preserve shape
266+
bi = basis(backend, y, i)
267+
ti = pullback(f!, y, pullback_prep, backend, x, (bi,), contexts...)
268+
dot(dx, only(ti))
267269
end
268270
return dy
269271
end

0 commit comments

Comments
 (0)