11# # Pullback
22
3- struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
3+ struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
44 _sig:: Val{SIG}
55 cache:: Tcache
6- dy_righttype:: DY
76 args_to_zero:: NTuple{N, Bool}
87end
98
@@ -15,15 +14,13 @@ function DI.prepare_pullback_nokwarg(
1514 cache = prepare_pullback_cache (
1615 f, x, map (DI. unwrap, contexts)... ; config
1716 )
18- y = f (x, map (DI. unwrap, contexts)... )
19- dy_righttype = zero_tangent (y)
2017 contexts_tup_false = map (_ -> false , contexts)
2118 args_to_zero = (
2219 false , # f
2320 true , # x
2421 contexts_tup_false... ,
2522 )
26- prep = MooncakeOneArgPullbackPrep (_sig, cache, dy_righttype, args_to_zero)
23+ prep = MooncakeOneArgPullbackPrep (_sig, cache, args_to_zero)
2724 return prep
2825end
2926
@@ -37,10 +34,8 @@ function DI.value_and_pullback(
3734 ) where {F, Y, C}
3835 DI. check_prep (f, prep, backend, x, ty, contexts... )
3936 dy = only (ty)
40- dy_righttype = dy isa tangent_type (Y) ? dy : _copy_to_output!! (prep. dy_righttype, dy)
4137 new_y, (_, new_dx) = value_and_pullback!! (
42- prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)... ;
43- prep. args_to_zero
38+ prep. cache, dy, f, x, map (DI. unwrap, contexts)... ; prep. args_to_zero
4439 )
4540 return new_y, (_copy_output (new_dx),)
4641end
@@ -55,11 +50,8 @@ function DI.value_and_pullback(
5550 ) where {F, Y, C}
5651 DI. check_prep (f, prep, backend, x, ty, contexts... )
5752 ys_and_tx = map (ty) do dy
58- dy_righttype =
59- dy isa tangent_type (Y) ? dy : _copy_to_output!! (prep. dy_righttype, dy)
6053 y, (_, new_dx) = value_and_pullback!! (
61- prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)... ;
62- prep. args_to_zero
54+ prep. cache, dy, f, x, map (DI. unwrap, contexts)... ; prep. args_to_zero
6355 )
6456 y, _copy_output (new_dx)
6557 end
0 commit comments