@@ -18,6 +18,7 @@ function DI.prepare_pushforward_nokwarg(
1818 step_der_var = derivative (f (x_var + t_var * dx_var, context_vars... ), t_var)
1919 pf_var = substitute (step_der_var, Dict (t_var => zero (eltype (x))))
2020
21+ erase_cache_vars! (context_vars, contexts)
2122 res = build_function (
2223 pf_var, x_var, dx_var, context_vars... ; expression= Val (false ), cse= true
2324 )
@@ -104,6 +105,7 @@ function DI.prepare_derivative_nokwarg(
104105 context_vars = variablize (contexts)
105106 der_var = derivative (f (x_var, context_vars... ), x_var)
106107
108+ erase_cache_vars! (context_vars, contexts)
107109 res = build_function (der_var, x_var, context_vars... ; expression= Val (false ), cse= true )
108110 (der_exe, der_exe!) = if res isa Tuple
109111 res
@@ -179,6 +181,7 @@ function DI.prepare_gradient_nokwarg(
179181 # Symbolic.gradient only accepts vectors
180182 grad_var = gradient (f (x_var, context_vars... ), vec (x_var))
181183
184+ erase_cache_vars! (context_vars, contexts)
182185 res = build_function (
183186 grad_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
184187 )
@@ -258,6 +261,7 @@ function DI.prepare_jacobian_nokwarg(
258261 jacobian (f (x_var, context_vars... ), x_var)
259262 end
260263
264+ erase_cache_vars! (context_vars, contexts)
261265 res = build_function (jac_var, x_var, context_vars... ; expression= Val (false ), cse= true )
262266 (jac_exe, jac_exe!) = res
263267 return SymbolicsOneArgJacobianPrep (_sig, jac_exe, jac_exe!)
@@ -337,6 +341,7 @@ function DI.prepare_hessian_nokwarg(
337341 hessian (f (x_var, context_vars... ), vec (x_var))
338342 end
339343
344+ erase_cache_vars! (context_vars, contexts)
340345 res = build_function (
341346 hess_var, vec (x_var), context_vars... ; expression= Val (false ), cse= true
342347 )
@@ -425,6 +430,7 @@ function DI.prepare_hvp_nokwarg(
425430 hess_var = hessian (f (x_var, context_vars... ), vec (x_var))
426431 hvp_vec_var = hess_var * vec (dx_var)
427432
433+ erase_cache_vars! (context_vars, contexts)
428434 res = build_function (
429435 hvp_vec_var,
430436 vec (x_var),
@@ -519,6 +525,7 @@ function DI.prepare_second_derivative_nokwarg(
519525 der_var = derivative (f (x_var, context_vars... ), x_var)
520526 der2_var = derivative (der_var, x_var)
521527
528+ erase_cache_vars! (context_vars, contexts)
522529 res = build_function (der2_var, x_var, context_vars... ; expression= Val (false ), cse= true )
523530 (der2_exe, der2_exe!) = if res isa Tuple
524531 res
0 commit comments