@@ -84,13 +84,15 @@ struct HVPGradientHessianPrep{
8484 BS<: BatchSizeSettings ,
8585 S<: AbstractVector{<:NTuple} ,
8686 R<: AbstractVector{<:NTuple} ,
87+ SE<: NTuple ,
8788 E2<: HVPPrep ,
8889 E1<: GradientPrep ,
8990} <: HessianPrep{SIG}
9091 _sig:: Val{SIG}
9192 batch_size_settings:: BS
9293 batched_seeds:: S
9394 batched_results:: R
95+ seed_example:: SE
9496 hvp_prep:: E2
9597 gradient_prep:: E1
9698end
@@ -119,10 +121,17 @@ function _prepare_hessian_aux(
119121 ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
120122 ]
121123 batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
122- hvp_prep = prepare_hvp_nokwarg (strict, f, backend, x, batched_seeds[1 ], contexts... )
124+ seed_example = ntuple (b -> basis (x), Val (B))
125+ hvp_prep = prepare_hvp_nokwarg (strict, f, backend, x, seed_example, contexts... )
123126 gradient_prep = prepare_gradient_nokwarg (strict, f, inner (backend), x, contexts... )
124127 return HVPGradientHessianPrep (
125- _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep
128+ _sig,
129+ batch_size_settings,
130+ batched_seeds,
131+ batched_results,
132+ seed_example,
133+ hvp_prep,
134+ gradient_prep,
126135 )
127136end
128137
@@ -154,7 +163,7 @@ function hessian(
154163 (; A, B_last) = batch_size_settings
155164
156165 hvp_prep_same = prepare_hvp_same_point (
157- f, hvp_prep, backend, x, batched_seeds[ 1 ] , contexts...
166+ f, hvp_prep, backend, x, seed_example , contexts...
158167 )
159168
160169 hess = mapreduce (hcat, eachindex (batched_seeds)) do a
@@ -182,7 +191,7 @@ function hessian!(
182191 (; N) = batch_size_settings
183192
184193 hvp_prep_same = prepare_hvp_same_point (
185- f, hvp_prep, backend, x, batched_seeds[ 1 ] , contexts...
194+ f, hvp_prep, backend, x, seed_example , contexts...
186195 )
187196
188197 for a in eachindex (batched_seeds, batched_results)
0 commit comments