@@ -12,8 +12,6 @@ Create an `extras` object that can be given to [`pullback`](@ref) and its varian
1212"""
1313function prepare_pullback end
1414
15- function prepare_pullback_batched end
16-
1715"""
1816 prepare_pullback_same_point(f, backend, x, dy) -> extras_same
1917 prepare_pullback_same_point(f!, y, backend, x, dy) -> extras_same
@@ -26,8 +24,6 @@ Create an `extras_same` object that can be given to [`pullback`](@ref) and its v
2624"""
2725function prepare_pullback_same_point end
2826
29- function prepare_pullback_batched_same_point end
30-
3127"""
3228 value_and_pullback(f, backend, x, dy, [extras]) -> (y, dx)
3329 value_and_pullback(f!, y, backend, x, dy, [extras]) -> (y, dx)
@@ -55,8 +51,6 @@ Compute the pullback of the function `f` at point `x` with seed `dy`.
5551"""
5652function pullback end
5753
58- function pullback_batched end
59-
6054"""
6155 pullback!(f, dx, backend, x, dy, [extras]) -> dx
6256 pullback!(f!, y, dx, backend, x, dy, [extras]) -> dx
@@ -65,8 +59,6 @@ Compute the pullback of the function `f` at point `x` with seed `dy`, overwritin
6559"""
6660function pullback! end
6761
68- function pullback_batched! end
69-
7062# # Preparation
7163
7264# ## Extras types
@@ -84,7 +76,7 @@ struct PushforwardPullbackExtras{E} <: PullbackExtras
8476 pushforward_extras:: E
8577end
8678
87- # # Standard
79+ # # Different point
8880
8981function prepare_pullback (f:: F , backend:: AbstractADType , x, dy) where {F}
9082 return prepare_pullback_aux (f, backend, x, dy, pullback_performance (backend))
@@ -114,7 +106,7 @@ function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast)
114106 throw (MissingBackendError (backend))
115107end
116108
117- # ## Standard, same point
109+ # ## Same point
118110
119111function prepare_pullback_same_point (
120112 f:: F , backend:: AbstractADType , x, dy, extras:: PullbackExtras
@@ -138,33 +130,9 @@ function prepare_pullback_same_point(f!::F, y, backend::AbstractADType, x, dy) w
138130 return prepare_pullback_same_point (f!, y, backend, x, dy, extras)
139131end
140132
141- # ## Batched
142-
143- function prepare_pullback_batched (f:: F , backend:: AbstractADType , x, dy:: Batch ) where {F}
144- return prepare_pullback (f, backend, x, first (dy. elements))
145- end
146-
147- function prepare_pullback_batched (f!:: F , y, backend:: AbstractADType , x, dy:: Batch ) where {F}
148- return prepare_pullback (f!, y, backend, x, first (dy. elements))
149- end
150-
151- # ## Batched, same point
152-
153- function prepare_pullback_batched_same_point (
154- f:: F , backend:: AbstractADType , x, dy:: Batch , extras:: PullbackExtras
155- ) where {F}
156- return prepare_pullback_same_point (f, backend, x, first (dy. elements), extras)
157- end
158-
159- function prepare_pullback_batched_same_point (
160- f!:: F , y, backend:: AbstractADType , x, dy:: Batch , extras:: PullbackExtras
161- ) where {F}
162- return prepare_pullback_same_point (f!, y, backend, x, first (dy. elements), extras)
163- end
164-
165133# # One argument
166134
167- # ## Standard
135+ # ## Without extras
168136
169137function value_and_pullback (f:: F , backend:: AbstractADType , x, dy) where {F}
170138 return value_and_pullback (f, backend, x, dy, prepare_pullback (f, backend, x, dy))
@@ -182,6 +150,8 @@ function pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F}
182150 return pullback! (f, dx, backend, x, dy, prepare_pullback (f, backend, x, dy))
183151end
184152
153+ # ## With extras
154+
185155function value_and_pullback (
186156 f:: F , backend, x, dy, extras:: PushforwardPullbackExtras
187157) where {F}
@@ -220,29 +190,9 @@ function pullback!(
220190 return value_and_pullback! (f, dx, backend, x, dy, extras)[2 ]
221191end
222192
223- # ## Batched
224-
225- function pullback_batched (
226- f:: F , backend:: AbstractADType , x, dy:: Batch{B} , extras:: PullbackExtras
227- ) where {F,B}
228- dx_elements = ntuple (Val (B)) do b
229- pullback (f, backend, x, dy. elements[b], extras)
230- end
231- return Batch (dx_elements)
232- end
233-
234- function pullback_batched! (
235- f:: F , dx:: Batch , backend:: AbstractADType , x, dy:: Batch , extras:: PullbackExtras
236- ) where {F}
237- for b in eachindex (dx. elements, dy. elements)
238- pullback! (f, dx. elements[b], backend, x, dy. elements[b], extras)
239- end
240- return dx
241- end
242-
243193# # Two arguments
244194
245- # ## Standard
195+ # ## Without extras
246196
247197function value_and_pullback (f!:: F , y, backend:: AbstractADType , x, dy) where {F}
248198 return value_and_pullback (
@@ -264,6 +214,8 @@ function pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F}
264214 return pullback! (f!, y, dx, backend, x, dy, prepare_pullback (f!, y, backend, x, dy))
265215end
266216
217+ # ## With extras
218+
267219function value_and_pullback (
268220 f!:: F , y, backend, x, dy, extras:: PushforwardPullbackExtras
269221) where {F}
@@ -297,23 +249,3 @@ function pullback!(
297249) where {F}
298250 return value_and_pullback! (f!, y, dx, backend, x, dy, extras)[2 ]
299251end
300-
301- # ## Batched
302-
303- function pullback_batched (
304- f!:: F , y, backend:: AbstractADType , x, dy:: Batch{B} , extras:: PullbackExtras
305- ) where {F,B}
306- dx_elements = ntuple (Val (B)) do b
307- pullback (f!, y, backend, x, dy. elements[b], extras)
308- end
309- return Batch (dx_elements)
310- end
311-
312- function pullback_batched! (
313- f!:: F , y, dx:: Batch , backend:: AbstractADType , x, dy:: Batch , extras:: PullbackExtras
314- ) where {F}
315- for b in eachindex (dx. elements, dy. elements)
316- pullback! (f!, y, dx. elements[b], backend, x, dy. elements[b], extras)
317- end
318- return dx
319- end
0 commit comments