@@ -56,13 +56,13 @@ abstract type JacobianExtras <: Extras end
5656struct NoJacobianExtras <: JacobianExtras end
5757
5858struct PushforwardJacobianExtras{B,D,E<: PushforwardExtras ,Y} <: JacobianExtras
59- seeds :: D
59+ batched_seeds :: Vector{Batch{B,D}}
6060 pushforward_batched_extras:: E
6161 y_example:: Y
6262end
6363
6464struct PullbackJacobianExtras{B,D,E<: PullbackExtras ,Y} <: JacobianExtras
65- seeds :: D
65+ batched_seeds :: Vector{Batch{B,D}}
6666 pullback_batched_extras:: E
6767 y_example:: Y
6868end
@@ -80,26 +80,38 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
8080 N = length (x)
8181 B = pick_batchsize (backend, N)
8282 seeds = [basis (backend, x, ind) for ind in CartesianIndices (x)]
83+ batched_seeds =
84+ Batch .([
85+ ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for
86+ a in 1 : div (N, B, RoundUp)
87+ ])
8388 pushforward_batched_extras = prepare_pushforward_batched (
84- f_or_f!y... , backend, x, Batch ( ntuple ( Returns (seeds [1 ]), Val (B)))
89+ f_or_f!y... , backend, x, batched_seeds [1 ]
8590 )
86- D = typeof (seeds)
91+ D = eltype (seeds)
8792 E = typeof (pushforward_batched_extras)
8893 Y = typeof (y)
89- return PushforwardJacobianExtras {B,D,E,Y} (seeds, pushforward_batched_extras, copy (y))
94+ return PushforwardJacobianExtras {B,D,E,Y} (
95+ batched_seeds, pushforward_batched_extras, copy (y)
96+ )
9097end
9198
9299function prepare_jacobian_aux (f_or_f!y:: FY , backend, x, y, :: PushforwardSlow ) where {FY}
93100 M = length (y)
94101 B = pick_batchsize (backend, M)
95102 seeds = [basis (backend, y, ind) for ind in CartesianIndices (y)]
103+ batched_seeds =
104+ Batch .([
105+ ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % M], Val (B)) for
106+ a in 1 : div (M, B, RoundUp)
107+ ])
96108 pullback_batched_extras = prepare_pullback_batched (
97- f_or_f!y... , backend, x, Batch ( ntuple ( Returns (seeds [1 ]), Val (B)))
109+ f_or_f!y... , backend, x, batched_seeds [1 ]
98110 )
99- D = typeof (seeds)
111+ D = eltype (seeds)
100112 E = typeof (pullback_batched_extras)
101113 Y = typeof (y)
102- return PullbackJacobianExtras {B,D,E,Y} (seeds , pullback_batched_extras, copy (y))
114+ return PullbackJacobianExtras {B,D,E,Y} (batched_seeds , pullback_batched_extras, copy (y))
103115end
104116
105117# # One argument
@@ -197,27 +209,16 @@ end
197209function jacobian_aux (
198210 f_or_f!y:: FY , backend, x:: AbstractArray , extras:: PushforwardJacobianExtras{B}
199211) where {FY,B}
200- @compat (; seeds , pushforward_batched_extras, y_example) = extras
212+ @compat (; batched_seeds , pushforward_batched_extras, y_example) = extras
201213 N = length (x)
202214
203215 pushforward_batched_extras_same = prepare_pushforward_batched_same_point (
204- f_or_f!y... ,
205- backend,
206- x,
207- Batch (ntuple (Returns (seeds[1 ]), Val (B))),
208- pushforward_batched_extras,
216+ f_or_f!y... , backend, x, batched_seeds[1 ], pushforward_batched_extras
209217 )
210218
211- jac_blocks = map (1 : div (N, B, RoundUp)) do a
212- dx_batch_elements = ntuple (Val (B)) do b
213- seeds[1 + ((a - 1 ) * B + (b - 1 )) % N]
214- end
219+ jac_blocks = map (eachindex (batched_seeds)) do a
215220 dy_batch = pushforward_batched (
216- f_or_f!y... ,
217- backend,
218- x,
219- Batch (dx_batch_elements),
220- pushforward_batched_extras_same,
221+ f_or_f!y... , backend, x, batched_seeds[a], pushforward_batched_extras_same
221222 )
222223 stack (vec, dy_batch. elements; dims= 2 )
223224 end
@@ -232,27 +233,16 @@ end
232233function jacobian_aux (
233234 f_or_f!y:: FY , backend, x:: AbstractArray , extras:: PullbackJacobianExtras{B}
234235) where {FY,B}
235- @compat (; seeds , pullback_batched_extras, y_example) = extras
236+ @compat (; batched_seeds , pullback_batched_extras, y_example) = extras
236237 M = length (y_example)
237238
238239 pullback_batched_extras_same = prepare_pullback_batched_same_point (
239- f_or_f!y... ,
240- backend,
241- x,
242- Batch (ntuple (Returns (seeds[1 ]), Val (B))),
243- extras. pullback_batched_extras,
240+ f_or_f!y... , backend, x, batched_seeds[1 ], extras. pullback_batched_extras
244241 )
245242
246- jac_blocks = map (1 : div (M, B, RoundUp)) do a
247- dy_batch_elements = ntuple (Val (B)) do b
248- seeds[1 + ((a - 1 ) * B + (b - 1 )) % M]
249- end
243+ jac_blocks = map (eachindex (batched_seeds)) do a
250244 dx_batch = pullback_batched (
251- f_or_f!y... ,
252- backend,
253- x,
254- Batch (dy_batch_elements),
255- pullback_batched_extras_same,
245+ f_or_f!y... , backend, x, batched_seeds[a], pullback_batched_extras_same
256246 )
257247 stack (vec, dx_batch. elements; dims= 1 )
258248 end
@@ -271,21 +261,14 @@ function jacobian_aux!(
271261 x:: AbstractArray ,
272262 extras:: PushforwardJacobianExtras{B} ,
273263) where {FY,B}
274- @compat (; seeds , pushforward_batched_extras, y_example) = extras
264+ @compat (; batched_seeds , pushforward_batched_extras, y_example) = extras
275265 N = length (x)
276266
277267 pushforward_batched_extras_same = prepare_pushforward_batched_same_point (
278- f_or_f!y... ,
279- backend,
280- x,
281- Batch (ntuple (Returns (seeds[1 ]), Val (B))),
282- pushforward_batched_extras,
268+ f_or_f!y... , backend, x, batched_seeds[1 ], pushforward_batched_extras
283269 )
284270
285- for a in 1 : div (N, B, RoundUp)
286- dx_batch_elements = ntuple (Val (B)) do b
287- seeds[1 + ((a - 1 ) * B + (b - 1 )) % N]
288- end
271+ for a in eachindex (batched_seeds)
289272 dy_batch_elements = ntuple (Val (B)) do b
290273 reshape (view (jac, :, 1 + ((a - 1 ) * B + (b - 1 )) % N), size (y_example))
291274 end
@@ -294,7 +277,7 @@ function jacobian_aux!(
294277 Batch (dy_batch_elements),
295278 backend,
296279 x,
297- Batch (dx_batch_elements) ,
280+ batched_seeds[a] ,
298281 pushforward_batched_extras_same,
299282 )
300283 end
@@ -309,21 +292,14 @@ function jacobian_aux!(
309292 x:: AbstractArray ,
310293 extras:: PullbackJacobianExtras{B} ,
311294) where {FY,B}
312- @compat (; seeds , pullback_batched_extras, y_example) = extras
295+ @compat (; batched_seeds , pullback_batched_extras, y_example) = extras
313296 M = length (y_example)
314297
315298 pullback_batched_extras_same = prepare_pullback_batched_same_point (
316- f_or_f!y... ,
317- backend,
318- x,
319- Batch (ntuple (Returns (seeds[1 ]), Val (B))),
320- extras. pullback_batched_extras,
299+ f_or_f!y... , backend, x, batched_seeds[1 ], extras. pullback_batched_extras
321300 )
322301
323- for a in 1 : div (M, B, RoundUp)
324- dy_batch_elements = ntuple (Val (B)) do b
325- seeds[1 + ((a - 1 ) * B + (b - 1 )) % M]
326- end
302+ for a in eachindex (batched_seeds)
327303 dx_batch_elements = ntuple (Val (B)) do b
328304 reshape (view (jac, 1 + ((a - 1 ) * B + (b - 1 )) % M, :), size (x))
329305 end
@@ -332,7 +308,7 @@ function jacobian_aux!(
332308 Batch (dx_batch_elements),
333309 backend,
334310 x,
335- Batch (dy_batch_elements) ,
311+ batched_seeds[a] ,
336312 pullback_batched_extras_same,
337313 )
338314 end
0 commit comments