Skip to content

Commit 19b5be4

Browse files
authored
Add AutoSymbolics (#182)
* Add AutoSymbolics * Detoggle * Fix FiniteDiff Hessian
1 parent a22ccc7 commit 19b5be4

17 files changed

Lines changed: 560 additions & 94 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
3636
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
3737
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
3838
DifferentiationInterfaceSparseDiffToolsExt = ["SparseDiffTools", "Symbolics"]
39+
DifferentiationInterfaceSymbolicsExt = "Symbolics"
3940
DifferentiationInterfaceTapirExt = "Tapir"
4041
DifferentiationInterfaceTrackerExt = "Tracker"
4142
DifferentiationInterfaceZygoteExt = "Zygote"

DifferentiationInterface/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ We also provide some experimental backends ourselves:
4949
| Backend | Object |
5050
| :------------------------------------------------------------------------------- | :------------------------------------------------------------- |
5151
| [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) | `AutoFastDifferentiation()`, `AutoSparseFastDifferentiation()` |
52+
| [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) | `AutoSymbolics()`, `AutoSparseSymbolics()` |
5253
| [Tapir.jl](https://github.com/withbayes/Tapir.jl) | `AutoTapir()` |
5354

5455
## Installation

DifferentiationInterface/docs/src/backends.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ function all_backends()
2020
AutoForwardDiff(),
2121
AutoPolyesterForwardDiff(; chunksize=1),
2222
AutoReverseDiff(),
23+
AutoSymbolics(),
2324
AutoTapir(),
2425
AutoTracker(),
2526
AutoZygote(),
@@ -56,6 +57,7 @@ AutoFiniteDifferences
5657
AutoPolyesterForwardDiff
5758
AutoPolyesterForwardDiff()
5859
AutoReverseDiff
60+
AutoSymbolics
5961
AutoTapir
6062
AutoTracker
6163
AutoZygote
@@ -73,6 +75,7 @@ AutoSparseForwardDiff
7375
AutoSparseForwardDiff()
7476
AutoSparsePolyesterForwardDiff
7577
AutoSparseReverseDiff
78+
AutoSparseSymbolics
7679
AutoSparseZygote
7780
```
7881

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ function DI.value_and_pushforward(
6060
)
6161
return f(x), DI.pushforward(f, backend, x, dx, extras)
6262
end
63+
6364
function DI.value_and_pushforward!(
6465
f,
6566
dy,
@@ -73,7 +74,7 @@ end
7374

7475
## Pullback
7576

76-
# TODO: this only fails for scalar -> matrix, not sure why
77+
# TODO: fix https://github.com/gdalle/DifferentiationInterface.jl/issues/131
7778

7879
## Derivative
7980

@@ -209,10 +210,10 @@ function DI.prepare_jacobian(f, backend::AnyAutoFastDifferentiation, x)
209210

210211
x_vec_var = vec(x_var)
211212
y_vec_var = vec(y_var)
212-
if issparse(backend)
213-
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
213+
jac_var = if issparse(backend)
214+
sparse_jacobian(y_vec_var, x_vec_var)
214215
else
215-
jac_var = jacobian(y_vec_var, x_vec_var)
216+
jacobian(y_vec_var, x_vec_var)
216217
end
217218
jac_exe = make_function(jac_var, x_vec_var; in_place=false)
218219
jac_exe! = make_function(jac_var, x_vec_var; in_place=true)
@@ -341,10 +342,10 @@ end
341342
function DI.prepare_hessian(f, backend::AnyAutoFastDifferentiation, x)
342343
x_vec_var = make_variables(:x, size(x)...)
343344
y_vec_var = f(x_vec_var)
344-
if issparse(backend)
345-
hess_var = sparse_hessian(y_vec_var, vec(x_vec_var))
345+
hess_var = if issparse(backend)
346+
sparse_hessian(y_vec_var, vec(x_vec_var))
346347
else
347-
hess_var = hessian(y_vec_var, vec(x_vec_var))
348+
hessian(y_vec_var, vec(x_vec_var))
348349
end
349350
hess_exe = make_function(hess_var, vec(x_vec_var); in_place=false)
350351
hess_exe! = make_function(hess_var, vec(x_vec_var); in_place=true)

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ function DI.prepare_jacobian(f!, y, backend::AnyAutoFastDifferentiation, x)
160160

161161
x_vec_var = vec(x_var)
162162
y_vec_var = vec(y_var)
163-
if issparse(backend)
164-
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
163+
jac_var = if issparse(backend)
164+
sparse_jacobian(y_vec_var, x_vec_var)
165165
else
166-
jac_var = jacobian(y_vec_var, x_vec_var)
166+
jacobian(y_vec_var, x_vec_var)
167167
end
168168
jac_exe = make_function(jac_var, x_vec_var; in_place=false)
169169
jac_exe! = make_function(jac_var, x_vec_var; in_place=true)

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,12 @@ function DI.prepare_hessian(f, backend::AutoFiniteDiff, x)
175175
return FiniteDiffHessianExtras(cache)
176176
end
177177

178-
function DI.hessian(f, ::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
179-
return finite_difference_hessian(f, x, extras.cache)
178+
# cache cannot be reused because of https://github.com/JuliaDiff/FiniteDiff.jl/issues/185
179+
180+
function DI.hessian(f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
181+
return finite_difference_hessian(f, x, HessianCache(x, fdhtype(backend)))
180182
end
181183

182-
function DI.hessian!(f, hess, ::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
183-
return finite_difference_hessian!(hess, f, x, extras.cache)
184+
function DI.hessian!(f, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
185+
return finite_difference_hessian!(hess, f, x, HessianCache(x, fdhtype(backend)))
184186
end
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
module DifferentiationInterfaceSymbolicsExt
2+
3+
using ADTypes: ADTypes
4+
import DifferentiationInterface as DI
5+
using DifferentiationInterface: AutoSymbolics, AutoSparseSymbolics
6+
using DifferentiationInterface:
7+
DerivativeExtras,
8+
GradientExtras,
9+
HessianExtras,
10+
HVPExtras,
11+
JacobianExtras,
12+
PullbackExtras,
13+
PushforwardExtras,
14+
SecondDerivativeExtras
15+
using FillArrays: Fill
16+
using LinearAlgebra: dot
17+
using Symbolics:
18+
build_function,
19+
derivative,
20+
gradient,
21+
hessian,
22+
jacobian,
23+
sparsehessian,
24+
sparsejacobian,
25+
substitute,
26+
variable,
27+
variables
28+
using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction
29+
30+
const AnyAutoSymbolics = Union{AutoSymbolics,AutoSparseSymbolics}
31+
32+
DI.check_available(::AnyAutoSymbolics) = true
33+
DI.mode(::AnyAutoSymbolics) = ADTypes.AbstractSymbolicDifferentiationMode
34+
DI.pushforward_performance(::AnyAutoSymbolics) = DI.PushforwardFast()
35+
DI.pullback_performance(::AnyAutoSymbolics) = DI.PullbackSlow()
36+
37+
monovec(x::Number) = Fill(x, 1)
38+
39+
myvec(x::Number) = monovec(x)
40+
myvec(x::AbstractArray) = vec(x)
41+
42+
issparse(::AutoSymbolics) = false
43+
issparse(::AutoSparseSymbolics) = true
44+
45+
include("onearg.jl")
46+
include("twoarg.jl")
47+
48+
end
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
## Pushforward
2+
3+
struct SymbolicsOneArgPushforwardExtras{E1,E2} <: PushforwardExtras
4+
pf_exe::E1
5+
pf_exe!::E2
6+
end
7+
8+
function DI.prepare_pushforward(f, ::AnyAutoSymbolics, x, dx)
9+
x_var = if x isa Number
10+
variable(:x)
11+
else
12+
variables(:x, axes(x)...)
13+
end
14+
dx_var = if dx isa Number
15+
variable(:dx)
16+
else
17+
variables(:dx, axes(dx)...)
18+
end
19+
t_var = variable(:t)
20+
step_der_var = derivative(f(x_var + t_var * dx_var), t_var)
21+
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))
22+
23+
res = build_function(pf_var, vcat(myvec(x_var), myvec(dx_var)); expression=Val(false))
24+
(pf_exe, pf_exe!) = if res isa Tuple
25+
res
26+
elseif res isa RuntimeGeneratedFunction
27+
res, nothing
28+
end
29+
return SymbolicsOneArgPushforwardExtras(pf_exe, pf_exe!)
30+
end
31+
32+
function DI.pushforward(
33+
f, ::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras
34+
)
35+
v_vec = vcat(myvec(x), myvec(dx))
36+
dy = extras.pf_exe(v_vec)
37+
return dy
38+
end
39+
40+
function DI.pushforward!(
41+
f, dy, ::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras
42+
)
43+
v_vec = vcat(myvec(x), myvec(dx))
44+
extras.pf_exe!(dy, v_vec)
45+
return dy
46+
end
47+
48+
function DI.value_and_pushforward(
49+
f, backend::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras
50+
)
51+
return f(x), DI.pushforward(f, backend, x, dx, extras)
52+
end
53+
54+
function DI.value_and_pushforward!(
55+
f, dy, backend::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras
56+
)
57+
return f(x), DI.pushforward!(f, dy, backend, x, dx, extras)
58+
end
59+
60+
## Derivative
61+
62+
struct SymbolicsOneArgDerivativeExtras{E1,E2} <: DerivativeExtras
63+
der_exe::E1
64+
der_exe!::E2
65+
end
66+
67+
function DI.prepare_derivative(f, ::AnyAutoSymbolics, x)
68+
x_var = variable(:x)
69+
der_var = derivative(f(x_var), x_var)
70+
71+
res = build_function(der_var, x_var; expression=Val(false))
72+
(der_exe, der_exe!) = if res isa Tuple
73+
res
74+
elseif res isa RuntimeGeneratedFunction
75+
res, nothing
76+
end
77+
return SymbolicsOneArgDerivativeExtras(der_exe, der_exe!)
78+
end
79+
80+
function DI.derivative(f, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras)
81+
return extras.der_exe(x)
82+
end
83+
84+
function DI.derivative!(
85+
f, der, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras
86+
)
87+
extras.der_exe!(der, x)
88+
return der
89+
end
90+
91+
function DI.value_and_derivative(
92+
f, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras
93+
)
94+
return f(x), DI.derivative(f, backend, x, extras)
95+
end
96+
97+
function DI.value_and_derivative!(
98+
f, der, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras
99+
)
100+
return f(x), DI.derivative!(f, der, backend, x, extras)
101+
end
102+
103+
## Gradient
104+
105+
struct SymbolicsOneArgGradientExtras{E1,E2} <: GradientExtras
106+
grad_exe::E1
107+
grad_exe!::E2
108+
end
109+
110+
function DI.prepare_gradient(f, ::AnyAutoSymbolics, x)
111+
x_var = variables(:x, axes(x)...)
112+
# Symbolic.gradient only accepts vectors
113+
grad_var = gradient(f(x_var), vec(x_var))
114+
115+
res = build_function(grad_var, vec(x_var); expression=Val(false))
116+
(grad_exe, grad_exe!) = if res isa Tuple
117+
res
118+
elseif res isa RuntimeGeneratedFunction
119+
res, nothing
120+
end
121+
return SymbolicsOneArgGradientExtras(grad_exe, grad_exe!)
122+
end
123+
124+
function DI.gradient(f, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras)
125+
return reshape(extras.grad_exe(vec(x)), size(x))
126+
end
127+
128+
function DI.gradient!(f, grad, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras)
129+
extras.grad_exe!(vec(grad), vec(x))
130+
return grad
131+
end
132+
133+
function DI.value_and_gradient(
134+
f, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras
135+
)
136+
return f(x), DI.gradient(f, backend, x, extras)
137+
end
138+
139+
function DI.value_and_gradient!(
140+
f, grad, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras
141+
)
142+
return f(x), DI.gradient!(f, grad, backend, x, extras)
143+
end
144+
145+
## Jacobian
146+
147+
struct SymbolicsOneArgJacobianExtras{E1,E2} <: JacobianExtras
148+
jac_exe::E1
149+
jac_exe!::E2
150+
end
151+
152+
function DI.prepare_jacobian(f, backend::AnyAutoSymbolics, x)
153+
x_var = variables(:x, axes(x)...)
154+
jac_var = if issparse(backend)
155+
sparsejacobian(f(x_var), x_var)
156+
else
157+
jacobian(f(x_var), x_var)
158+
end
159+
160+
res = build_function(jac_var, x_var; expression=Val(false))
161+
(jac_exe, jac_exe!) = if res isa Tuple
162+
res
163+
elseif res isa RuntimeGeneratedFunction
164+
res, nothing
165+
end
166+
return SymbolicsOneArgJacobianExtras(jac_exe, jac_exe!)
167+
end
168+
169+
function DI.jacobian(f, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgJacobianExtras)
170+
return extras.jac_exe(x)
171+
end
172+
173+
function DI.jacobian!(f, jac, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgJacobianExtras)
174+
extras.jac_exe!(jac, x)
175+
return jac
176+
end
177+
178+
function DI.value_and_jacobian(
179+
f, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgJacobianExtras
180+
)
181+
return f(x), DI.jacobian(f, backend, x, extras)
182+
end
183+
184+
function DI.value_and_jacobian!(
185+
f, jac, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgJacobianExtras
186+
)
187+
return f(x), DI.jacobian!(f, jac, backend, x, extras)
188+
end
189+
190+
## Hessian
191+
192+
struct SymbolicsOneArgHessianExtras{E1,E2} <: HessianExtras
193+
hess_exe::E1
194+
hess_exe!::E2
195+
end
196+
197+
function DI.prepare_hessian(f, backend::AnyAutoSymbolics, x)
198+
x_var = variables(:x, axes(x)...)
199+
# Symbolic.gradient only accepts vectors
200+
hess_var = if issparse(backend)
201+
sparsehessian(f(x_var), vec(x_var))
202+
else
203+
hessian(f(x_var), vec(x_var))
204+
end
205+
206+
res = build_function(hess_var, vec(x_var); expression=Val(false))
207+
(hess_exe, hess_exe!) = if res isa Tuple
208+
res
209+
elseif res isa RuntimeGeneratedFunction
210+
res, nothing
211+
end
212+
return SymbolicsOneArgHessianExtras(hess_exe, hess_exe!)
213+
end
214+
215+
function DI.hessian(f, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgHessianExtras)
216+
return extras.hess_exe(vec(x))
217+
end
218+
219+
function DI.hessian!(f, hess, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgHessianExtras)
220+
extras.hess_exe!(hess, vec(x))
221+
return hess
222+
end

0 commit comments

Comments
 (0)