Skip to content

Commit 9fcfadd

Browse files
authored
Support sparse hessians (#128)
* Start sparsity for hessian * Add sparse hessian * Remove extras second_derivative * Improve docs * Remove Test from docs
1 parent 4cfd3d0 commit 9fcfadd

11 files changed

Lines changed: 163 additions & 103 deletions

File tree

Project.toml

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,25 @@ DifferentiationInterfaceZygoteExt = "Zygote"
4343

4444
[compat]
4545
ADTypes = "0.2.7"
46-
AbstractDifferentiation = "0.6"
47-
ChainRulesCore = "1.19"
48-
Diffractor = "0.2"
49-
DocStringExtensions = "0.9"
50-
Enzyme = "0.11"
51-
FastDifferentiation = "0.3"
52-
FillArrays = "1"
53-
FiniteDiff = "2.22"
54-
FiniteDifferences = "0.12"
55-
ForwardDiff = "0.10"
46+
AbstractDifferentiation = "0.6.2"
47+
ChainRulesCore = "1.23.0"
48+
Diffractor = "0.2.6"
49+
DocStringExtensions = "0.9.3"
50+
Enzyme = "0.11.20"
51+
FastDifferentiation = "0.3.7"
52+
FillArrays = "1.9.3"
53+
FiniteDiff = "2.23.0"
54+
FiniteDifferences = "0.12.31"
55+
ForwardDiff = "0.10.36"
5656
LinearAlgebra = "1"
57-
PolyesterForwardDiff = "0.1"
58-
ReverseDiff = "1.15"
59-
SparseDiffTools = "2.17"
60-
Symbolics = "5.27"
61-
Tapir = "0.1"
57+
PolyesterForwardDiff = "0.1.1"
58+
ReverseDiff = "1.15.1"
59+
SparseDiffTools = "2.17.0"
60+
Symbolics = "5.27.1"
61+
Tapir = "0.1.2"
6262
Test = "1"
63-
Tracker = "0.2"
64-
Zygote = "0.6"
63+
Tracker = "0.2.33"
64+
Zygote = "0.6.69"
6565
julia = "1.10"
6666

6767
[extras]

README.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,19 @@ This package provides a backend-agnostic syntax to differentiate functions of th
2626

2727
We support most of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):
2828

