@@ -102,33 +102,33 @@ function prepare_hvp(f::F, backend::AbstractADType, x, dx) where {F}
102102end
103103
104104function prepare_hvp (f:: F , backend:: SecondOrder , x, dx) where {F}
105- return prepare_hvp (f, backend, x, dx, hvp_mode (backend))
105+ return prepare_hvp_aux (f, backend, x, dx, hvp_mode (backend))
106106end
107107
108- function prepare_hvp (f:: F , backend:: SecondOrder , x, dx, :: ForwardOverForward ) where {F}
108+ function prepare_hvp_aux (f:: F , backend:: SecondOrder , x, dx, :: ForwardOverForward ) where {F}
109109 # pushforward of many pushforwards in theory, but pushforward of gradient in practice
110110 inner_gradient = InnerGradient (f, nested (inner (backend)))
111111 outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, dx)
112112 return ForwardOverForwardHVPExtras (inner_gradient, outer_pushforward_extras)
113113end
114114
115- function prepare_hvp (f:: F , backend:: SecondOrder , x, dx, :: ForwardOverReverse ) where {F}
115+ function prepare_hvp_aux (f:: F , backend:: SecondOrder , x, dx, :: ForwardOverReverse ) where {F}
116116 # pushforward of gradient
117117 inner_gradient = InnerGradient (f, nested (inner (backend)))
118118 outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, dx)
119119 return ForwardOverReverseHVPExtras (inner_gradient, outer_pushforward_extras)
120120end
121121
122- function prepare_hvp (f:: F , backend:: SecondOrder , x, dx, :: ReverseOverForward ) where {F}
122+ function prepare_hvp_aux (f:: F , backend:: SecondOrder , x, dx, :: ReverseOverForward ) where {F}
123123 # gradient of pushforward
124124 # uses dx in the closure so it can't be stored
125125 inner_pushforward = InnerPushforwardFixedSeed (f, nested (inner (backend)), dx)
126126 outer_gradient_extras = prepare_gradient (inner_pushforward, outer (backend), x)
127127 return ReverseOverForwardHVPExtras (outer_gradient_extras)
128128end
129129
130- function prepare_hvp (f:: F , backend:: SecondOrder , x, dx, :: ReverseOverReverse ) where {F}
131- # pullback of the gradient
130+ function prepare_hvp_aux (f:: F , backend:: SecondOrder , x, dx, :: ReverseOverReverse ) where {F}
131+ # pullback of gradient
132132 inner_gradient = InnerGradient (f, nested (inner (backend)))
133133 outer_pullback_extras = prepare_pullback (inner_gradient, outer (backend), x, dx)
134134 return ReverseOverReverseHVPExtras (inner_gradient, outer_pullback_extras)
@@ -149,16 +149,58 @@ end
149149
150150# ## Batched
151151
152- function prepare_hvp_batched (f:: F , backend:: AbstractADType , x, dx:: Batch{B} ) where {F,B}
153- return prepare_hvp (f, backend, x, first (dx. elements))
152+ function prepare_hvp_batched (f:: F , backend:: AbstractADType , x, dx:: Batch ) where {F}
153+ return prepare_hvp_batched (f, SecondOrder (backend, backend), x, dx)
154+ end
155+
156+ function prepare_hvp_batched (f:: F , backend:: SecondOrder , x, dx:: Batch ) where {F}
157+ return prepare_hvp_batched_aux (f, backend, x, dx, hvp_mode (backend))
158+ end
159+
160+ function prepare_hvp_batched_aux (
161+ f:: F , backend:: SecondOrder , x, dx:: Batch , :: ForwardOverForward
162+ ) where {F}
163+ # batched pushforward of gradient
164+ inner_gradient = InnerGradient (f, nested (inner (backend)))
165+ outer_pushforward_extras = prepare_pushforward_batched (
166+ inner_gradient, outer (backend), x, dx
167+ )
168+ return ForwardOverForwardHVPExtras (inner_gradient, outer_pushforward_extras)
169+ end
170+
171+ function prepare_hvp_batched_aux (
172+ f:: F , backend:: SecondOrder , x, dx:: Batch , :: ForwardOverReverse
173+ ) where {F}
174+ # batched pushforward of gradient
175+ inner_gradient = InnerGradient (f, nested (inner (backend)))
176+ outer_pushforward_extras = prepare_pushforward_batched (
177+ inner_gradient, outer (backend), x, dx
178+ )
179+ return ForwardOverReverseHVPExtras (inner_gradient, outer_pushforward_extras)
180+ end
181+
182+ function prepare_hvp_batched_aux (
183+ f:: F , backend:: SecondOrder , x, dx:: Batch , :: ReverseOverForward
184+ ) where {F}
185+ # TODO : batched version replacing the outer gradient with a pullback
186+ return prepare_hvp_aux (f, backend, x, first (dx. elements), ReverseOverForward ())
187+ end
188+
189+ function prepare_hvp_batched_aux (
190+ f:: F , backend:: SecondOrder , x, dx:: Batch , :: ReverseOverReverse
191+ ) where {F}
192+ # batched pullback of gradient
193+ inner_gradient = InnerGradient (f, nested (inner (backend)))
194+ outer_pullback_extras = prepare_pullback_batched (inner_gradient, outer (backend), x, dx)
195+ return ReverseOverReverseHVPExtras (inner_gradient, outer_pullback_extras)
154196end
155197
156198# ## Batched, same point
157199
158200function prepare_hvp_batched_same_point (
159- f:: F , backend:: AbstractADType , x, dx:: Batch{B} , extras:: HVPExtras
160- ) where {F,B }
161- return prepare_hvp_same_point (f, backend, x, first (dx . elements), extras)
201+ f:: F , backend:: AbstractADType , x, dx:: Batch , extras:: HVPExtras
202+ ) where {F}
203+ return extras
162204end
163205
164206# # One argument
@@ -241,27 +283,89 @@ end
241283
242284# ## Batched
243285
244- function hvp_batched (f:: F , backend:: AbstractADType , x, dx, extras:: HVPExtras ) where {F}
286+ function hvp_batched (
287+ f:: F , backend:: AbstractADType , x, dx:: Batch , extras:: HVPExtras
288+ ) where {F}
245289 return hvp_batched (f, SecondOrder (backend, backend), x, dx, extras)
246290end
247291
248292function hvp_batched (
249- f:: F , backend:: SecondOrder , x, dx:: Batch{B} , extras:: HVPExtras
293+ f:: F , backend:: SecondOrder , x, dx:: Batch , extras:: ForwardOverForwardHVPExtras
294+ ) where {F}
295+ @compat (; inner_gradient, outer_pushforward_extras) = extras
296+ return pushforward_batched (
297+ inner_gradient, outer (backend), x, dx, outer_pushforward_extras
298+ )
299+ end
300+
301+ function hvp_batched (
302+ f:: F , backend:: SecondOrder , x, dx:: Batch , extras:: ForwardOverReverseHVPExtras
303+ ) where {F}
304+ @compat (; inner_gradient, outer_pushforward_extras) = extras
305+ return pushforward_batched (
306+ inner_gradient, outer (backend), x, dx, outer_pushforward_extras
307+ )
308+ end
309+
310+ function hvp_batched (
311+ f:: F , backend:: SecondOrder , x, dx:: Batch{B} , extras:: ReverseOverForwardHVPExtras
250312) where {F,B}
251- dg_elements = ntuple (Val (B)) do l
252- hvp (f, backend, x, dx. elements[l ], extras)
313+ dg_elements = ntuple (Val (B)) do b
314+ hvp (f, backend, x, dx. elements[b ], extras)
253315 end
254316 return Batch (dg_elements)
255317end
256318
257- function hvp_batched! (f:: F , dg, backend:: AbstractADType , x, dx, extras:: HVPExtras ) where {F}
319+ function hvp_batched (
320+ f:: F , backend:: SecondOrder , x, dx:: Batch , extras:: ReverseOverReverseHVPExtras
321+ ) where {F}
322+ @compat (; inner_gradient, outer_pullback_extras) = extras
323+ return pullback_batched (inner_gradient, outer (backend), x, dx, outer_pullback_extras)
324+ end
325+
326+ function hvp_batched! (
327+ f:: F , dg:: Batch , backend:: AbstractADType , x, dx:: Batch , extras:: HVPExtras
328+ ) where {F}
258329 return hvp_batched! (f, dg, SecondOrder (backend, backend), x, dx, extras)
259330end
260331
261332function hvp_batched! (
262- f:: F , dg:: Batch{B} , backend:: SecondOrder , x, dx:: Batch{B} , extras:: HVPExtras
333+ f:: F , dg:: Batch , backend:: SecondOrder , x, dx:: Batch , extras:: ForwardOverForwardHVPExtras
334+ ) where {F}
335+ @compat (; inner_gradient, outer_pushforward_extras) = extras
336+ return pushforward_batched! (
337+ inner_gradient, dg, outer (backend), x, dx, outer_pushforward_extras
338+ )
339+ end
340+
341+ function hvp_batched! (
342+ f:: F , dg:: Batch , backend:: SecondOrder , x, dx:: Batch , extras:: ForwardOverReverseHVPExtras
343+ ) where {F}
344+ @compat (; inner_gradient, outer_pushforward_extras) = extras
345+ return pushforward_batched! (
346+ inner_gradient, dg, outer (backend), x, dx, outer_pushforward_extras
347+ )
348+ end
349+
350+ function hvp_batched! (
351+ f:: F ,
352+ dg:: Batch{B} ,
353+ backend:: SecondOrder ,
354+ x,
355+ dx:: Batch{B} ,
356+ extras:: ReverseOverForwardHVPExtras ,
263357) where {F,B}
264- for l in 1 : B
265- hvp! (f, dg. elements[l ], backend, x, dx. elements[l ], extras)
358+ for b in eachindex (dg . elements, dx . elements)
359+ hvp! (f, dg. elements[b ], backend, x, dx. elements[b ], extras)
266360 end
361+ return dg
362+ end
363+
364+ function hvp_batched! (
365+ f:: F , dg:: Batch , backend:: SecondOrder , x, dx:: Batch , extras:: ReverseOverReverseHVPExtras
366+ ) where {F}
367+ @compat (; inner_gradient, outer_pullback_extras) = extras
368+ return pullback_batched! (
369+ inner_gradient, dg, outer (backend), x, dx, outer_pullback_extras
370+ )
267371end
0 commit comments