Skip to content

Commit 2ee1bf1

Browse files
authored
Prepation for outer backend of SecondOrder (#135)
1 parent 3e6164f commit 2ee1bf1

24 files changed

Lines changed: 320 additions & 220 deletions

File tree

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ We support most of the backends defined by [ADTypes.jl](https://github.com/SciML
4040
| [Tracker.jl](https://github.com/FluxML/Tracker.jl) | `AutoTracker()` |
4141
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |
4242

43-
We also support additional (experimental) backends:
43+
We also provide some experimental backends ourselves:
4444

45-
| backend | object |
46-
| :------------------------------------------------------------------------------- | :-------------------------- |
47-
| [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) | `AutoFastDifferentiation()` |
48-
| [Tapir.jl](https://github.com/withbayes/Tapir.jl) | `AutoTapir()` |
45+
| backend | object |
46+
| :------------------------------------------------------------------------------- | :------------------------------------------------------------- |
47+
| [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) | `AutoFastDifferentiation()`, `AutoSparseFastDifferentiation()` |
48+
| [Tapir.jl](https://github.com/withbayes/Tapir.jl) | `AutoTapir()` |
4949

5050
## Example
5151

docs/src/overview.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ By default, all the preparation functions return `nothing`.
101101
We do not make any guarantees on their implementation for each backend, or on the performance gains that can be expected.
102102

103103
!!! warning
104-
We haven't yet figured out how to deal with extras for second-order operators, because closures make our life rather complicated.
105-
For now, consider that preparation doesn't work there in general, although some individual backends may be okay already.
104+
For `SecondOrder` backends, the inner differentiation cannot be prepared at the moment, only the outer one is.
106105

107106
## FAQ
108107

@@ -118,7 +117,6 @@ The sparsity pattern is computed automatically with [Symbolics.jl](https://githu
118117

119118
If you need to work with sparse Hessians, you can use a sparse backend as the _outer_ backend of a `SecondOrder`.
120119
This means the Hessian is obtained as the sparse Jacobian of the gradient.
121-
Since preparation does not yet work for second order, the sparsity pattern is currently recomputed every time, so you may not gain much time as things stand.
122120

123121
!!! danger
124122
Sparsity support is still experimental, use at your own risk.

ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ DI.supports_mutation(::AutoChainRules) = DI.MutationNotSupported()
1515
DI.mode(::AutoForwardChainRules) = ADTypes.AbstractForwardMode
1616
DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode
1717

18-
## Pushforward
18+
## Pushforward (unused)
1919

20+
#=
2021
DI.prepare_pushforward(f, ::AutoForwardChainRules, x) = NoPushforwardExtras()
2122
2223
function DI.value_and_pushforward(
@@ -26,6 +27,7 @@ function DI.value_and_pushforward(
2627
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
2728
return y, new_dy
2829
end
30+
=#
2931

3032
## Pullback
3133

ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,15 @@ function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x::AbstractArray, ::NoGradi
6868
gradient!(Reverse, grad_sametype, f, x)
6969
return grad_sametype
7070
end
71+
72+
function DI.value_and_gradient(
73+
f, backend::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras
74+
)
75+
return DI.value_and_pullback(f, backend, x, one(eltype(x)), NoPullbackExtras())
76+
end
77+
78+
function DI.value_and_gradient!!(
79+
f, grad, backend::AutoReverseEnzyme, x::AbstractArray, ::NoGradientExtras
80+
)
81+
return DI.value_and_pullback!!(f, grad, backend, x, one(eltype(x)), NoPullbackExtras())
82+
end

ext/DifferentiationInterfaceFastDifferentiationExt/allocating.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,25 @@ function DI.value_and_derivative!!(
122122
return DI.value_and_derivative(f, backend, x, extras)
123123
end
124124

125+
function DI.derivative(
126+
f,
127+
backend::AnyAutoFastDifferentiation,
128+
x,
129+
extras::FastDifferentiationAllocatingDerivativeExtras,
130+
)
131+
return DI.value_and_derivative(f, backend, x, extras)[2]
132+
end
133+
134+
function DI.derivative!!(
135+
f,
136+
der,
137+
backend::AnyAutoFastDifferentiation,
138+
x,
139+
extras::FastDifferentiationAllocatingDerivativeExtras,
140+
)
141+
return DI.derivative(f, backend, x, extras)
142+
end
143+
125144
## Jacobian
126145

127146
struct FastDifferentiationAllocatingJacobianExtras{E} <: JacobianExtras
@@ -226,7 +245,7 @@ struct FastDifferentiationHVPExtras{E} <: HVPExtras
226245
hvp_exe::E
227246
end
228247

229-
function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x)
248+
function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
230249
x_var = if x isa Number
231250
only(make_variables(:x))
232251
else

ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@ module DifferentiationInterfaceForwardDiffExt
33
using ADTypes: AbstractADType, AutoForwardDiff, AutoSparseForwardDiff
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6-
DerivativeExtras, GradientExtras, HessianExtras, JacobianExtras, NoPushforwardExtras
6+
DerivativeExtras,
7+
GradientExtras,
8+
HessianExtras,
9+
JacobianExtras,
10+
NoDerivativeExtras,
11+
NoPushforwardExtras
712
using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult
813
using ForwardDiff:
914
Chunk,

ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ using DifferentiationInterface:
1111
GradientExtras,
1212
HessianExtras,
1313
JacobianExtras,
14+
NoDerivativeExtras,
1415
NoGradientExtras,
16+
NoHessianExtras,
1517
NoJacobianExtras,
1618
PushforwardExtras
1719
using DocStringExtensions

ext/DifferentiationInterfacePolyesterForwardDiffExt/allocating.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,38 +59,44 @@ end
5959

6060
## Gradient
6161

62-
DI.prepare_gradient(f, ::AnyAutoPolyForwardDiff, x) = NoGradientExtras()
62+
function DI.prepare_gradient(f, backend::AnyAutoPolyForwardDiff, x)
63+
return DI.prepare_gradient(f, single_threaded(backend), x)
64+
end
6365

6466
function DI.value_and_gradient!!(
65-
f,
66-
grad::AbstractVector,
67-
::AnyAutoPolyForwardDiff{C},
68-
x::AbstractVector,
69-
::NoGradientExtras,
67+
f, grad, ::AnyAutoPolyForwardDiff{C}, x::AbstractVector, ::GradientExtras
7068
) where {C}
7169
threaded_gradient!(f, grad, x, Chunk{C}())
7270
return f(x), grad
7371
end
7472

7573
function DI.gradient!!(
76-
f,
77-
grad::AbstractVector,
78-
::AnyAutoPolyForwardDiff{C},
79-
x::AbstractVector,
80-
::NoGradientExtras,
74+
f, grad, ::AnyAutoPolyForwardDiff{C}, x::AbstractVector, ::GradientExtras
8175
) where {C}
8276
threaded_gradient!(f, grad, x, Chunk{C}())
8377
return grad
8478
end
8579

80+
function DI.value_and_gradient!!(
81+
f, grad, backend::AnyAutoPolyForwardDiff{C}, x::AbstractArray, extras::GradientExtras
82+
) where {C}
83+
return DI.value_and_gradient!!(f, grad, single_threaded(backend), x, extras)
84+
end
85+
86+
function DI.gradient!!(
87+
f, grad, backend::AnyAutoPolyForwardDiff{C}, x::AbstractArray, extras::GradientExtras
88+
) where {C}
89+
return DI.gradient!!(f, grad, single_threaded(backend), x, extras)
90+
end
91+
8692
function DI.value_and_gradient(
87-
f, backend::AnyAutoPolyForwardDiff, x::AbstractVector, extras::NoGradientExtras
93+
f, backend::AnyAutoPolyForwardDiff, x::AbstractArray, extras::GradientExtras
8894
)
8995
return DI.value_and_gradient!!(f, similar(x), backend, x, extras)
9096
end
9197

9298
function DI.gradient(
93-
f, backend::AnyAutoPolyForwardDiff, x::AbstractVector, extras::NoGradientExtras
99+
f, backend::AnyAutoPolyForwardDiff, x::AbstractArray, extras::GradientExtras
94100
)
95101
return DI.gradient!!(f, similar(x), backend, x, extras)
96102
end

ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module DifferentiationInterfaceSparseDiffToolsExt
22

33
using ADTypes
44
import DifferentiationInterface as DI
5-
using DifferentiationInterface: JacobianExtras, NoHessianExtras, SecondOrder, inner, outer
5+
using DifferentiationInterface:
6+
HessianExtras, JacobianExtras, NoHessianExtras, SecondOrder, inner, outer
67
using SparseDiffTools:
78
AutoSparseEnzyme,
89
JacPrototypeSparsityDetection,

ext/DifferentiationInterfaceSparseDiffToolsExt/allocating.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ struct SparseDiffToolsAllocatingJacobianExtras{C} <: JacobianExtras
22
cache::C
33
end
44

5+
struct SparseDiffToolsHessianExtras{C,E} <: HessianExtras
6+
inner_gradient_closure::C
7+
outer_jacobian_extras::E
8+
end
9+
510
for AutoSparse in SPARSE_BACKENDS
611
@eval begin
712

@@ -42,18 +47,41 @@ for AutoSparse in SPARSE_BACKENDS
4247

4348
## Hessian
4449

45-
DI.prepare_hessian(f, ::SecondOrder{<:$AutoSparse}, x) = NoHessianExtras()
50+
function DI.prepare_hessian(f, backend::SecondOrder{<:$AutoSparse}, x)
51+
inner_gradient_closure(z) = DI.gradient(f, inner(backend), z)
52+
outer_jacobian_extras = DI.prepare_jacobian(
53+
inner_gradient_closure, outer(backend), x
54+
)
55+
return SparseDiffToolsHessianExtras(
56+
inner_gradient_closure, outer_jacobian_extras
57+
)
58+
end
4659

47-
function DI.hessian(f, backend::SecondOrder{<:$AutoSparse}, x, ::NoHessianExtras)
48-
gradient_closure(z) = DI.gradient(f, inner(backend), z)
49-
return DI.jacobian(gradient_closure, outer(backend), x)
60+
function DI.hessian(
61+
f, backend::SecondOrder{<:$AutoSparse}, x, extras::SparseDiffToolsHessianExtras
62+
)
63+
return DI.jacobian(
64+
extras.inner_gradient_closure,
65+
outer(backend),
66+
x,
67+
extras.outer_jacobian_extras,
68+
)
5069
end
5170

5271
function DI.hessian!!(
53-
f, hess, backend::SecondOrder{<:$AutoSparse}, x, ::NoHessianExtras
72+
f,
73+
hess,
74+
backend::SecondOrder{<:$AutoSparse},
75+
x,
76+
extras::SparseDiffToolsHessianExtras,
5477
)
55-
gradient_closure(z) = DI.gradient(f, inner(backend), z)
56-
return DI.jacobian!!(gradient_closure, hess, outer(backend), x)
78+
return DI.jacobian!!(
79+
extras.inner_gradient_closure,
80+
hess,
81+
outer(backend),
82+
x,
83+
extras.outer_jacobian_extras,
84+
)
5785
end
5886
end
5987
end

0 commit comments

Comments
 (0)