Skip to content

Commit 5460c32

Browse files
authored
Add first sparsity support (#124)
1 parent 50e2a52 commit 5460c32

23 files changed

Lines changed: 334 additions & 101 deletions

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,12 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
8181
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8282
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
8383
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
84+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
8485
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
8586
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
8687
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8788
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
8889
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8990

9091
[targets]
91-
test = ["ADTypes", "Aqua", "Chairmarks", "DataFrames", "Diffractor", "Documenter", "Enzyme", "FastDifferentiation", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "PolyesterForwardDiff", "ReverseDiff", "SparseDiffTools", "Symbolics", "Test", "Tracker", "Zygote"]
92+
test = ["ADTypes", "Aqua", "Chairmarks", "DataFrames", "Diffractor", "Documenter", "Enzyme", "FastDifferentiation", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "PolyesterForwardDiff", "ReverseDiff", "SparseArrays", "SparseDiffTools", "Symbolics", "Test", "Tracker", "Zygote"]

ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,26 @@ module DifferentiationInterfaceSparseDiffToolsExt
22

33
using ADTypes
44
import DifferentiationInterface as DI
5-
using SparseDiffTools: JacPrototypeSparsityDetection, SymbolicsSparsityDetection
5+
using SparseDiffTools:
6+
AutoSparseEnzyme,
7+
JacPrototypeSparsityDetection,
8+
SymbolicsSparsityDetection,
9+
sparse_jacobian,
10+
sparse_jacobian!,
11+
sparse_jacobian_cache
612
using Symbolics: Symbolics
713

14+
# used with @eval to avoid Unions and thus ambiguities
15+
SPARSE_BACKENDS = [
16+
AutoSparseEnzyme,
17+
AutoSparseFiniteDiff,
18+
AutoSparseForwardDiff,
19+
AutoSparsePolyesterForwardDiff,
20+
AutoSparseReverseDiff,
21+
AutoSparseZygote,
22+
]
23+
24+
include("allocating.jl")
25+
include("mutating.jl")
26+
827
end
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
for AutoSparse in SPARSE_BACKENDS
2+
@eval begin
3+
4+
## Jacobian
5+
6+
function DI.prepare_jacobian(f, backend::$AutoSparse, x::AbstractArray)
7+
return sparse_jacobian_cache(
8+
backend, SymbolicsSparsityDetection(), f, x; fx=f(x)
9+
)
10+
end
11+
12+
function DI.value_and_jacobian!!(f, jac, backend::$AutoSparse, x, cache)
13+
sparse_jacobian!(jac, backend, cache, f, x)
14+
return f(x), jac
15+
end
16+
17+
function DI.jacobian!!(f, jac, backend::$AutoSparse, x, cache)
18+
sparse_jacobian!(jac, backend, cache, f, x)
19+
return jac
20+
end
21+
22+
function DI.value_and_jacobian(f, backend::$AutoSparse, x, cache)
23+
return f(x), sparse_jacobian(backend, cache, f, x)
24+
end
25+
26+
function DI.jacobian(f, backend::$AutoSparse, x, cache)
27+
return sparse_jacobian(backend, cache, f, x)
28+
end
29+
end
30+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
for AutoSparse in SPARSE_BACKENDS
2+
@eval begin
3+
## Jacobian
4+
5+
function DI.prepare_jacobian(
6+
f!, backend::$AutoSparse, y::AbstractArray, x::AbstractArray
7+
)
8+
return sparse_jacobian_cache(
9+
backend, SymbolicsSparsityDetection(), f!, similar(y), x
10+
)
11+
end
12+
13+
function DI.value_and_jacobian!!(f!, y, jac, backend::$AutoSparse, x, cache)
14+
sparse_jacobian!(jac, backend, cache, f!, y, x)
15+
f!(y, x)
16+
return y, jac
17+
end
18+
end
19+
end

ext/DifferentiationInterfaceTapirExt/DifferentiationInterfaceTapirExt.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,34 @@ using Tapir: CoDual, build_rrule, value_and_pullback!!, zero_codual
77

88
DI.supports_mutation(::AutoTapir) = DI.MutationNotSupported()
99

10+
function zero_sametype!!(x_target, x::Number)
11+
return zero(x)
12+
end
13+
14+
function zero_sametype!!(x_target, x::AbstractArray)
15+
x_sametype = convert(typeof(x), x_target)
16+
x_sametype .= zero(eltype(x))
17+
return x_sametype
18+
end
19+
1020
function DI.value_and_pullback(f, ::AutoTapir, x, dy, rrule)
1121
y = f(x)
1222
dy_righttype = convert(typeof(y), dy)
1323
_, (_, dx) = value_and_pullback!!(rrule, dy_righttype, f, x)
1424
return y, dx
1525
end
1626

17-
for op in [
18-
:pushforward,
19-
:pullback,
20-
:derivative,
21-
:gradient,
22-
:jacobian,
23-
:second_derivative,
24-
:hvp,
25-
:hessian,
26-
]
27+
function DI.value_and_pullback!!(f, dx, ::AutoTapir, x, dy, rrule)
28+
y = f(x)
29+
dy_righttype = convert(typeof(y), dy)
30+
dx_righttype = zero_sametype!!(dx, x)
31+
new_y, (_, new_dx) = value_and_pullback!!(
32+
rrule, dy_righttype, zero_codual(f), CoDual(x, dx_righttype)
33+
)
34+
return new_y, new_dx
35+
end
36+
37+
for op in [:pushforward, :pullback, :derivative, :gradient, :jacobian]
2738
prep_op = Symbol(:prepare_, op)
2839
@eval function DI.$prep_op(f, backend::AutoTapir, x)
2940
return build_rrule(f, x)

lib/DifferentiationInterfaceTest/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1313
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718

lib/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import DifferentiationInterface as DI
2929
using JET: @test_call, @test_opt
3030
using JLArrays: jl
3131
using LinearAlgebra: Diagonal, dot
32+
using SparseArrays: SparseArrays, nnz, SparseMatrixCSC
3233
using StaticArrays: SVector, SMatrix
3334
using Test: @testset, @test
3435

@@ -47,6 +48,7 @@ include("utils/filter.jl")
4748
include("tests/correctness.jl")
4849
include("tests/type_stability.jl")
4950
include("tests/call_count.jl")
51+
include("tests/sparsity.jl")
5052
include("tests/benchmark.jl")
5153
include("tests/test.jl")
5254

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
function test_sparsity(ba::AbstractADType, ::typeof(jacobian), scen::Scenario{false};)
2+
(; f, x, y, ref) = new_scen = deepcopy(scen)
3+
extras = prepare_jacobian(f, ba, x)
4+
jac_true = if ref isa AbstractADType
5+
jacobian(f, ref, x)
6+
else
7+
ref.jacobian(x)
8+
end
9+
10+
_, jac1 = value_and_jacobian(f, ba, x, extras)
11+
_, jac2 = value_and_jacobian!!(f, mysimilar(jac_true), ba, x, extras)
12+
13+
jac3 = jacobian(f, ba, x, extras)
14+
jac4 = jacobian!!(f, mysimilar(jac_true), ba, x, extras)
15+
16+
@testset "Sparse type" begin
17+
@test jac1 isa SparseMatrixCSC
18+
@test jac2 isa SparseMatrixCSC
19+
@test jac3 isa SparseMatrixCSC
20+
@test jac4 isa SparseMatrixCSC
21+
end
22+
@testset "Sparsity pattern" begin
23+
@test nnz(jac1) < length(jac_true)
24+
@test nnz(jac2) < length(jac_true)
25+
@test nnz(jac3) < length(jac_true)
26+
@test nnz(jac4) < length(jac_true)
27+
end
28+
return nothing
29+
end
30+
31+
function test_sparsity(ba::AbstractADType, ::typeof(jacobian), scen::Scenario{true};)
32+
(; f, x, y, dy, ref) = new_scen = deepcopy(scen)
33+
f! = f
34+
extras = prepare_jacobian(f!, ba, y, x)
35+
jac_shape = Matrix{eltype(y)}(undef, length(y), length(x))
36+
jac_true = if ref isa AbstractADType
37+
last(value_and_jacobian!!(f!, mysimilar(y), mysimilar(jac_shape), ref, x))
38+
else
39+
ref.jacobian(x)
40+
end
41+
42+
y10 = mysimilar(y)
43+
_, jac1 = value_and_jacobian!!(f!, y10, mysimilar(jac_true), ba, x, extras)
44+
45+
@testset "Sparse type" begin
46+
@test jac1 isa SparseMatrixCSC
47+
end
48+
@testset "Sparsity pattern" begin
49+
@test nnz(jac1) < length(jac_true)
50+
end
51+
return nothing
52+
end

lib/DifferentiationInterfaceTest/src/tests/test.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Testing:
1515
- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario. If a backend object like `correctness=AutoForwardDiff()` is passed instead of a boolean, the results will be compared using that reference backend as the ground truth.
1616
- `call_count=false`: whether to check that the function is called the right number of times
1717
- `type_stability=false`: whether to check type stability with JET.jl (thanks to `@test_opt`)
18+
- `sparsity`: whether to check sparsity of the jacobian / hessian
1819
- `detailed=false`: whether to print a detailed or condensed test log
1920
2021
Filtering:
@@ -42,6 +43,7 @@ function test_differentiation(
4243
correctness::Union{Bool,AbstractADType}=true,
4344
type_stability::Bool=false,
4445
call_count::Bool=false,
46+
sparsity::Bool=false,
4547
detailed=false,
4648
# filtering
4749
input_type::Type=Any,
@@ -68,7 +70,8 @@ function test_differentiation(
6870
"Differentiation tests -" *
6971
(correctness != false ? " correctness" : "") *
7072
(call_count ? " calls" : "") *
71-
(type_stability ? " types" : "")
73+
(type_stability ? " types" : "") *
74+
(sparsity ? " sparsity" : "")
7275

7376
@testset verbose = true "$title" begin
7477
@testset verbose = detailed "$(backend_string(backend))" for backend in backends
@@ -77,7 +80,7 @@ function test_differentiation(
7780
compatible(backend, op, scen)
7881
end
7982
logging &&
80-
@info "Testing: $(backend_string(backend)) - $op - $(string(scen))"
83+
@info "$title: $(backend_string(backend)) - $op - $(string(scen))"
8184
correctness != false && @testset "Correctness" begin
8285
test_correctness(backend, op, scen; isapprox, atol, rtol)
8386
end
@@ -87,6 +90,9 @@ function test_differentiation(
8790
type_stability && @testset "Type stability" begin
8891
test_jet(backend, op, scen)
8992
end
93+
sparsity && @testset "Sparsity" begin
94+
test_sparsity(backend, op, scen)
95+
end
9096
end
9197
end
9298
end

src/DifferentiationInterface.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ struct AutoTapir <: AbstractReverseMode end
5353
include("second_order.jl")
5454
include("traits.jl")
5555
include("utils.jl")
56-
include("prepare.jl")
5756

5857
include("pushforward.jl")
5958
include("pullback.jl")

0 commit comments

Comments
 (0)