We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8ac3c79 commit 0955815Copy full SHA for 0955815
2 files changed
DifferentiationInterface/src/first_order/pullback.jl
@@ -168,9 +168,11 @@ function _pullback_via_pushforward(
168
dy,
169
contexts::Vararg{Context,C},
170
) where {F,C}
171
- dx = map(x, CartesianIndices(x)) do xj, j
+ ind = CartesianIndices(x)
172
+ T = typeof(similar(x, eltype(ind)))
173
+ dx = map(x, T(ind)) do xj, j
174
t1 = pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)
- convert(eltype(x), dot(only(t1), dy))
175
+ dot(only(t1), dy)
176
end
177
return dx
178
@@ -254,9 +256,11 @@ function _pullback_via_pushforward(
254
256
255
257
258
- dx = map(x, CartesianIndices(x)) do xj, j # preserve shape
259
260
261
+ dx = map(x, T(ind)) do xj, j # preserve shape
262
t1 = pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)
263
264
265
266
DifferentiationInterface/src/first_order/pushforward.jl
@@ -171,9 +171,11 @@ function _pushforward_via_pullback(
dx,
- dy = map(y, CartesianIndices(y)) do yi, i
+ ind = CartesianIndices(y)
+ T = typeof(similar(y, eltype(ind)))
+ dy = map(y, T(ind)) do yi, i
t1 = pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)
- convert(eltype(y), dot(only(t1), dx))
+ dot(only(t1), dx)
179
180
return dy
181
@@ -243,9 +245,11 @@ function _pushforward_via_pullback(
243
245
244
246
247
- dy = map(y, CartesianIndices(y)) do yi, i # preserve shape
248
249
250
+ dy = map(y, T(ind)) do yi, i # preserve shape
251
t1 = pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)
252
253
0 commit comments