29-
| backend | object |
30-
| :------------------------------------------------------------------------------ | :----------------------------------------------------------- |
31-
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
32-
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
33-
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Enzyme.Forward)` or `AutoEnzyme(Enzyme.Reverse)` |
34-
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
35-
| [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) | `AutoFiniteDifferences(fdm)` |
36-
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
37-
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |
38-
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
39-
| [Tracker.jl](https://github.com/FluxML/Tracker.jl) | `AutoTracker()` |
40-
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |
29+
****| backend | object |
30+
| :------------------------------------------------------------------------------ | :--------------------------------------------------------- |
31+
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
32+
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
33+
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Enzyme.Forward)`, `AutoEnzyme(Enzyme.Reverse)` |
34+
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
35+
| [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) | `AutoFiniteDifferences(fdm)` |
36+
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
37+
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |
38+
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
39+
| [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl) | `AutoSparseForwardDiff()`, `AutoSparseFiniteDiff()` |
40+
| [Tracker.jl](https://github.com/FluxML/Tracker.jl) | `AutoTracker()` |
41+
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |
4142

4243
We also support additional (experimental) backends:
4344

docs/src/backends.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ AutoZygote
6363

6464
### Sparse
6565

66-
!!! danger
67-
Sparsity support is still experimental, use at your own risk.
68-
6966
```@docs
7067
AutoSparseFastDifferentiation
7168
AutoSparseFiniteDiff
@@ -103,3 +100,22 @@ rows = map(all_backends()) do backend # hide
103100
end # hide
104101
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
105102
```
103+
104+
## Hessian support
105+
106+
Only some backends are able to compute Hessians.
107+
You can use [`check_hessian`](@ref) to check that feature, like we did below:
108+
109+
```@example backends
110+
header = "| backend | Hessian |" # hide
111+
subheader = "|---|---|" # hide
112+
rows = map(all_backends()) do backend # hide
113+
"| `$(backend_string(backend))` | $(check_hessian(backend) ? '✅' : '❌') |" # hide
114+
end # hide
115+
Markdown.parse(join(vcat(header, subheader, rows...), "\n")) # hide
116+
```
117+
118+
!!! warning
119+
Second-order operators can also be used with a combination of backends inside the [`SecondOrder`](@ref) struct.
120+
There are many possible combinations, a lot of which will fail.
121+
Due to compilation overhead, we do not currently test them all to display the working ones in the documentation, but we might if users deem it relevant.

docs/src/overview.md

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,14 @@ Several variants of each operator are defined:
4949
# mistakenly keep working with grad_in: NOT OK
5050
```
5151
Note that we don't guarantee `grad_out` will have the same type as `grad_in`.
52+
Its type can even depend on the choice of backend.
5253

5354
## Second order
5455

55-
Second-order differentiation is also supported, with the following operators:
56+
Second-order differentiation is also supported.
57+
You can either pick a single backend to do all the work, or combine an "outer" backend with an "inner" backend using the [`SecondOrder`](@ref) struct, like so: `SecondOrder(outer, inner)`.
58+
59+
The available operators are similar to first-order ones:
5660

5761
| operator | input `x` | output `y` | result type | result shape |
5862
| --------------------------- | --------------- | ------------ | ---------------- | ------------------------ |
@@ -97,9 +101,24 @@ By default, all the preparation functions return `nothing`.
97101
We do not make any guarantees on their implementation for each backend, or on the performance gains that can be expected.
98102

99103
!!! warning
100-
We haven't fully figured out what must happen when an `extras` object is prepared for a specific operator but then given to a lower-level one (i.e. prepare it for `jacobian` but then give it to `pushforward` inside `jacobian`).
104+
We haven't yet figured out how to deal with extras for second-order operators, because closures make our life rather complicated.
105+
For now, consider that preparation doesn't work there in general, although some individual backends may be okay already.
106+
107+
## FAQ
101108

102-
## Multiple inputs/outputs
109+
### Multiple inputs/outputs
103110

104111
Restricting the API to one input and one output has many coding advantages, but it is not very flexible.
105112
If you need more than that, use [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap several objects inside a single `ComponentVector`.
113+
114+
### Sparsity
115+
116+
If you need to work with sparse Jacobians, you can pick one of the [sparse backends](@ref Sparse) from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
117+
The sparsity pattern is computed automatically with [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) during the preparation step.
118+
119+
If you need to work with sparse Hessians, you can use a sparse backend as the _outer_ backend of a `SecondOrder`.
120+
This means the Hessian is obtained as the sparse Jacobian of the gradient.
121+
Since preparation does not yet work for second order, the sparsity pattern is currently recomputed every time, so you may not gain much time as things stand.
122+
123+
!!! danger
124+
Sparsity support is still experimental, use at your own risk.

ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DifferentiationInterfaceSparseDiffToolsExt
22

33
using ADTypes
44
import DifferentiationInterface as DI
5-
using DifferentiationInterface: JacobianExtras
5+
using DifferentiationInterface: JacobianExtras, NoHessianExtras, SecondOrder, inner, outer
66
using SparseDiffTools:
77
AutoSparseEnzyme,
88
JacPrototypeSparsityDetection,

ext/DifferentiationInterfaceSparseDiffToolsExt/allocating.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,21 @@ for AutoSparse in SPARSE_BACKENDS
3939
)
4040
return sparse_jacobian(backend, extras.cache, f, x)
4141
end
42+
43+
## Hessian
44+
45+
DI.prepare_hessian(f, ::SecondOrder{<:$AutoSparse}, x) = NoHessianExtras()
46+
47+
function DI.hessian(f, backend::SecondOrder{<:$AutoSparse}, x, ::NoHessianExtras)
48+
gradient_closure(z) = DI.gradient(f, inner(backend), z)
49+
return DI.jacobian(gradient_closure, outer(backend), x)
50+
end
51+
52+
function DI.hessian!!(
53+
f, hess, backend::SecondOrder{<:$AutoSparse}, x, ::NoHessianExtras
54+
)
55+
gradient_closure(z) = DI.gradient(f, inner(backend), z)
56+
return DI.jacobian!!(gradient_closure, hess, outer(backend), x)
57+
end
4258
end
4359
end

lib/DifferentiationInterfaceTest/src/tests/sparsity.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
## Jacobian
2+
13
function test_sparsity(ba::AbstractADType, scen::JacobianScenario{false}; ref_backend)
24
(; f, x, y) = new_scen = deepcopy(scen)
35
extras = prepare_jacobian(f, ba, x)
@@ -50,3 +52,28 @@ function test_sparsity(ba::AbstractADType, scen::JacobianScenario{true}; ref_bac
5052
end
5153
return nothing
5254
end
55+
56+
## Hessian
57+
58+
function test_sparsity(ba::AbstractADType, scen::HessianScenario{false}; ref_backend)
59+
(; f, x, y) = new_scen = deepcopy(scen)
60+
extras = prepare_hessian(f, ba, x)
61+
hess_true = if ref_backend isa AbstractADType
62+
hessian(f, ref_backend, x)
63+
else
64+
new_scen.ref(x)
65+
end
66+
67+
hess1 = hessian(f, ba, x, extras)
68+
hess2 = hessian!!(f, mysimilar(hess_true), ba, x, extras)
69+
70+
@testset "Sparse type" begin
71+
@test hess1 isa SparseMatrixCSC
72+
@test hess2 isa SparseMatrixCSC
73+
end
74+
@testset "Sparsity pattern" begin
75+
@test nnz(hess1) < length(hess_true)
76+
@test nnz(hess2) < length(hess_true)
77+
end
78+
return nothing
79+
end

src/hvp.jl

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -47,46 +47,30 @@ end
4747

4848
function hvp_aux(f, backend, x, v, extras, ::ForwardOverReverse)
4949
# JVP of the gradient
50-
function gradient_closure(z)
51-
inner_extras = prepare_gradient(extras, f, inner(backend), z)
52-
return gradient(f, inner(backend), z, inner_extras)
53-
end
54-
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
55-
p = pushforward(gradient_closure, outer(backend), x, v, outer_extras)
50+
gradient_closure(z) = gradient(f, inner(backend), z)
51+
p = pushforward(gradient_closure, outer(backend), x, v)
5652
return p
5753
end
5854

5955
function hvp_aux(f, backend, x, v, extras, ::ReverseOverForward)
6056
# gradient of the JVP
61-
function jvp_closure(z)
62-
inner_extras = prepare_pushforward(extras, f, inner(backend), z)
63-
return pushforward(f, inner(backend), z, v, inner_extras)
64-
end
65-
outer_extras = prepare_gradient(extras, jvp_closure, outer(backend), x)
66-
p = gradient(jvp_closure, outer(backend), x, outer_extras)
57+
pushforward_closure(z) = pushforward(f, inner(backend), z, v)
58+
p = gradient(pushforward_closure, outer(backend), x)
6759
return p
6860
end
6961

7062
function hvp_aux(f, backend, x, v, extras, ::ReverseOverReverse)
7163
# VJP of the gradient
72-
function gradient_closure(z)
73-
inner_extras = prepare_gradient(extras, f, inner(backend), z)
74-
return gradient(f, inner(backend), z, inner_extras)
75-
end
76-
outer_extras = prepare_pullback(extras, gradient_closure, outer(backend), x)
77-
p = pullback(gradient_closure, outer(backend), x, v, outer_extras)
64+
gradient_closure(z) = gradient(f, inner(backend), z)
65+
p = pullback(gradient_closure, outer(backend), x, v)
7866
return p
7967
end
8068

8169
function hvp_aux(f, backend, x, v, extras, ::ForwardOverForward)
8270
# JVPs of JVPs in theory
8371
# also pushforward of gradient in practice
84-
function gradient_closure(z)
85-
inner_extras = prepare_gradient(extras, f, inner(backend), z)
86-
return gradient(f, inner(backend), z, inner_extras)
87-
end
88-
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
89-
p = pushforward(gradient_closure, outer(backend), x, v, outer_extras)
72+
gradient_closure(z) = gradient(f, inner(backend), z)
73+
p = pushforward(gradient_closure, outer(backend), x, v)
9074
return p
9175
end
9276

@@ -108,41 +92,25 @@ function hvp!!(
10892
end
10993

11094
function hvp_aux!!(f, p, backend, x, v, extras, ::ForwardOverReverse)
111-
function gradient_closure(z)
112-
inner_extras = prepare_gradient(extras, f, inner(backend), z)
113-
return gradient(f, inner(backend), z, inner_extras)
114-
end
115-
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
116-
p = pushforward!!(gradient_closure, p, outer(backend), x, v, outer_extras)
95+
gradient_closure(z) = gradient(f, inner(backend), z)
96+
p = pushforward!!(gradient_closure, p, outer(backend), x, v)
11797
return p
11898
end
11999

120100
function hvp_aux!!(f, p, backend, x, v, extras, ::ReverseOverForward)
121-
function jvp_closure(z)
122-
inner_extras = prepare_pushforward(extras, f, inner(backend), z)
123-
return pushforward(f, inner(backend), z, v, inner_extras)
124-
end
125-
outer_extras = prepare_gradient(extras, jvp_closure, outer(backend), x)
126-
p = gradient!!(jvp_closure, p, outer(backend), x, outer_extras)
101+
pushforward_closure(z) = pushforward(f, inner(backend), z, v)
102+
p = gradient!!(pushforward_closure, p, outer(backend), x)
127103
return p
128104
end
129105

130106
function hvp_aux!!(f, p, backend, x, v, extras, ::ReverseOverReverse)
131-
function gradient_closure(z)
132-
inner_extras = prepare_gradient(extras, f, inner(backend), z)
133-
return gradient(f, inner(backend), z, inner_extras)
134-
end
135-
outer_extras = prepare_pullback(extras, gradient_closure, outer(backend), x)
136-
p = pullback!!(gradient_closure, p, outer(backend), x, v, outer_extras)
107+
gradient_closure(z) = gradient(f, inner(backend), z)
108+
p = pullback!!(gradient_closure, p, outer(backend), x, v)
137109
return p
138110
end
139111

140112
function hvp_aux!!(f, p, backend, x, v, extras, ::ForwardOverForward)
141-
function gradient_closure(z)
142-
inner_extras = prepare_gradient(extras, f, inner(backend), z)
143-
return gradient(f, inner(backend), z, inner_extras)
144-
end
145-
outer_extras = prepare_pushforward(extras, gradient_closure, outer(backend), x)
146-
p = pushforward!!(gradient_closure, p, outer(backend), x, v, outer_extras)
113+
gradient_closure(z) = gradient(f, inner(backend), z)
114+
p = pushforward!!(gradient_closure, p, outer(backend), x, v)
147115
return p
148116
end

src/second_derivative.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,8 @@ function second_derivative(
4444
x,
4545
extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x),
4646
)
47-
function derivative_closure(z)
48-
inner_extras = prepare_derivative(extras, f, inner(backend), z)
49-
return derivative(f, inner(backend), z, inner_extras)
50-
end
51-
outer_extras = prepare_derivative(extras, derivative_closure, outer(backend), x)
52-
der2 = derivative(derivative_closure, outer(backend), x, outer_extras)
47+
derivative_closure(z) = derivative(f, inner(backend), z)
48+
der2 = derivative(derivative_closure, outer(backend), x)
5349
return der2
5450
end
5551

@@ -75,12 +71,8 @@ function second_derivative!!(
7571
x,
7672
extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x),
7773
)
78-
function derivative_closure(z)
79-
inner_extras = prepare_derivative(extras, f, inner(backend), z)
80-
return derivative(f, inner(backend), z, inner_extras)
81-
end
82-
outer_extras = prepare_derivative(extras, derivative_closure, outer(backend), x)
83-
der2 = derivative!!(derivative_closure, der2, outer(backend), x, outer_extras)
74+
derivative_closure(z) = derivative(f, inner(backend), z)
75+
der2 = derivative!!(derivative_closure, der2, outer(backend), x)
8476
return der2
8577
end
8678

src/second_order.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ Combination of two backends for second-order differentiation.
1111
1212
$(TYPEDFIELDS)
1313
"""
14-
struct SecondOrder{AD1<:AbstractADType,AD2<:AbstractADType} <: AbstractADType
14+
struct SecondOrder{ADO<:AbstractADType,ADI<:AbstractADType} <: AbstractADType
1515
"backend for the outer differentiation"
16-
outer::AD1
16+
outer::ADO
1717
"backend for the inner differentiation"
18-
inner::AD2
18+
inner::ADI
1919
end
2020

2121
SecondOrder(backend::AbstractADType) = SecondOrder(backend, backend)

0 commit comments

Comments
 (0)