@@ -50,7 +50,10 @@ struct ForwardOverReverseHVPPrep{G,E<:PushforwardPrep} <: HVPPrep
5050 outer_pushforward_prep:: E
5151end
5252
53- struct ReverseOverForwardHVPPrep <: HVPPrep end
53+ struct ReverseOverForwardHVPPrep{P,E} <: HVPPrep
54+ inner_pushforward:: P
55+ outer_gradient_prep:: E
56+ end
5457
5558struct ReverseOverReverseHVPPrep{G,E<: PullbackPrep } <: HVPPrep
5659 inner_gradient:: G
@@ -111,9 +114,19 @@ function _prepare_hvp_aux(
111114 tx:: Tangents ,
112115 contexts:: Vararg{Context,C} ,
113116) where {F,C}
117+ rewrap = Rewrap (contexts... )
114118 # gradient of pushforward
115- # uses dx in the closure so it can't be prepared
116- return ReverseOverForwardHVPPrep ()
119+ function inner_pushforward (_x, _dx, unannotated_contexts... )
120+ annotated_contexts = rewrap (unannotated_contexts... )
121+ ty = pushforward (
122+ f, nested (inner (backend)), _x, Tangents (_dx), annotated_contexts...
123+ )
124+ return only (ty)
125+ end
126+ outer_gradient_prep = prepare_gradient (
127+ inner_pushforward, outer (backend), x, contexts...
128+ )
129+ return ReverseOverForwardHVPPrep (inner_pushforward, outer_gradient_prep)
117130end
118131
119132function _prepare_hvp_aux (
@@ -168,23 +181,15 @@ end
168181
169182function hvp (
170183 f:: F ,
171- :: ReverseOverForwardHVPPrep ,
184+ prep :: ReverseOverForwardHVPPrep ,
172185 backend:: AbstractADType ,
173186 x,
174187 tx:: Tangents ,
175188 contexts:: Vararg{Context,C} ,
176189) where {F,C}
177- rewrap = Rewrap (contexts ... )
190+ @compat (; inner_pushforward, outer_gradient_prep) = prep
178191 tg = map (tx) do dx
179- function inner_pushforward (_x, unannotated_contexts... )
180- annotated_contexts = rewrap (unannotated_contexts... )
181- return only (
182- pushforward (
183- f, nested (inner (backend)), _x, Tangents (dx), annotated_contexts...
184- ),
185- )
186- end
187- gradient (only ∘ inner_pushforward, outer (backend), x, contexts... )
192+ gradient (inner_pushforward, outer (backend), x, Constant (dx), contexts... )
188193 end
189194 return tg
190195end
@@ -234,23 +239,23 @@ end
234239function hvp! (
235240 f:: F ,
236241 tg:: Tangents ,
237- :: ReverseOverForwardHVPPrep ,
242+ prep :: ReverseOverForwardHVPPrep ,
238243 backend:: AbstractADType ,
239244 x,
240245 tx:: Tangents ,
241246 contexts:: Vararg{Context,C} ,
242247) where {F,C}
243- rewrap = Rewrap (contexts ... )
248+ @compat (; inner_pushforward, outer_gradient_prep) = prep
244249 for b in eachindex (tx. d, tg. d)
245- function inner_pushforward (_x, unannotated_contexts ... )
246- annotated_contexts = rewrap (unannotated_contexts ... )
247- return only (
248- pushforward (
249- f, nested ( inner ( backend)), _x, Tangents (tx . d[b]), annotated_contexts ...
250- ) ,
251- )
252- end
253- gradient! (only ∘ inner_pushforward, tg . d[b], outer (backend), x, contexts ... )
250+ gradient! (
251+ inner_pushforward,
252+ tg . d[b],
253+ outer_gradient_prep,
254+ outer ( backend),
255+ x ,
256+ Constant (tx . d[b]),
257+ contexts ... ,
258+ )
254259 end
255260 return tg
256261end
0 commit comments