From ac96b934ce84c5bc87308fe78a03ee1026c9d09d Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Thu, 27 Mar 2025 14:24:32 +0100 Subject: [PATCH 1/2] use cse in AutoSymbolics closes #758 --- .../DifferentiationInterfaceSymbolicsExt/onearg.jl | 14 +++++++------- .../DifferentiationInterfaceSymbolicsExt/twoarg.jl | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 48314542d..c3eb72c2f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -18,7 +18,7 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) + res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true) (pf_exe, pf_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -102,7 +102,7 @@ function DI.prepare_derivative_nokwarg( context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) - res = build_function(der_var, x_var, context_vars...; expression=Val(false)) + res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) (der_exe, der_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -177,7 +177,7 @@ function DI.prepare_gradient_nokwarg( # Symbolic.gradient only accepts vectors grad_var = gradient(f(x_var, context_vars...), vec(x_var)) - res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false)) + res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true) (grad_exe, grad_exe!) = res return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!) end @@ -254,7 +254,7 @@ function DI.prepare_jacobian_nokwarg( jacobian(f(x_var, context_vars...), x_var) end - res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) + res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) (jac_exe, jac_exe!) = res return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!) end @@ -333,7 +333,7 @@ function DI.prepare_hessian_nokwarg( hessian(f(x_var, context_vars...), vec(x_var)) end - res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false)) + res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true) (hess_exe, hess_exe!) = res gradient_prep = DI.prepare_gradient_nokwarg( @@ -420,7 +420,7 @@ function DI.prepare_hvp_nokwarg( hvp_vec_var = hess_var * vec(dx_var) res = build_function( - hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false) + hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false), cse=true ) (hvp_exe, hvp_exe!) = res @@ -508,7 +508,7 @@ function DI.prepare_second_derivative_nokwarg( der_var = derivative(f(x_var, context_vars...), x_var) der2_var = derivative(der_var, x_var) - res = build_function(der2_var, x_var, context_vars...; expression=Val(false)) + res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true) (der2_exe, der2_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 5237a64d4..10462ab2a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -26,7 +26,7 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(y_var, t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) + res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true) (pushforward_exe, pushforward_exe!) = res return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!) end @@ -114,7 +114,7 @@ function DI.prepare_derivative_nokwarg( f!(y_var, x_var, context_vars...) der_var = derivative(y_var, x_var) - res = build_function(der_var, x_var, context_vars...; expression=Val(false)) + res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) (der_exe, der_exe!) = res return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!) end @@ -201,7 +201,7 @@ function DI.prepare_jacobian_nokwarg( jacobian(y_var, x_var) end - res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) + res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) (jac_exe, jac_exe!) = res return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!) end From 639a196104e6af48d769dd289b6f51958b59e809 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 27 Mar 2025 15:28:19 +0100 Subject: [PATCH 2/2] Format --- .../onearg.jl | 19 +++++++++++++++---- .../twoarg.jl | 4 +++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index c3eb72c2f..96450be9a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -18,7 +18,9 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true) + res = build_function( + pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true + ) (pf_exe, pf_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -177,7 +179,9 @@ function DI.prepare_gradient_nokwarg( # Symbolic.gradient only accepts vectors grad_var = gradient(f(x_var, context_vars...), vec(x_var)) - res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true) + res = build_function( + grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true + ) (grad_exe, grad_exe!) = res return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!) end @@ -333,7 +337,9 @@ function DI.prepare_hessian_nokwarg( hessian(f(x_var, context_vars...), vec(x_var)) end - res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true) + res = build_function( + hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true + ) (hess_exe, hess_exe!) = res gradient_prep = DI.prepare_gradient_nokwarg( @@ -420,7 +426,12 @@ function DI.prepare_hvp_nokwarg( hvp_vec_var = hess_var * vec(dx_var) res = build_function( - hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false), cse=true + hvp_vec_var, + vec(x_var), + vec(dx_var), + context_vars...; + expression=Val(false), + cse=true, ) (hvp_exe, hvp_exe!) = res diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 10462ab2a..6597af01c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -26,7 +26,9 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(y_var, t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true) + res = build_function( + pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true + ) (pushforward_exe, pushforward_exe!) = res return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!) end