Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.44"
version = "0.6.45"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -22,6 +22,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Expand All @@ -41,6 +42,7 @@ DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceStaticArraysExt = "StaticArrays"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
Expand All @@ -66,7 +68,7 @@ Mooncake = "0.4.88"
PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
SparseConnectivityTracer = "0.6.14"
SparseMatrixColorings = "0.4.9"
StaticArrays = "1.9.7"
Symbolics = "5.27.1, 6"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
module DifferentiationInterfaceSparseConnectivityTracerExt

using ADTypes: jacobian_sparsity, hessian_sparsity
import DifferentiationInterface as DI
using SparseConnectivityTracer:
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer

@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c)
@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray})
return jacobian_buffer(DI.unwrap(c), detector)

Check warning on line 10 in DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl#L8-L10

Added lines #L8 - L10 were not covered by tests
end

function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
new_contexts = map(contexts) do c
_jacobian_translate(detector, c)

Check warning on line 15 in DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl#L15

Added line #L15 was not covered by tests
end
return new_contexts
end

@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c)
@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray})
return hessian_buffer(DI.unwrap(c), detector)

Check warning on line 22 in DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl#L20-L22

Added lines #L20 - L22 were not covered by tests
end

function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
new_contexts = map(contexts) do c
_hessian_translate(detector, c)

Check warning on line 27 in DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl#L27

Added line #L27 was not covered by tests
end
return new_contexts
end

function DI.jacobian_sparsity_with_contexts(
f::F,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, contexts...)
fc = DI.FixTail(f, contexts_tracer)
return jacobian_sparsity(fc, x, detector)
end

function DI.jacobian_sparsity_with_contexts(
f!::F,
y,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, contexts...)
fc! = DI.FixTail(f!, contexts_tracer)
return jacobian_sparsity(fc!, y, x, detector)
end

function DI.hessian_sparsity_with_contexts(
f::F,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = hessian_translate(detector, contexts...)
fc = DI.FixTail(f, contexts_tracer)
return hessian_sparsity(fc, x, detector)
end

end
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
module DifferentiationInterfaceSparseMatrixColoringsExt

using ADTypes:
ADTypes,
AutoSparse,
coloring_algorithm,
dense_ad,
sparsity_detector,
jacobian_sparsity,
hessian_sparsity
using ADTypes: ADTypes, AutoSparse, coloring_algorithm, dense_ad, sparsity_detector
import DifferentiationInterface as DI
using SparseMatrixColorings:
AbstractColoringResult,
Expand All @@ -22,14 +15,6 @@ using SparseMatrixColorings:
decompress!
import SparseMatrixColorings as SMC

function fycont(f, contexts::Vararg{DI.Context,C}) where {C}
return (DI.with_contexts(f, contexts...),)
end

function fycont(f!, y, contexts::Vararg{DI.Context,C}) where {C}
return (DI.with_contexts(f!, contexts...), y)
end

abstract type SparseJacobianPrep <: DI.JacobianPrep end

SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ function DI.prepare_hessian(
f::F, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
) where {F,C}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(
DI.with_contexts(f, contexts...), x, sparsity_detector(backend)
sparsity = DI.hessian_sparsity_with_contexts(
f, sparsity_detector(backend), x, contexts...
)
problem = ColoringProblem{:symmetric,:column}()
coloring_result = coloring(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ function _prepare_sparse_jacobian_aux(
contexts::Vararg{DI.Context,C},
) where {FY,C}
dense_backend = dense_ad(backend)
sparsity = jacobian_sparsity(
fycont(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
sparsity = DI.jacobian_sparsity_with_contexts(
f_or_f!y..., sparsity_detector(backend), x, contexts...
)
if perf isa DI.PushforwardFast
problem = ColoringProblem{:nonsymmetric,:column}()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
y, f_or_f!y::FY, backend::AutoSparse{<:DI.MixedMode}, x, contexts::Vararg{DI.Context,C}
) where {FY,C}
dense_backend = dense_ad(backend)
sparsity = jacobian_sparsity(
fycont(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
sparsity = DI.jacobian_sparsity_with_contexts(

Check warning on line 45 in DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl#L45

Added line #L45 was not covered by tests
f_or_f!y..., sparsity_detector(backend), x, contexts...
)
problem = ColoringProblem{:nonsymmetric,:bidirectional}()
coloring_result = coloring(
Expand Down
6 changes: 5 additions & 1 deletion DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ module DifferentiationInterface
using ADTypes:
ADTypes,
AbstractADType,
AbstractSparsityDetector,
AutoSparse,
ForwardMode,
ForwardOrReverseMode,
ReverseMode,
SymbolicMode,
dense_ad,
mode
mode,
jacobian_sparsity,
hessian_sparsity
using ADTypes:
AutoChainRules,
AutoDiffractor,
Expand Down Expand Up @@ -45,6 +48,7 @@ include("utils/check.jl")
include("utils/printing.jl")
include("utils/context.jl")
include("utils/linalg.jl")
include("utils/sparse.jl")

include("first_order/pushforward.jl")
include("first_order/pullback.jl")
Expand Down
17 changes: 17 additions & 0 deletions DifferentiationInterface/src/utils/sparse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
function jacobian_sparsity_with_contexts(
f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C}
) where {F,C}
return jacobian_sparsity(with_contexts(f, contexts...), x, detector)
end

function jacobian_sparsity_with_contexts(
f!::F, y, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C}
) where {F,C}
return jacobian_sparsity(with_contexts(f!, contexts...), y, x, detector)
end

function hessian_sparsity_with_contexts(
f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C}
) where {F,C}
return hessian_sparsity(with_contexts(f, contexts...), x, detector)
end
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ end
MyAutoSparse.(
vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2]))
),
sparse_scenarios(; include_constantified=true);
sparse_scenarios(; include_constantified=true, include_cachified=true);
sparsity=true,
logging=LOGGING,
)
Expand Down
12 changes: 9 additions & 3 deletions DifferentiationInterfaceTest/src/scenarios/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ end

Create a vector of [`Scenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians.
"""
function sparse_scenarios(; band_sizes=[5, 10, 20], include_constantified=false)
function sparse_scenarios(;
band_sizes=[5, 10, 20], include_constantified=false, include_cachified=false
)
x_6 = float.(1:6)
x_2_3 = float.(reshape(1:6, 2, 3))
x_50 = float.(range(1, 2, 50))
Expand All @@ -341,6 +343,10 @@ function sparse_scenarios(; band_sizes=[5, 10, 20], include_constantified=false)
append!(scens, squarelinearmap_scenarios(x_50, band_sizes))
append!(scens, squarequadraticform_scenarios(x_50, band_sizes))
end
include_constantified && append!(scens, constantify(scens))
return scens

final_scens = Scenario[]
append!(final_scens, scens)
include_constantified && append!(final_scens, constantify(scens))
include_cachified && append!(final_scens, cachify(scens))
return final_scens
end
Loading