Skip to content

Commit ac96b93

Browse files
authored
use cse in AutoSymbolics
closes #758
1 parent fca02b8 commit ac96b93

2 files changed

Lines changed: 10 additions & 10 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function DI.prepare_pushforward_nokwarg(
1818
step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var)
1919
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))
2020

21-
res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false))
21+
res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true)
2222
(pf_exe, pf_exe!) = if res isa Tuple
2323
res
2424
elseif res isa RuntimeGeneratedFunction
@@ -102,7 +102,7 @@ function DI.prepare_derivative_nokwarg(
102102
context_vars = variablize(contexts)
103103
der_var = derivative(f(x_var, context_vars...), x_var)
104104

105-
res = build_function(der_var, x_var, context_vars...; expression=Val(false))
105+
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)
106106
(der_exe, der_exe!) = if res isa Tuple
107107
res
108108
elseif res isa RuntimeGeneratedFunction
@@ -177,7 +177,7 @@ function DI.prepare_gradient_nokwarg(
177177
# Symbolic.gradient only accepts vectors
178178
grad_var = gradient(f(x_var, context_vars...), vec(x_var))
179179

180-
res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false))
180+
res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true)
181181
(grad_exe, grad_exe!) = res
182182
return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!)
183183
end
@@ -254,7 +254,7 @@ function DI.prepare_jacobian_nokwarg(
254254
jacobian(f(x_var, context_vars...), x_var)
255255
end
256256

257-
res = build_function(jac_var, x_var, context_vars...; expression=Val(false))
257+
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)
258258
(jac_exe, jac_exe!) = res
259259
return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!)
260260
end
@@ -333,7 +333,7 @@ function DI.prepare_hessian_nokwarg(
333333
hessian(f(x_var, context_vars...), vec(x_var))
334334
end
335335

336-
res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false))
336+
res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true)
337337
(hess_exe, hess_exe!) = res
338338

339339
gradient_prep = DI.prepare_gradient_nokwarg(
@@ -420,7 +420,7 @@ function DI.prepare_hvp_nokwarg(
420420
hvp_vec_var = hess_var * vec(dx_var)
421421

422422
res = build_function(
423-
hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false)
423+
hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false), cse=true
424424
)
425425
(hvp_exe, hvp_exe!) = res
426426

@@ -508,7 +508,7 @@ function DI.prepare_second_derivative_nokwarg(
508508
der_var = derivative(f(x_var, context_vars...), x_var)
509509
der2_var = derivative(der_var, x_var)
510510

511-
res = build_function(der2_var, x_var, context_vars...; expression=Val(false))
511+
res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true)
512512
(der2_exe, der2_exe!) = if res isa Tuple
513513
res
514514
elseif res isa RuntimeGeneratedFunction

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function DI.prepare_pushforward_nokwarg(
2626
step_der_var = derivative(y_var, t_var)
2727
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))
2828

29-
res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false))
29+
res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true)
3030
(pushforward_exe, pushforward_exe!) = res
3131
return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!)
3232
end
@@ -114,7 +114,7 @@ function DI.prepare_derivative_nokwarg(
114114
f!(y_var, x_var, context_vars...)
115115
der_var = derivative(y_var, x_var)
116116

117-
res = build_function(der_var, x_var, context_vars...; expression=Val(false))
117+
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)
118118
(der_exe, der_exe!) = res
119119
return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!)
120120
end
@@ -201,7 +201,7 @@ function DI.prepare_jacobian_nokwarg(
201201
jacobian(y_var, x_var)
202202
end
203203

204-
res = build_function(jac_var, x_var, context_vars...; expression=Val(false))
204+
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)
205205
(jac_exe, jac_exe!) = res
206206
return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!)
207207
end

0 commit comments

Comments
 (0)