@@ -24,7 +24,7 @@ function DI.prepare_pullback_nokwarg(
2424 map (DI. unwrap, contexts)... ;
2525 config,
2626 )
27- dy_backup_after = zero_tangent (y )
27+ dy_backup = zero_tangent_or_primal (y, backend )
2828 contexts_tup_false = map (_ -> false , contexts)
2929 args_to_zero = (
3030 false , # call_and_return
@@ -34,7 +34,7 @@ function DI.prepare_pullback_nokwarg(
3434 contexts_tup_false... ,
3535 )
3636 prep = MooncakeTwoArgPullbackPrep (
37- _sig, cache, dy_backup_after , args_to_zero
37+ _sig, cache, dy_backup , args_to_zero
3838 )
3939 return prep
4040end
@@ -51,11 +51,11 @@ function DI.value_and_pullback(
5151 DI. check_prep (f!, y, prep, backend, x, ty, contexts... )
5252 dy = only (ty)
5353 # Prepare cotangent to add after the forward pass.
54- dy_backup_after = copyto! (prep. dy_backup, dy)
54+ dy_backup = copyto! (prep. dy_backup, dy)
5555 # Run the reverse-pass and return the results.
5656 y_after, (_, _, _, dx) = value_and_pullback!! (
5757 prep. cache,
58- dy_backup_after ,
58+ dy_backup ,
5959 call_and_return,
6060 f!,
6161 y,
@@ -78,10 +78,10 @@ function DI.value_and_pullback(
7878 ) where {F, C}
7979 DI. check_prep (f!, y, prep, backend, x, ty, contexts... )
8080 tx = map (ty) do dy
81- dy_backup_after = copyto! (prep. dy_backup, dy)
81+ dy_backup = copyto! (prep. dy_backup, dy)
8282 y_after, (_, _, _, dx) = value_and_pullback!! (
8383 prep. cache,
84- dy_backup_after ,
84+ dy_backup ,
8585 call_and_return,
8686 f!,
8787 y,
0 commit comments