Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var)
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))

res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false))
res = build_function(

Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl#L21

Added line #L21 was not covered by tests
pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true
)
(pf_exe, pf_exe!) = if res isa Tuple
res
elseif res isa RuntimeGeneratedFunction
Expand Down Expand Up @@ -102,7 +104,7 @@
context_vars = variablize(contexts)
der_var = derivative(f(x_var, context_vars...), x_var)

res = build_function(der_var, x_var, context_vars...; expression=Val(false))
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)

Check warning on line 107 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl#L107

Added line #L107 was not covered by tests
(der_exe, der_exe!) = if res isa Tuple
res
elseif res isa RuntimeGeneratedFunction
Expand Down Expand Up @@ -177,7 +179,9 @@
# Symbolic.gradient only accepts vectors
grad_var = gradient(f(x_var, context_vars...), vec(x_var))

res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false))
res = build_function(

Check warning on line 182 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl#L182

Added line #L182 was not covered by tests
grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true
)
(grad_exe, grad_exe!) = res
return SymbolicsOneArgGradientPrep(_sig, grad_exe, grad_exe!)
end
Expand Down Expand Up @@ -254,7 +258,7 @@
jacobian(f(x_var, context_vars...), x_var)
end

res = build_function(jac_var, x_var, context_vars...; expression=Val(false))
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)

Check warning on line 261 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl#L261

Added line #L261 was not covered by tests
(jac_exe, jac_exe!) = res
return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!)
end
Expand Down Expand Up @@ -333,7 +337,9 @@
hessian(f(x_var, context_vars...), vec(x_var))
end

res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false))
res = build_function(

Check warning on line 340 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl#L340

Added line #L340 was not covered by tests
hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true
)
(hess_exe, hess_exe!) = res

gradient_prep = DI.prepare_gradient_nokwarg(
Expand Down Expand Up @@ -420,7 +426,12 @@
hvp_vec_var = hess_var * vec(dx_var)

res = build_function(
hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false)
hvp_vec_var,
vec(x_var),
vec(dx_var),
context_vars...;
expression=Val(false),
cse=true,
)
(hvp_exe, hvp_exe!) = res

Expand Down Expand Up @@ -508,7 +519,7 @@
der_var = derivative(f(x_var, context_vars...), x_var)
der2_var = derivative(der_var, x_var)

res = build_function(der2_var, x_var, context_vars...; expression=Val(false))
res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true)

Check warning on line 522 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl#L522

Added line #L522 was not covered by tests
(der2_exe, der2_exe!) = if res isa Tuple
res
elseif res isa RuntimeGeneratedFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
step_der_var = derivative(y_var, t_var)
pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x))))

res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false))
res = build_function(

Check warning on line 29 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl#L29

Added line #L29 was not covered by tests
pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true
)
(pushforward_exe, pushforward_exe!) = res
return SymbolicsTwoArgPushforwardPrep(_sig, pushforward_exe, pushforward_exe!)
end
Expand Down Expand Up @@ -114,7 +116,7 @@
f!(y_var, x_var, context_vars...)
der_var = derivative(y_var, x_var)

res = build_function(der_var, x_var, context_vars...; expression=Val(false))
res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true)

Check warning on line 119 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl#L119

Added line #L119 was not covered by tests
(der_exe, der_exe!) = res
return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!)
end
Expand Down Expand Up @@ -201,7 +203,7 @@
jacobian(y_var, x_var)
end

res = build_function(jac_var, x_var, context_vars...; expression=Val(false))
res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true)

Check warning on line 206 in DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl#L206

Added line #L206 was not covered by tests
(jac_exe, jac_exe!) = res
return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!)
end
Expand Down
Loading