diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index c574a4a90..cd48a328b 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.44" +version = "0.6.45" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -22,6 +22,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -41,6 +42,7 @@ DifferentiationInterfaceMooncakeExt = "Mooncake" DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" +DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" DifferentiationInterfaceStaticArraysExt = "StaticArrays" DifferentiationInterfaceSymbolicsExt = "Symbolics" @@ -66,7 +68,7 @@ Mooncake = "0.4.88" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "<0.0.1,1" -SparseConnectivityTracer = "0.5.0,0.6" +SparseConnectivityTracer = "0.6.14" SparseMatrixColorings = "0.4.9" StaticArrays = "1.9.7" Symbolics = "5.27.1, 6" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl new file mode 100644 index 000000000..80508c511 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl @@ -0,0 +1,66 @@ +module DifferentiationInterfaceSparseConnectivityTracerExt + +using ADTypes: jacobian_sparsity, hessian_sparsity +import DifferentiationInterface as DI +using SparseConnectivityTracer: + TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer + +@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c) +@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray}) + return jacobian_buffer(DI.unwrap(c), detector) +end + +function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} + new_contexts = map(contexts) do c + _jacobian_translate(detector, c) + end + return new_contexts +end + +@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c) +@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray}) + return hessian_buffer(DI.unwrap(c), detector) +end + +function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} + new_contexts = map(contexts) do c + _hessian_translate(detector, c) + end + return new_contexts +end + +function DI.jacobian_sparsity_with_contexts( + f::F, + detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, + x, + contexts::Vararg{DI.Context,C}, +) where {F,C} + contexts_tracer = jacobian_translate(detector, contexts...) + fc = DI.FixTail(f, contexts_tracer) + return jacobian_sparsity(fc, x, detector) +end + +function DI.jacobian_sparsity_with_contexts( + f!::F, + y, + detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, + x, + contexts::Vararg{DI.Context,C}, +) where {F,C} + contexts_tracer = jacobian_translate(detector, contexts...) + fc! = DI.FixTail(f!, contexts_tracer) + return jacobian_sparsity(fc!, y, x, detector) +end + +function DI.hessian_sparsity_with_contexts( + f::F, + detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, + x, + contexts::Vararg{DI.Context,C}, +) where {F,C} + contexts_tracer = hessian_translate(detector, contexts...) + fc = DI.FixTail(f, contexts_tracer) + return hessian_sparsity(fc, x, detector) +end + +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl index 877ed05a2..ea8ef94b9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl @@ -1,13 +1,6 @@ module DifferentiationInterfaceSparseMatrixColoringsExt -using ADTypes: - ADTypes, - AutoSparse, - coloring_algorithm, - dense_ad, - sparsity_detector, - jacobian_sparsity, - hessian_sparsity +using ADTypes: ADTypes, AutoSparse, coloring_algorithm, dense_ad, sparsity_detector import DifferentiationInterface as DI using SparseMatrixColorings: AbstractColoringResult, @@ -22,14 +15,6 @@ using SparseMatrixColorings: decompress! import SparseMatrixColorings as SMC -function fycont(f, contexts::Vararg{DI.Context,C}) where {C} - return (DI.with_contexts(f, contexts...),) -end - -function fycont(f!, y, contexts::Vararg{DI.Context,C}) where {C} - return (DI.with_contexts(f!, contexts...), y) -end - abstract type SparseJacobianPrep <: DI.JacobianPrep end SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index d9e82bd33..61131dc71 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -27,8 +27,8 @@ function DI.prepare_hessian( f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C} ) where {F,C} dense_backend = dense_ad(backend) - sparsity = hessian_sparsity( - DI.with_contexts(f, contexts...), x, sparsity_detector(backend) + sparsity = DI.hessian_sparsity_with_contexts( + f, sparsity_detector(backend), x, contexts... ) problem = ColoringProblem{:symmetric,:column}() coloring_result = coloring( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 17928d44c..4193fd7a0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -58,8 +58,8 @@ function _prepare_sparse_jacobian_aux( contexts::Vararg{DI.Context,C}, ) where {FY,C} dense_backend = dense_ad(backend) - sparsity = jacobian_sparsity( - fycont(f_or_f!y..., contexts...)..., x, sparsity_detector(backend) + sparsity = DI.jacobian_sparsity_with_contexts( + f_or_f!y..., sparsity_detector(backend), x, contexts... ) if perf isa DI.PushforwardFast problem = ColoringProblem{:nonsymmetric,:column}() diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl index 4f1fc765b..4ab180aad 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -42,8 +42,8 @@ function _prepare_mixed_sparse_jacobian_aux( y, f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C} ) where {FY,C} dense_backend = dense_ad(backend) - sparsity = jacobian_sparsity( - fycont(f_or_f!y..., contexts...)..., x, sparsity_detector(backend) + sparsity = DI.jacobian_sparsity_with_contexts( + f_or_f!y..., sparsity_detector(backend), x, contexts... ) problem = ColoringProblem{:nonsymmetric,:bidirectional}() coloring_result = coloring( diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 99ecfc2d4..ec6be730f 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -8,13 +8,16 @@ module DifferentiationInterface using ADTypes: ADTypes, AbstractADType, + AbstractSparsityDetector, AutoSparse, ForwardMode, ForwardOrReverseMode, ReverseMode, SymbolicMode, dense_ad, - mode + mode, + jacobian_sparsity, + hessian_sparsity using ADTypes: AutoChainRules, AutoDiffractor, @@ -45,6 +48,7 @@ include("utils/check.jl") include("utils/printing.jl") include("utils/context.jl") include("utils/linalg.jl") +include("utils/sparse.jl") include("first_order/pushforward.jl") include("first_order/pullback.jl") diff --git a/DifferentiationInterface/src/utils/sparse.jl b/DifferentiationInterface/src/utils/sparse.jl new file mode 100644 index 000000000..dc5749138 --- /dev/null +++ b/DifferentiationInterface/src/utils/sparse.jl @@ -0,0 +1,17 @@ +function jacobian_sparsity_with_contexts( + f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} +) where {F,C} + return jacobian_sparsity(with_contexts(f, contexts...), x, detector) +end + +function jacobian_sparsity_with_contexts( + f!::F, y, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} +) where {F,C} + return jacobian_sparsity(with_contexts(f!, contexts...), y, x, detector) +end + +function hessian_sparsity_with_contexts( + f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} +) where {F,C} + return hessian_sparsity(with_contexts(f, contexts...), x, detector) +end diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index b72377b3b..e1c2974cd 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -57,7 +57,7 @@ end MyAutoSparse.( vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2])) ), - sparse_scenarios(; include_constantified=true); + sparse_scenarios(; include_constantified=true, include_cachified=true); sparsity=true, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/src/scenarios/sparse.jl b/DifferentiationInterfaceTest/src/scenarios/sparse.jl index 125ca0049..5f1cad7fa 100644 --- a/DifferentiationInterfaceTest/src/scenarios/sparse.jl +++ b/DifferentiationInterfaceTest/src/scenarios/sparse.jl @@ -324,7 +324,9 @@ end Create a vector of [`Scenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians. """ -function sparse_scenarios(; band_sizes=[5, 10, 20], include_constantified=false) +function sparse_scenarios(; + band_sizes=[5, 10, 20], include_constantified=false, include_cachified=false +) x_6 = float.(1:6) x_2_3 = float.(reshape(1:6, 2, 3)) x_50 = float.(range(1, 2, 50)) @@ -341,6 +343,10 @@ function sparse_scenarios(; band_sizes=[5, 10, 20], include_constantified=false) append!(scens, squarelinearmap_scenarios(x_50, band_sizes)) append!(scens, squarequadraticform_scenarios(x_50, band_sizes)) end - include_constantified && append!(scens, constantify(scens)) - return scens + + final_scens = Scenario[] + append!(final_scens, scens) + include_constantified && append!(final_scens, constantify(scens)) + include_cachified && append!(final_scens, cachify(scens)) + return final_scens end