diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 48314542d..96450be9a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -18,7 +18,9 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) + res = build_function( + pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true + ) (pf_exe, pf_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -102,7 +104,7 @@ function DI.prepare_derivative_nokwarg( context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) - res = build_function(der_var, x_var, context_vars...; expression=Val(false)) + res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) (der_exe, der_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -177,7 +179,9 @@ function DI.prepare_gradient_nokwarg( # Symbolic.gradient only accepts vectors grad_var = gradient(f(x_var, context_vars...), vec(x_var)) - res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false)) + res = build_function( + grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true + ) (grad_exe, grad_exe!) = res return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!) end @@ -254,7 +258,7 @@ function DI.prepare_jacobian_nokwarg( jacobian(f(x_var, context_vars...), x_var) end - res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) + res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) (jac_exe, jac_exe!) = res return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!) end @@ -333,7 +337,9 @@ function DI.prepare_hessian_nokwarg( hessian(f(x_var, context_vars...), vec(x_var)) end - res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false)) + res = build_function( + hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true + ) (hess_exe, hess_exe!) = res gradient_prep = DI.prepare_gradient_nokwarg( @@ -420,7 +426,12 @@ function DI.prepare_hvp_nokwarg( hvp_vec_var = hess_var * vec(dx_var) res = build_function( - hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false) + hvp_vec_var, + vec(x_var), + vec(dx_var), + context_vars...; + expression=Val(false), + cse=true, ) (hvp_exe, hvp_exe!) = res @@ -508,7 +519,7 @@ function DI.prepare_second_derivative_nokwarg( der_var = derivative(f(x_var, context_vars...), x_var) der2_var = derivative(der_var, x_var) - res = build_function(der2_var, x_var, context_vars...; expression=Val(false)) + res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true) (der2_exe, der2_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 5237a64d4..6597af01c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -26,7 +26,9 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(y_var, t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) + res = build_function( + pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true + ) (pushforward_exe, pushforward_exe!) = res return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!) end @@ -114,7 +116,7 @@ function DI.prepare_derivative_nokwarg( f!(y_var, x_var, context_vars...) der_var = derivative(y_var, x_var) - res = build_function(der_var, x_var, context_vars...; expression=Val(false)) + res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) (der_exe, der_exe!) = res return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!) end @@ -201,7 +203,7 @@ function DI.prepare_jacobian_nokwarg( jacobian(y_var, x_var) end - res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) + res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) (jac_exe, jac_exe!) = res return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!) end