@@ -138,12 +138,14 @@ struct PushforwardJacobianPrep{
138138 BS<: BatchSizeSettings ,
139139 S<: AbstractVector{<:NTuple} ,
140140 R<: AbstractVector{<:NTuple} ,
141+ SE<: NTuple ,
141142 E<: PushforwardPrep ,
142143} <: StandardJacobianPrep{SIG}
143144 _sig:: Val{SIG}
144145 batch_size_settings:: BS
145146 batched_seeds:: S
146147 batched_results:: R
148+ seed_example:: SE
147149 pushforward_prep:: E
148150end
149151
@@ -152,15 +154,15 @@ struct PullbackJacobianPrep{
152154 BS<: BatchSizeSettings ,
153155 S<: AbstractVector{<:NTuple} ,
154156 R<: AbstractVector{<:NTuple} ,
157+ SE<: NTuple ,
155158 E<: PullbackPrep ,
156- Y,
157159} <: StandardJacobianPrep{SIG}
158160 _sig:: Val{SIG}
159161 batch_size_settings:: BS
160162 batched_seeds:: S
161163 batched_results:: R
164+ seed_example:: SE
162165 pullback_prep:: E
163- y_example:: Y
164166end
165167
166168function prepare_jacobian_nokwarg (
@@ -213,11 +215,17 @@ function _prepare_jacobian_aux(
213215 ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
214216 ]
215217 batched_results = [ntuple (b -> similar (y), Val (B)) for _ in batched_seeds]
218+ seed_example = ntuple (b -> zero (x), Val (B))
216219 pushforward_prep = prepare_pushforward_nokwarg (
217- strict, f_or_f!y... , backend, x, ntuple (b -> zero (x), Val (B)) , contexts...
220+ strict, f_or_f!y... , backend, x, seed_example , contexts...
218221 )
219222 return PushforwardJacobianPrep (
220- _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
223+ _sig,
224+ batch_size_settings,
225+ batched_seeds,
226+ batched_results,
227+ seed_example,
228+ pushforward_prep,
221229 )
222230end
223231
@@ -238,11 +246,17 @@ function _prepare_jacobian_aux(
238246 ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
239247 ]
240248 batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
249+ seed_example = ntuple (b -> zero (y), Val (B))
241250 pullback_prep = prepare_pullback_nokwarg (
242- strict, f_or_f!y... , backend, x, ntuple (b -> zero (y), Val (B)) , contexts...
251+ strict, f_or_f!y... , backend, x, seed_example , contexts...
243252 )
244253 return PullbackJacobianPrep (
245- _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep, y
254+ _sig,
255+ batch_size_settings,
256+ batched_seeds,
257+ batched_results,
258+ seed_example,
259+ pullback_prep,
246260 )
247261end
248262
@@ -365,11 +379,11 @@ function _jacobian_aux(
365379 x,
366380 contexts:: Vararg{Context,C} ,
367381) where {FY,SIG,B,aligned,C}
368- (; batch_size_settings, batched_seeds, pushforward_prep) = prep
382+ (; batch_size_settings, batched_seeds, seed_example, pushforward_prep) = prep
369383 (; A, B_last) = batch_size_settings
370384
371385 pushforward_prep_same = prepare_pushforward_same_point (
372- f_or_f!y... , pushforward_prep, backend, x, ntuple (b -> zero (x), Val (B)) , contexts...
386+ f_or_f!y... , pushforward_prep, backend, x, seed_example , contexts...
373387 )
374388
375389 jac = mapreduce (hcat, eachindex (batched_seeds)) do a
@@ -421,16 +435,11 @@ function _jacobian_aux(
421435 x,
422436 contexts:: Vararg{Context,C} ,
423437) where {FY,SIG,B,aligned,C}
424- (; batch_size_settings, batched_seeds, pullback_prep, y_example ) = prep
438+ (; batch_size_settings, batched_seeds, seed_example, pullback_prep ) = prep
425439 (; A, B_last) = batch_size_settings
426440
427441 pullback_prep_same = prepare_pullback_same_point (
428- f_or_f!y... ,
429- pullback_prep,
430- backend,
431- x,
432- ntuple (b -> zero (y_example), Val (B)),
433- contexts... ,
442+ f_or_f!y... , pullback_prep, backend, x, seed_example, contexts...
434443 )
435444
436445 jac = mapreduce (vcat, eachindex (batched_seeds)) do a
@@ -458,11 +467,13 @@ function _jacobian_aux!(
458467 x,
459468 contexts:: Vararg{Context,C} ,
460469) where {FY,SIG,B,C}
461- (; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep
470+ (;
471+ batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep
472+ ) = prep
462473 (; N) = batch_size_settings
463474
464475 pushforward_prep_same = prepare_pushforward_same_point (
465- f_or_f!y... , pushforward_prep, backend, x, batched_seeds[ 1 ] , contexts...
476+ f_or_f!y... , pushforward_prep, backend, x, seed_example , contexts...
466477 )
467478
468479 for a in eachindex (batched_seeds, batched_results)
@@ -494,16 +505,12 @@ function _jacobian_aux!(
494505 x,
495506 contexts:: Vararg{Context,C} ,
496507) where {FY,SIG,B,C}
497- (; batch_size_settings, batched_seeds, batched_results, pullback_prep, y_example) = prep
508+ (; batch_size_settings, batched_seeds, batched_results, seed_example, pullback_prep) =
509+ prep
498510 (; N) = batch_size_settings
499511
500512 pullback_prep_same = prepare_pullback_same_point (
501- f_or_f!y... ,
502- pullback_prep,
503- backend,
504- x,
505- ntuple (b -> zero (y_example), Val (B)),
506- contexts... ,
513+ f_or_f!y... , pullback_prep, backend, x, seed_example, contexts...
507514 )
508515
509516 for a in eachindex (batched_seeds, batched_results)
0 commit comments