11# # Pullback
22
3- struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY} <: DI.PullbackPrep{SIG}
3+ struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N } <: DI.PullbackPrep{SIG}
44 _sig:: Val{SIG}
55 cache:: Tcache
66 dy_righttype:: DY
7+ args_to_zero:: NTuple{N, Bool}
78end
89
910function DI. prepare_pullback_nokwarg (
@@ -16,7 +17,12 @@ function DI.prepare_pullback_nokwarg(
1617 )
1718 y = f (x, map (DI. unwrap, contexts)... )
1819 dy_righttype = zero_tangent (y)
19- prep = MooncakeOneArgPullbackPrep (_sig, cache, dy_righttype)
20+ args_to_zero = (
21+ false , # f
22+ true , # x
23+ map (_ -> false , contexts)... ,
24+ )
25+ prep = MooncakeOneArgPullbackPrep (_sig, cache, dy_righttype, args_to_zero)
2026 return prep
2127end
2228
@@ -32,7 +38,8 @@ function DI.value_and_pullback(
3238 dy = only (ty)
3339 dy_righttype = dy isa tangent_type (Y) ? dy : _copy_to_output!! (prep. dy_righttype, dy)
3440 new_y, (_, new_dx) = value_and_pullback!! (
35- prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)...
41+ prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)... ;
42+ prep. args_to_zero
3643 )
3744 return new_y, (_copy_output (new_dx),)
3845end
@@ -50,7 +57,8 @@ function DI.value_and_pullback(
5057 dy_righttype =
5158 dy isa tangent_type (Y) ? dy : _copy_to_output!! (prep. dy_righttype, dy)
5259 y, (_, new_dx) = value_and_pullback!! (
53- prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)...
60+ prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)... ;
61+ prep. args_to_zero
5462 )
5563 y, _copy_output (new_dx)
5664 end
101109
102110# # Gradient
103111
104- struct MooncakeGradientPrep{SIG, Tcache} <: DI.GradientPrep{SIG}
112+ struct MooncakeGradientPrep{SIG, Tcache, N } <: DI.GradientPrep{SIG}
105113 _sig:: Val{SIG}
106114 cache:: Tcache
115+ args_to_zero:: NTuple{N, Bool}
107116end
108117
109118function DI. prepare_gradient_nokwarg (
@@ -114,7 +123,12 @@ function DI.prepare_gradient_nokwarg(
114123 cache = prepare_gradient_cache (
115124 f, x, map (DI. unwrap, contexts)... ; config. debug_mode, config. silence_debug_messages
116125 )
117- prep = MooncakeGradientPrep (_sig, cache)
126+ args_to_zero = (
127+ false , # f
128+ true , # x
129+ map (_ -> false , contexts)... ,
130+ )
131+ prep = MooncakeGradientPrep (_sig, cache, args_to_zero)
118132 return prep
119133end
120134
@@ -126,7 +140,10 @@ function DI.value_and_gradient(
126140 contexts:: Vararg{DI.Context, C} ,
127141 ) where {F, C}
128142 DI. check_prep (f, prep, backend, x, contexts... )
129- y, (_, new_grad) = value_and_gradient!! (prep. cache, f, x, map (DI. unwrap, contexts)... )
143+ y, (_, new_grad) = value_and_gradient!! (
144+ prep. cache, f, x, map (DI. unwrap, contexts)... ;
145+ prep. args_to_zero
146+ )
130147 return y, _copy_output (new_grad)
131148end
132149
@@ -139,7 +156,10 @@ function DI.value_and_gradient!(
139156 contexts:: Vararg{DI.Context, C} ,
140157 ) where {F, C}
141158 DI. check_prep (f, prep, backend, x, contexts... )
142- y, (_, new_grad) = value_and_gradient!! (prep. cache, f, x, map (DI. unwrap, contexts)... )
159+ y, (_, new_grad) = value_and_gradient!! (
160+ prep. cache, f, x, map (DI. unwrap, contexts)... ;
161+ prep. args_to_zero
162+ )
143163 copyto! (grad, new_grad)
144164 return y, grad
145165end
0 commit comments