1+ # # Pullback
2+
13struct MooncakeOneArgPullbackPrep{Tcache,DY} <: DI.PullbackPrep
24 cache:: Tcache
35 dy_righttype:: DY
46end
57
68function DI. prepare_pullback (
7- f, backend:: AutoMooncake , x, ty:: NTuple , contexts:: Vararg{DI.Context,C}
8- ) where {C}
9+ f:: F , backend:: AutoMooncake , x, ty:: NTuple , contexts:: Vararg{DI.Context,C}
10+ ) where {F, C}
911 config = get_config (backend)
1012 cache = prepare_pullback_cache (
1113 f, x, map (DI. unwrap, contexts)... ; config. debug_mode, config. silence_debug_messages
@@ -27,34 +29,34 @@ function DI.value_and_pullback(
2729) where {F,Y,C}
2830 dy = only (ty)
2931 dy_righttype = dy isa tangent_type (Y) ? dy : copyto!! (prep. dy_righttype, dy)
30- new_y, (_, new_dx) = Mooncake . value_and_pullback!! (
32+ new_y, (_, new_dx) = value_and_pullback!! (
3133 prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)...
3234 )
33- return new_y, (copy (new_dx),)
35+ return new_y, (mycopy (new_dx),)
3436end
3537
3638function DI. value_and_pullback! (
37- f,
39+ f:: F ,
3840 tx:: NTuple{1} ,
3941 prep:: MooncakeOneArgPullbackPrep{Y} ,
4042 backend:: AutoMooncake ,
4143 x,
4244 ty:: NTuple{1} ,
4345 contexts:: Vararg{DI.Context,C} ,
44- ) where {Y,C}
46+ ) where {F, Y,C}
4547 y, (new_dx,) = DI. value_and_pullback (f, prep, backend, x, ty, contexts... )
4648 copyto! (only (tx), new_dx)
4749 return y, tx
4850end
4951
5052function DI. value_and_pullback (
51- f,
53+ f:: F ,
5254 prep:: MooncakeOneArgPullbackPrep ,
5355 backend:: AutoMooncake ,
5456 x,
5557 ty:: NTuple ,
5658 contexts:: Vararg{DI.Context,C} ,
57- ) where {C}
59+ ) where {F, C}
5860 ys_and_tx = map (ty) do dy
5961 y, tx = DI. value_and_pullback (f, prep, backend, x, (dy,), contexts... )
6062 y, only (tx)
@@ -65,14 +67,14 @@ function DI.value_and_pullback(
6567end
6668
6769function DI. value_and_pullback! (
68- f,
70+ f:: F ,
6971 tx:: NTuple ,
7072 prep:: MooncakeOneArgPullbackPrep ,
7173 backend:: AutoMooncake ,
7274 x,
7375 ty:: NTuple ,
7476 contexts:: Vararg{DI.Context,C} ,
75- ) where {C}
77+ ) where {F, C}
7678 ys = map (tx, ty) do dx, dy
7779 y, _ = DI. value_and_pullback! (f, (dx,), prep, backend, x, (dy,), contexts... )
7880 y
@@ -82,24 +84,85 @@ function DI.value_and_pullback!(
8284end
8385
8486function DI. pullback (
85- f,
87+ f:: F ,
8688 prep:: MooncakeOneArgPullbackPrep ,
8789 backend:: AutoMooncake ,
8890 x,
8991 ty:: NTuple ,
9092 contexts:: Vararg{DI.Context,C} ,
91- ) where {C}
93+ ) where {F, C}
9294 return DI. value_and_pullback (f, prep, backend, x, ty, contexts... )[2 ]
9395end
9496
9597function DI. pullback! (
96- f,
98+ f:: F ,
9799 tx:: NTuple ,
98100 prep:: MooncakeOneArgPullbackPrep ,
99101 backend:: AutoMooncake ,
100102 x,
101103 ty:: NTuple ,
102104 contexts:: Vararg{DI.Context,C} ,
103- ) where {C}
105+ ) where {F, C}
104106 return DI. value_and_pullback! (f, tx, prep, backend, x, ty, contexts... )[2 ]
105107end
108+
109+ # # Gradient
110+
111+ struct MooncakeGradientPrep{Tcache} <: DI.GradientPrep
112+ cache:: Tcache
113+ end
114+
115+ function DI. prepare_gradient (
116+ f:: F , backend:: AutoMooncake , x, contexts:: Vararg{DI.Context,C}
117+ ) where {F,C}
118+ config = get_config (backend)
119+ cache = prepare_pullback_cache (
120+ f, x, map (DI. unwrap, contexts)... ; config. debug_mode, config. silence_debug_messages
121+ )
122+ prep = MooncakeGradientPrep (cache)
123+ DI. value_and_gradient (f, prep, backend, x, contexts... )
124+ return prep
125+ end
126+
127+ function DI. value_and_gradient (
128+ f:: F , prep:: MooncakeGradientPrep , :: AutoMooncake , x, contexts:: Vararg{DI.Context,C}
129+ ) where {F,C}
130+ y, (_, new_grad) = value_and_gradient!! (prep. cache, f, x, map (DI. unwrap, contexts)... )
131+ return y, mycopy (new_grad)
132+ end
133+
134+ function DI. value_and_gradient! (
135+ f:: F ,
136+ grad,
137+ prep:: MooncakeGradientPrep ,
138+ :: AutoMooncake ,
139+ x,
140+ contexts:: Vararg{DI.Context,C} ,
141+ ) where {F,C}
142+ y, (_, new_grad) = value_and_gradient!! (prep. cache, f, x, map (DI. unwrap, contexts)... )
143+ copyto! (grad, new_grad)
144+ return y, grad
145+ end
146+
147+ function DI. gradient (
148+ f:: F ,
149+ prep:: MooncakeGradientPrep ,
150+ backend:: AutoMooncake ,
151+ x,
152+ contexts:: Vararg{DI.Context,C} ,
153+ ) where {F,C}
154+ _, grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
155+ return grad
156+ end
157+
158+ function DI. gradient! (
159+ f:: F ,
160+ grad,
161+ prep:: MooncakeGradientPrep ,
162+ backend:: AutoMooncake ,
163+ x,
164+ contexts:: Vararg{DI.Context,C} ,
165+ ) where {F,C}
166+ DI. value_and_gradient! (f, grad, prep, backend, x, contexts... )
167+ return grad
168+ end
0 commit comments