Skip to content

Commit 1f2cc3a

Browse files
committed
fix: handle empty row or column colors in mixed mode sparse Jacobian
1 parent fca2ad3 commit 1f2cc3a

3 files changed

Lines changed: 32 additions & 7 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ struct SMCMixedModeSparseJacobianPrep{
77
P<:AbstractMatrix,
88
C<:AbstractColoringResult{:nonsymmetric,:bidirectional},
99
M<:AbstractMatrix{<:Number},
10+
Sfp<:NTuple,
11+
Srp<:NTuple,
1012
Sf<:Vector{<:NTuple},
1113
Sr<:Vector{<:NTuple},
1214
Rf<:Vector{<:NTuple},
@@ -21,6 +23,8 @@ struct SMCMixedModeSparseJacobianPrep{
2123
coloring_result::C
2224
compressed_matrix_forward::M
2325
compressed_matrix_reverse::M
26+
batched_seed_forward_prep::Sfp
27+
batched_seed_reverse_prep::Srp
2428
batched_seeds_forward::Sf
2529
batched_seeds_reverse::Sr
2630
batched_results_forward::Rf
@@ -111,12 +115,24 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
111115
groups_forward = column_groups(coloring_result)
112116
groups_reverse = row_groups(coloring_result)
113117

118+
seed_forward_prep = DI.multibasis(x, eachindex(x))
119+
seed_reverse_prep = DI.multibasis(y, eachindex(y))
114120
seeds_forward = [DI.multibasis(x, eachindex(x)[group]) for group in groups_forward]
115121
seeds_reverse = [DI.multibasis(y, eachindex(y)[group]) for group in groups_reverse]
116122

117-
compressed_matrix_forward = stack(_ -> vec(similar(y)), groups_forward; dims=2)
118-
compressed_matrix_reverse = stack(_ -> vec(similar(x)), groups_reverse; dims=1)
123+
compressed_matrix_forward = if isempty(groups_forward)
124+
similar(vec(y), length(y), 0)
125+
else
126+
stack(_ -> vec(similar(y)), groups_forward; dims=2)
127+
end
128+
compressed_matrix_reverse = if isempty(groups_reverse)
129+
similar(vec(x), 0, length(x))
130+
else
131+
stack(_ -> vec(similar(x)), groups_reverse; dims=1)
132+
end
119133

134+
batched_seed_forward_prep = ntuple(b -> copy(seed_forward_prep), Val(Bf))
135+
batched_seed_reverse_prep = ntuple(b -> copy(seed_reverse_prep), Val(Br))
120136
batched_seeds_forward = [
121137
ntuple(b -> seeds_forward[1 + ((a - 1) * Bf + (b - 1)) % Nf], Val(Bf)) for a in 1:Af
122138
]
@@ -136,15 +152,15 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
136152
f_or_f!y...,
137153
DI.forward_backend(dense_backend),
138154
x,
139-
batched_seeds_forward[1],
155+
batched_seed_forward_prep,
140156
contexts...;
141157
)
142158
pullback_prep = DI.prepare_pullback_nokwarg(
143159
strict,
144160
f_or_f!y...,
145161
DI.reverse_backend(dense_backend),
146162
x,
147-
batched_seeds_reverse[1],
163+
batched_seed_reverse_prep,
148164
contexts...;
149165
)
150166

@@ -156,6 +172,8 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
156172
coloring_result,
157173
compressed_matrix_forward,
158174
compressed_matrix_reverse,
175+
batched_seed_forward_prep,
176+
batched_seed_reverse_prep,
159177
batched_seeds_forward,
160178
batched_seeds_reverse,
161179
batched_results_forward,
@@ -183,6 +201,8 @@ function _sparse_jacobian_aux!(
183201
coloring_result,
184202
compressed_matrix_forward,
185203
compressed_matrix_reverse,
204+
batched_seed_forward_prep,
205+
batched_seed_reverse_prep,
186206
batched_seeds_forward,
187207
batched_seeds_reverse,
188208
batched_results_forward,
@@ -200,15 +220,15 @@ function _sparse_jacobian_aux!(
200220
pushforward_prep,
201221
DI.forward_backend(dense_backend),
202222
x,
203-
batched_seeds_forward[1],
223+
batched_seed_forward_prep,
204224
contexts...,
205225
)
206226
pullback_prep_same = DI.prepare_pullback_same_point(
207227
f_or_f!y...,
208228
pullback_prep,
209229
DI.reverse_backend(dense_backend),
210230
x,
211-
batched_seeds_reverse[1],
231+
batched_seed_reverse_prep,
212232
contexts...,
213233
)
214234

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ end
128128
@test only(row_groups(jac_rev_prep)) == 1:10
129129
@test only(column_groups(hess_prep)) == 1:10
130130
end
131+
132+
@testset "Empty colors for mixed mode" begin # issue 857
133+
backend = MyAutoSparse(MixedMode(adaptive_backends[1], adaptive_backends[2]))
134+
@test jacobian(copyto!, zeros(10), backend, ones(10)) isa AbstractMatrix
135+
end
131136
end
132137

133138
@testset "Misc" begin

DifferentiationInterface/test/testutils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function MyAutoSparse(backend::AbstractADType)
1717
return AutoSparse(
1818
backend;
1919
sparsity_detector=TracerSparsityDetector(),
20-
coloring_algorithm=GreedyColoringAlgorithm(),
20+
coloring_algorithm=GreedyColoringAlgorithm(; postprocessing=true),
2121
)
2222
end
2323

0 commit comments

Comments
 (0)