@@ -7,6 +7,8 @@ struct SMCMixedModeSparseJacobianPrep{
77 P<: AbstractMatrix ,
88 C<: AbstractColoringResult{:nonsymmetric,:bidirectional} ,
99 M<: AbstractMatrix{<:Number} ,
10+ Sfp<: NTuple ,
11+ Srp<: NTuple ,
1012 Sf<: Vector{<:NTuple} ,
1113 Sr<: Vector{<:NTuple} ,
1214 Rf<: Vector{<:NTuple} ,
@@ -21,6 +23,8 @@ struct SMCMixedModeSparseJacobianPrep{
2123 coloring_result:: C
2224 compressed_matrix_forward:: M
2325 compressed_matrix_reverse:: M
26+ batched_seed_forward_prep:: Sfp
27+ batched_seed_reverse_prep:: Srp
2428 batched_seeds_forward:: Sf
2529 batched_seeds_reverse:: Sr
2630 batched_results_forward:: Rf
@@ -111,12 +115,24 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
111115 groups_forward = column_groups (coloring_result)
112116 groups_reverse = row_groups (coloring_result)
113117
118+ seed_forward_prep = DI. multibasis (x, eachindex (x))
119+ seed_reverse_prep = DI. multibasis (y, eachindex (y))
114120 seeds_forward = [DI. multibasis (x, eachindex (x)[group]) for group in groups_forward]
115121 seeds_reverse = [DI. multibasis (y, eachindex (y)[group]) for group in groups_reverse]
116122
117- compressed_matrix_forward = stack (_ -> vec (similar (y)), groups_forward; dims= 2 )
118- compressed_matrix_reverse = stack (_ -> vec (similar (x)), groups_reverse; dims= 1 )
123+ compressed_matrix_forward = if isempty (groups_forward)
124+ similar (vec (y), length (y), 0 )
125+ else
126+ stack (_ -> vec (similar (y)), groups_forward; dims= 2 )
127+ end
128+ compressed_matrix_reverse = if isempty (groups_reverse)
129+ similar (vec (x), 0 , length (x))
130+ else
131+ stack (_ -> vec (similar (x)), groups_reverse; dims= 1 )
132+ end
119133
134+ batched_seed_forward_prep = ntuple (b -> copy (seed_forward_prep), Val (Bf))
135+ batched_seed_reverse_prep = ntuple (b -> copy (seed_reverse_prep), Val (Br))
120136 batched_seeds_forward = [
121137 ntuple (b -> seeds_forward[1 + ((a - 1 ) * Bf + (b - 1 )) % Nf], Val (Bf)) for a in 1 : Af
122138 ]
@@ -136,15 +152,15 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
136152 f_or_f!y... ,
137153 DI. forward_backend (dense_backend),
138154 x,
139- batched_seeds_forward[ 1 ] ,
155+ batched_seed_forward_prep ,
140156 contexts... ;
141157 )
142158 pullback_prep = DI. prepare_pullback_nokwarg (
143159 strict,
144160 f_or_f!y... ,
145161 DI. reverse_backend (dense_backend),
146162 x,
147- batched_seeds_reverse[ 1 ] ,
163+ batched_seed_reverse_prep ,
148164 contexts... ;
149165 )
150166
@@ -156,6 +172,8 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
156172 coloring_result,
157173 compressed_matrix_forward,
158174 compressed_matrix_reverse,
175+ batched_seed_forward_prep,
176+ batched_seed_reverse_prep,
159177 batched_seeds_forward,
160178 batched_seeds_reverse,
161179 batched_results_forward,
@@ -183,6 +201,8 @@ function _sparse_jacobian_aux!(
183201 coloring_result,
184202 compressed_matrix_forward,
185203 compressed_matrix_reverse,
204+ batched_seed_forward_prep,
205+ batched_seed_reverse_prep,
186206 batched_seeds_forward,
187207 batched_seeds_reverse,
188208 batched_results_forward,
@@ -200,15 +220,15 @@ function _sparse_jacobian_aux!(
200220 pushforward_prep,
201221 DI. forward_backend (dense_backend),
202222 x,
203- batched_seeds_forward[ 1 ] ,
223+ batched_seed_forward_prep ,
204224 contexts... ,
205225 )
206226 pullback_prep_same = DI. prepare_pullback_same_point (
207227 f_or_f!y... ,
208228 pullback_prep,
209229 DI. reverse_backend (dense_backend),
210230 x,
211- batched_seeds_reverse[ 1 ] ,
231+ batched_seed_reverse_prep ,
212232 contexts... ,
213233 )
214234
0 commit comments