diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 495b31b10..f5d674cc4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -4,6 +4,7 @@ struct SMCSparseHessianPrep{ P <: AbstractMatrix, C <: AbstractColoringResult{:symmetric, :column}, M <: AbstractMatrix{<:Number}, + Sp <: NTuple, S <: AbstractVector{<:NTuple}, R <: AbstractVector{<:NTuple}, E2 <: DI.HVPPrep, @@ -14,6 +15,7 @@ struct SMCSparseHessianPrep{ sparsity::P coloring_result::C compressed_matrix::M + batched_seed_prep::Sp batched_seeds::S batched_results::R hvp_prep::E2 @@ -54,14 +56,20 @@ function _prepare_sparse_hessian_aux( (; N, A) = batch_size_settings dense_backend = dense_ad(backend) groups = column_groups(coloring_result) + seed_prep = DI.multibasis(x, eachindex(x)) seeds = [DI.multibasis(x, eachindex(x)[group]) for group in groups] - compressed_matrix = stack(_ -> vec(similar(x)), groups; dims = 2) + compressed_matrix = if isempty(groups) + similar(x, length(x), 0) + else + stack(_ -> vec(similar(x)), groups; dims = 2) + end + batched_seed_prep = ntuple(b -> copy(seed_prep), Val(B)) batched_seeds = [ ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] hvp_prep = DI.prepare_hvp_nokwarg( - strict, f, dense_backend, x, batched_seeds[1], contexts... + strict, f, dense_backend, x, batched_seed_prep, contexts... ) gradient_prep = DI.prepare_gradient_nokwarg( strict, f, DI.inner(dense_backend), x, contexts... @@ -72,6 +80,7 @@ function _prepare_sparse_hessian_aux( sparsity, coloring_result, compressed_matrix, + batched_seed_prep, batched_seeds, batched_results, hvp_prep, @@ -92,6 +101,7 @@ function DI.hessian!( batch_size_settings, coloring_result, compressed_matrix, + batched_seed_prep, batched_seeds, batched_results, hvp_prep, @@ -100,7 +110,7 @@ function DI.hessian!( dense_backend = dense_ad(backend) hvp_prep_same = DI.prepare_hvp_same_point( - f, hvp_prep, dense_backend, x, batched_seeds[1], contexts... + f, hvp_prep, dense_backend, x, batched_seed_prep, contexts... ) for a in eachindex(batched_seeds, batched_results) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 0d8d0f703..1bd40bfa6 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -136,7 +136,16 @@ end @test only(column_groups(hess_prep)) == 1:10 end - @testset "Empty colors for mixed mode" begin # issue 857 + @testset "Empty color groups in sparse AD" begin # issue 857 + # forward + backend = MyAutoSparse(adaptive_backends[1]) + @test jacobian(zero, backend, ones(10)) isa AbstractMatrix + @test hessian(sum ∘ zero, backend, ones(10)) isa AbstractMatrix + # reverse + backend = MyAutoSparse(adaptive_backends[2]) + @test jacobian(zero, backend, ones(10)) isa AbstractMatrix + @test hessian(sum ∘ zero, backend, ones(10)) isa AbstractMatrix + # mixed backend = MyAutoSparse(MixedMode(adaptive_backends[1], adaptive_backends[2])) @test jacobian(copyto!, zeros(10), backend, ones(10)) isa AbstractMatrix end