Skip to content

Commit 22bd82e

Browse files
committed
fix: use zero as example
1 parent 35d9a60 commit 22bd82e

2 files changed

Lines changed: 4 additions & 28 deletions

File tree

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -285,13 +285,7 @@ function _prepare_pullback_aux(
285285
contexts::Vararg{Context,C};
286286
) where {F,C}
287287
_sig = signature(f, backend, x, ty, contexts...; strict)
288-
dx = if x isa Number
289-
oneunit(x)
290-
elseif isempty(x)
291-
zero(x)
292-
else
293-
basis(x, first(CartesianIndices(x)))
294-
end
288+
dx = zero(x)
295289
pushforward_prep = prepare_pushforward_nokwarg(
296290
strict, f, backend, x, (dx,), contexts...
297291
)
@@ -309,13 +303,7 @@ function _prepare_pullback_aux(
309303
contexts::Vararg{Context,C};
310304
) where {F,C}
311305
_sig = signature(f!, y, backend, x, ty, contexts...; strict)
312-
dx = if x isa Number
313-
oneunit(x)
314-
elseif isempty(x)
315-
zero(x)
316-
else
317-
basis(x, first(CartesianIndices(x)))
318-
end
306+
dx = zero(x)
319307
pushforward_prep = prepare_pushforward_nokwarg(
320308
strict, f!, y, backend, x, (dx,), contexts...
321309
)

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,7 @@ function _prepare_pushforward_aux(
290290
) where {F,C}
291291
_sig = signature(f, backend, x, tx, contexts...; strict)
292292
y = f(x, map(unwrap, contexts)...)
293-
dy = if y isa Number
294-
oneunit(y)
295-
elseif isempty(y)
296-
zero(y)
297-
else
298-
basis(y, first(CartesianIndices(y)))
299-
end
293+
dy = zero(y)
300294
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
301295
return PullbackPushforwardPrep(_sig, pullback_prep)
302296
end
@@ -312,13 +306,7 @@ function _prepare_pushforward_aux(
312306
contexts::Vararg{Context,C};
313307
) where {F,C}
314308
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
315-
dy = if y isa Number
316-
oneunit(y)
317-
elseif isempty(y)
318-
zero(y)
319-
else
320-
basis(y, first(CartesianIndices(y)))
321-
end
309+
dy = zero(y)
322310
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
323311
return PullbackPushforwardPrep(_sig, pullback_prep)
324312
end

0 commit comments

Comments
 (0)