@@ -153,12 +153,14 @@ struct PullbackJacobianPrep{
153153 S<: AbstractVector{<:NTuple} ,
154154 R<: AbstractVector{<:NTuple} ,
155155 E<: PullbackPrep ,
156+ Y,
156157} <: StandardJacobianPrep{SIG}
157158 _sig:: Val{SIG}
158159 batch_size_settings:: BS
159160 batched_seeds:: S
160161 batched_results:: R
161162 pullback_prep:: E
163+ y_example:: Y
162164end
163165
164166function prepare_jacobian_nokwarg (
@@ -212,7 +214,7 @@ function _prepare_jacobian_aux(
212214 ]
213215 batched_results = [ntuple (b -> similar (y), Val (B)) for _ in batched_seeds]
214216 pushforward_prep = prepare_pushforward_nokwarg (
215- strict, f_or_f!y... , backend, x, batched_seeds[ 1 ] , contexts...
217+ strict, f_or_f!y... , backend, x, ntuple (b -> zero (x), Val (B)) , contexts...
216218 )
217219 return PushforwardJacobianPrep (
218220 _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
@@ -237,10 +239,10 @@ function _prepare_jacobian_aux(
237239 ]
238240 batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
239241 pullback_prep = prepare_pullback_nokwarg (
240- strict, f_or_f!y... , backend, x, batched_seeds[ 1 ] , contexts...
242+ strict, f_or_f!y... , backend, x, ntuple (b -> zero (y), Val (B)) , contexts...
241243 )
242244 return PullbackJacobianPrep (
243- _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep
245+ _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep, y
244246 )
245247end
246248
@@ -367,7 +369,7 @@ function _jacobian_aux(
367369 (; A, B_last) = batch_size_settings
368370
369371 pushforward_prep_same = prepare_pushforward_same_point (
370- f_or_f!y... , pushforward_prep, backend, x, batched_seeds[ 1 ] , contexts...
372+ f_or_f!y... , pushforward_prep, backend, x, ntuple (b -> zero (x), Val (B)) , contexts...
371373 )
372374
373375 jac = mapreduce (hcat, eachindex (batched_seeds)) do a
@@ -419,11 +421,16 @@ function _jacobian_aux(
419421 x,
420422 contexts:: Vararg{Context,C} ,
421423) where {FY,SIG,B,aligned,C}
422- (; batch_size_settings, batched_seeds, pullback_prep) = prep
424+ (; batch_size_settings, batched_seeds, pullback_prep, y_example ) = prep
423425 (; A, B_last) = batch_size_settings
424426
425427 pullback_prep_same = prepare_pullback_same_point (
426- f_or_f!y... , prep. pullback_prep, backend, x, batched_seeds[1 ], contexts...
428+ f_or_f!y... ,
429+ pullback_prep,
430+ backend,
431+ x,
432+ ntuple (b -> zero (y_example), Val (B)),
433+ contexts... ,
427434 )
428435
429436 jac = mapreduce (vcat, eachindex (batched_seeds)) do a
@@ -487,11 +494,16 @@ function _jacobian_aux!(
487494 x,
488495 contexts:: Vararg{Context,C} ,
489496) where {FY,SIG,B,C}
490- (; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep
497+ (; batch_size_settings, batched_seeds, batched_results, pullback_prep, y_example ) = prep
491498 (; N) = batch_size_settings
492499
493500 pullback_prep_same = prepare_pullback_same_point (
494- f_or_f!y... , pullback_prep, backend, x, batched_seeds[1 ], contexts...
501+ f_or_f!y... ,
502+ pullback_prep,
503+ backend,
504+ x,
505+ ntuple (b -> zero (y_example), Val (B)),
506+ contexts... ,
495507 )
496508
497509 for a in eachindex (batched_seeds, batched_results)
0 commit comments