diff --git a/include/atoms/non_elementwise_full_dom.h b/include/atoms/non_elementwise_full_dom.h index 65c6070..94a7239 100644 --- a/include/atoms/non_elementwise_full_dom.h +++ b/include/atoms/non_elementwise_full_dom.h @@ -22,7 +22,21 @@ #include "subexpr.h" #include "utils/CSR_matrix.h" -expr *new_quad_form(expr *child, CSR_matrix *Q); +expr *new_quad_form_sparse(expr *child, CSR_matrix *Q); + +/* Dense / parametric quadratic form y = x' P x over a vector expression x (a + * leaf variable, or a composition x = f(u) handled via the chain rule). + * + * P is n x n, row-major, and assumed symmetric (matching the new_quad_form_sparse + * convention where the Hessian of x'Qx is taken to be 2Q). For a leaf x the + * Hessian is materialized as a dense permuted_dense block. + * + * - constant P: P_data points to n*n doubles, param_source == NULL. + * - parametric P: P_data == NULL, param_source is the parameter node that + * supplies P (n*n doubles) and is refreshed each solve. + */ +expr *new_quad_form_dense(expr *child, int n, const double *P_data, + expr *param_source); /* product of all entries, without axis argument */ expr *new_prod(expr *child); diff --git a/include/subexpr.h b/include/subexpr.h index 0dd6c8e..26aaaf6 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -54,12 +54,22 @@ typedef struct power_expr double p; } power_expr; -/* Quadratic form: y = x'*Q*x */ +/* Quadratic form: y = x'*Q*x. Q is a polymorphic matrix: a sparse (CSR) backend + on the sparse path, or a dense (permuted_dense) backend on the dense path. */ typedef struct quad_form_expr { expr base; - CSR_matrix *Q; - CSC_matrix *QJf; /* Q * J_f in CSC_matrix (for chain rule hessian) */ + matrix *Q; + /* Q * J_f for the composition chain-rule hessian; exactly one is used per + node. Sparse path: CSC (raw symmetric products, no matrix-vtable form). + Dense path: permuted_dense via the matrix dispatchers. */ + CSC_matrix *QJf; + matrix *QJf_dense; + double *diag_w; /* length-n diagonal (= 2w) fed to BTDA on the dense path */ + int n; /* quadratic dimension = left->size */ + + /* parametric dense path: param_source feeds Q each solve (NULL otherwise) */ + expr *param_source; } quad_form_expr; /* Sum reduction along an axis */ diff --git a/src/atoms/other/quad_form.c b/src/atoms/other/quad_form.c index d02e52c..9116569 100644 --- a/src/atoms/other/quad_form.c +++ b/src/atoms/other/quad_form.c @@ -19,31 +19,53 @@ #include "subexpr.h" #include "utils/CSC_matrix.h" #include "utils/cblas_wrapper.h" +#include "utils/matmul_dispatchers.h" #include "utils/matrix_sum.h" +#include "utils/permuted_dense.h" #include "utils/sparse_matrix.h" #include "utils/tracked_alloc.h" #include -#include #include #include #include +/* Quadratic form y = x'Qx. Sparse path: Q is a CSR matrix. Dense path: Q is an + n x n permuted_dense (optionally parameter-fed). For a leaf x the Hessian 2Q is + materialized as a dense block; for a composition x = f(u) the dense path forms the + chain rule J_f^T Q J_f via the PD matmul dispatchers. Q is assumed symmetric. */ + +/* Refresh Q from the parameter once per solve (no-op when Q is constant). + Q is symmetric, so column-major == row-major and the copy is verbatim. */ +static void refresh_param_values_qf(quad_form_expr *qnode) +{ + if (qnode->param_source == NULL || !qnode->base.needs_parameter_refresh) + { + return; + } + qnode->base.needs_parameter_refresh = false; + memcpy(qnode->Q->x, qnode->param_source->value, + (size_t) qnode->n * qnode->n * sizeof(double)); +} + static void forward(expr *node, const double *u) { + quad_form_expr *qnode = (quad_form_expr *) node; expr *x = node->left; + /* refresh Q from the parameter if needed (no-op on the constant/sparse path) */ + if (qnode->param_source != NULL && node->needs_parameter_refresh) + { + qnode->param_source->forward(qnode->param_source, NULL); + } + refresh_param_values_qf(qnode); + /* child's forward pass */ x->forward(x, u); - /* local forward pass */ - CSR_matrix *Q = ((quad_form_expr *) node)->Q; - Ax_csr(Q, x->value, node->work->dwork, 0); - node->value[0] = 0.0; - - for (int i = 0; i < x->size; i++) - { - node->value[0] += x->value[i] * node->work->dwork[i]; - } + /* dwork = Q @ x; value = x' (Q x) */ + matrix *Q = qnode->Q; + Q->block_left_mult_vec(Q, x->value, node->work->dwork, 1); + node->value[0] = cblas_ddot(qnode->n, x->value, 1, node->work->dwork, 1); } static void jacobian_init_impl(expr *node) @@ -90,15 +112,16 @@ static void jacobian_init_impl(expr *node) static void eval_jacobian(expr *node) { + quad_form_expr *qnode = (quad_form_expr *) node; expr *x = node->left; - CSR_matrix *Q = ((quad_form_expr *) node)->Q; CSR_matrix *jac = node->jacobian->to_csr(node->jacobian); if (x->var_id != NOT_A_VARIABLE) { - /* jacobian = 2 * (Q @ x)^T */ - Ax_csr(Q, x->value, jac->x, 0); - cblas_dscal(x->size, 2.0, jac->x, 1); + /* jacobian = 2 * (Q @ x)^T (leaf x: sparsity is the variable block) */ + matrix *Q = qnode->Q; + Q->block_left_mult_vec(Q, x->value, jac->x, 1); + cblas_dscal(qnode->n, 2.0, jac->x, 1); } else { @@ -124,9 +147,12 @@ static void eval_jacobian(expr *node) } } -static void wsum_hess_init_impl(expr *node) +/* Sparse-backend hessian. The non-leaf chain rule J_f^T Q J_f uses raw CSR/CSC + symmetric products that have no matrix-vtable equivalent. */ +static void wsum_hess_init_sparse(expr *node) { - CSR_matrix *Q = ((quad_form_expr *) node)->Q; + quad_form_expr *qnode = (quad_form_expr *) node; + CSR_matrix *Q = qnode->Q->to_csr(qnode->Q); expr *x = node->left; if (x->var_id != NOT_A_VARIABLE) @@ -160,7 +186,6 @@ static void wsum_hess_init_impl(expr *node) */ /* jacobian_csc_init(x) already called in jacobian_init */ - quad_form_expr *qnode = (quad_form_expr *) node; CSC_matrix *Jf = x->work->jacobian_csc; /* term1 = Jf^T W Jf = Jf^T B*/ @@ -181,9 +206,10 @@ static void wsum_hess_init_impl(expr *node) } } -static void eval_wsum_hess(expr *node, const double *w) +static void eval_wsum_hess_sparse(expr *node, const double *w) { - CSR_matrix *Q = ((quad_form_expr *) node)->Q; + quad_form_expr *qnode = (quad_form_expr *) node; + CSR_matrix *Q = qnode->Q->to_csr(qnode->Q); expr *x = node->left; double two_w = 2.0 * w[0]; @@ -209,10 +235,10 @@ static void eval_wsum_hess(expr *node, const double *w) } } - CSC_matrix *QJf = ((quad_form_expr *) node)->QJf; + CSC_matrix *QJf = qnode->QJf; CSR_matrix *term1 = node->work->hess_term1->to_csr(node->work->hess_term1); - /* term1 = J_f^T Q J_f = J_f^T B */ + /* term1 = J_f^T Q J_f = J_f^T B */ BA_fill_values(Q, Jf, QJf); BTDA_fill_values(Jf, QJf, NULL, term1); @@ -233,16 +259,124 @@ static void eval_wsum_hess(expr *node, const double *w) } } +/* Dense-backend hessian. Leaf x: 2wQ materialized as a permuted_dense block (the + fast common case). Composition x = f(u): the chain rule + H = J_f^T (2w Q) J_f + sum_i (2w Q f(u))_i nabla^2 f_i = term1 + term2, + with term1 formed via the PD matmul dispatchers (Q symmetric PD so QJf = Q J_f + is PD; J_f^T Q J_f = (Q J_f)^T J_f keeps the PD operand on the dispatch key). */ +static void wsum_hess_init_dense(expr *node) +{ + quad_form_expr *qnode = (quad_form_expr *) node; + expr *x = node->left; + int n = qnode->n; + + if (x->var_id != NOT_A_VARIABLE) + { + /* Hessian is the dense block 2Q over x's contiguous variable range. */ + int *perm = (int *) sp_malloc(n * sizeof(int)); + for (int i = 0; i < n; i++) + { + perm[i] = x->var_id + i; + } + node->wsum_hess = + new_permuted_dense(node->n_vars, node->n_vars, n, n, perm, perm, NULL); + sp_free(perm); + } + else + { + /* The dispatchers read a sparse child jacobian through its csc_cache. */ + if (!x->jacobian->is_permuted_dense && !x->jacobian->is_stacked_pd) + { + sparse_matrix_ensure_csc_cache((sparse_matrix *) x->jacobian); + } + + /* term1 = J_f^T Q J_f = (Q J_f)^T J_f. QJf is PD; passing it as the + transposed operand B keeps the PD type on the dispatch key. */ + permuted_dense *Q_pd = (permuted_dense *) qnode->Q; + qnode->QJf_dense = BA_pd_matrices_alloc(Q_pd, x->jacobian); + node->work->hess_term1 = BTA_matrices_alloc(x->jacobian, qnode->QJf_dense); + qnode->diag_w = (double *) sp_malloc(n * sizeof(double)); + + /* term2 = sum_i (Q f(x))_i nabla^2 f_i */ + wsum_hess_init(x); + node->work->hess_term2 = x->wsum_hess->copy_sparsity(x->wsum_hess); + + /* hess = term1 + term2 (CSR-backed; sum_matrices is type-agnostic) */ + int max_nnz = node->work->hess_term1->nnz + node->work->hess_term2->nnz; + node->wsum_hess = + new_sparse_matrix_alloc(node->n_vars, node->n_vars, max_nnz); + sum_matrices_alloc(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); + } +} + +static void eval_wsum_hess_dense(expr *node, const double *w) +{ + quad_form_expr *qnode = (quad_form_expr *) node; + expr *x = node->left; + double two_w = 2.0 * w[0]; + + if (x->var_id != NOT_A_VARIABLE) + { + int nn = qnode->n * qnode->n; + /* Hessian = 2 w Q (Q symmetric, constant up to the weight). The PD's value + buffer (->x) aliases its dense block, so writing it updates to_csr too. */ + memcpy(node->wsum_hess->x, qnode->Q->x, nn * sizeof(double)); + cblas_dscal(nn, two_w, node->wsum_hess->x, 1); + } + else + { + /* Mirror the child jacobian's current values into its csc_cache; the PD + dispatchers below read from it. */ + x->jacobian->refresh_csc_values(x->jacobian); + + /* term1 = 2w J_f^T Q J_f. The dispatcher fill is B^T diag(d) A (no plain + B^T A form); a constant diagonal d = 2w carries the weight. + Potential TODO: Add back BTA_matrices_fill_values_kernel so we don't have + to form diag_w. */ + for (int i = 0; i < qnode->n; i++) + { + qnode->diag_w[i] = two_w; + } + BA_pd_matrices_fill_values((permuted_dense *) qnode->Q, x->jacobian, + (permuted_dense *) qnode->QJf_dense); + BTDA_matrices_fill_values(x->jacobian, qnode->diag_w, qnode->QJf_dense, + node->work->hess_term1); + + /* term2 = 2w sum_i (Q f(x))_i nabla^2 f_i (dwork = Q f(x) from forward) */ + x->eval_wsum_hess(x, node->work->dwork); + memcpy(node->work->hess_term2->x, x->wsum_hess->x, + x->wsum_hess->nnz * sizeof(double)); + cblas_dscal(node->work->hess_term2->nnz, two_w, node->work->hess_term2->x, + 1); + + sum_matrices_fill_values(node->work->hess_term1, node->work->hess_term2, + node->wsum_hess); + } +} + static void free_type_data(expr *node) { quad_form_expr *qnode = (quad_form_expr *) node; - free_CSR_matrix(qnode->Q); + free_matrix(qnode->Q); qnode->Q = NULL; if (qnode->QJf != NULL) { free_CSC_matrix(qnode->QJf); qnode->QJf = NULL; } + if (qnode->QJf_dense != NULL) + { + free_matrix(qnode->QJf_dense); + qnode->QJf_dense = NULL; + } + if (qnode->diag_w != NULL) + { + sp_free(qnode->diag_w); + qnode->diag_w = NULL; + } + free_expr(qnode->param_source); + qnode->param_source = NULL; } static bool is_affine(const expr *node) @@ -252,22 +386,71 @@ static bool is_affine(const expr *node) return false; } -expr *new_quad_form(expr *left, CSR_matrix *Q) +expr *new_quad_form_sparse(expr *left, CSR_matrix *Q) { assert(left->d1 == 1 || left->d2 == 1); /* left must be a vector */ quad_form_expr *qnode = (quad_form_expr *) sp_calloc(1, sizeof(quad_form_expr)); expr *node = &qnode->base; init_expr(node, 1, 1, left->n_vars, forward, jacobian_init_impl, eval_jacobian, - is_affine, wsum_hess_init_impl, eval_wsum_hess, free_type_data); + is_affine, wsum_hess_init_sparse, eval_wsum_hess_sparse, + free_type_data); node->left = left; expr_retain(left); - /* Set type-specific field */ - qnode->Q = new_CSR_matrix(Q->m, Q->n, Q->nnz); - copy_CSR_matrix(Q, qnode->Q); + /* Set type-specific field. new_sparse_matrix takes ownership, so clone. */ + qnode->Q = new_sparse_matrix(new_csr(Q)); + qnode->n = left->size; /* quadratic dimension; used by the shared forward */ /* dwork stores the result of Q @ f(x) in the forward pass */ node->work->dwork = (double *) sp_malloc(left->size * sizeof(double)); return node; } + +expr *new_quad_form_dense(expr *child, int n, const double *P_data, + expr *param_source) +{ + assert(child->d1 == 1 || child->d2 == 1); /* child must be a vector */ + assert(child->size == n); + + quad_form_expr *qnode = (quad_form_expr *) sp_calloc(1, sizeof(quad_form_expr)); + expr *node = &qnode->base; + + init_expr(node, 1, 1, child->n_vars, forward, jacobian_init_impl, eval_jacobian, + is_affine, wsum_hess_init_dense, eval_wsum_hess_dense, free_type_data); + node->left = child; + expr_retain(child); + + qnode->n = n; + /* dwork stores Q @ x in the forward pass */ + node->work->dwork = (double *) sp_malloc(n * sizeof(double)); + + qnode->param_source = param_source; + if (param_source != NULL) + { + if (P_data != NULL) + { + fprintf(stderr, "Error in new_quad_form_dense: param and data both " + "set\n"); + exit(1); + } + + expr_retain(param_source); + + /* Q is filled from the parameter on the first forward pass. */ + qnode->Q = new_permuted_dense_full(n, n, NULL); + node->needs_parameter_refresh = true; + } + else + { + if (P_data == NULL) + { + fprintf(stderr, "Error in new_quad_form_dense: need P data\n"); + exit(1); + } + + qnode->Q = new_permuted_dense_full(n, n, P_data); + } + + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index e571683..4c0f193 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -103,6 +103,7 @@ #include "wsum_hess/other/test_prod_axis_one.h" #include "wsum_hess/other/test_prod_axis_zero.h" #include "wsum_hess/other/test_quad_form.h" +#include "wsum_hess/other/test_quad_form_dense.h" #endif /* PROFILE_ONLY */ #ifdef PROFILE_ONLY @@ -293,6 +294,10 @@ int main(void) mu_run_test(test_wsum_hess_quad_over_lin_xy, tests_run); mu_run_test(test_wsum_hess_quad_over_lin_yx, tests_run); mu_run_test(test_wsum_hess_quad_form, tests_run); + mu_run_test(test_wsum_hess_quad_form_dense, tests_run); + mu_run_test(test_wsum_hess_quad_form_dense_affine, tests_run); + mu_run_test(test_wsum_hess_quad_form_dense_exp, tests_run); + mu_run_test(test_wsum_hess_quad_form_dense_param, tests_run); mu_run_test(test_wsum_hess_scalar_mult_log_vector, tests_run); mu_run_test(test_wsum_hess_scalar_mult_log_matrix, tests_run); mu_run_test(test_wsum_hess_vector_mult_log_vector, tests_run); diff --git a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h index 44bb966..a7870f1 100644 --- a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h +++ b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h @@ -135,7 +135,7 @@ const char *test_jacobian_quad_form_Ax(void) expr *x = new_variable(4, 1, 1, 6); expr *Ax = new_left_matmul(NULL, x, A); expr *sin_Ax = new_sin(Ax); - expr *node = new_quad_form(sin_Ax, Q); + expr *node = new_quad_form_sparse(sin_Ax, Q); mu_assert("check_jacobian failed", check_jacobian_num(node, u_vals, NUMERICAL_DIFF_DEFAULT_H)); @@ -162,7 +162,7 @@ const char *test_jacobian_quad_form_exp(void) expr *x = new_variable(3, 1, 0, 3); expr *exp_x = new_exp(x); - expr *node = new_quad_form(exp_x, Q); + expr *node = new_quad_form_sparse(exp_x, Q); mu_assert("check_jacobian failed", check_jacobian_num(node, u_vals, NUMERICAL_DIFF_DEFAULT_H)); diff --git a/tests/jacobian_tests/other/test_quad_form.h b/tests/jacobian_tests/other/test_quad_form.h index 22d33c0..f4c1b7e 100644 --- a/tests/jacobian_tests/other/test_quad_form.h +++ b/tests/jacobian_tests/other/test_quad_form.h @@ -21,7 +21,7 @@ const char *test_quad_form(void) memcpy(Q->x, Qx, 5 * sizeof(double)); memcpy(Q->i, Qi, 5 * sizeof(int)); memcpy(Q->p, Qp, 4 * sizeof(int)); - expr *node = new_quad_form(x, Q); + expr *node = new_quad_form_sparse(x, Q); jacobian_init(node); node->forward(node, u_vals); @@ -67,7 +67,7 @@ memcpy(A->x, Ax, 10 * sizeof(double)); memcpy(A->i, Ai, 10 * sizeof(int)); memcpy(A->p, Ap, 4 * sizeof(int)); expr *Au = new_linear(u, A, NULL); -expr *node = new_quad_form(Au, Q); +expr *node = new_quad_form_sparse(Au, Q); jacobian_init(node); node->forward(node, u_vals); diff --git a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h index 296dd3c..ed2c405 100644 --- a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h +++ b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h @@ -341,7 +341,7 @@ const char *test_wsum_hess_quad_form_Ax(void) expr *x = new_variable(4, 1, 1, 6); expr *Ax = new_left_matmul(NULL, x, A); - expr *node = new_quad_form(Ax, Q); + expr *node = new_quad_form_sparse(Ax, Q); mu_assert("check_wsum_hess failed", check_wsum_hess(node, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); @@ -371,7 +371,7 @@ const char *test_wsum_hess_quad_form_sin_Ax(void) expr *x = new_variable(4, 1, 1, 6); expr *Ax = new_left_matmul(NULL, x, A); expr *sin_Ax = new_sin(Ax); - expr *node = new_quad_form(sin_Ax, Q); + expr *node = new_quad_form_sparse(sin_Ax, Q); mu_assert("check_wsum_hess failed", check_wsum_hess(node, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); @@ -501,7 +501,7 @@ const char *test_wsum_hess_quad_form_exp(void) expr *x = new_variable(3, 1, 0, 3); expr *exp_x = new_exp(x); - expr *node = new_quad_form(exp_x, Q); + expr *node = new_quad_form_sparse(exp_x, Q); mu_assert("check_wsum_hess failed", check_wsum_hess(node, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); diff --git a/tests/wsum_hess/other/test_quad_form.h b/tests/wsum_hess/other/test_quad_form.h index bd62d5a..fa8956f 100644 --- a/tests/wsum_hess/other/test_quad_form.h +++ b/tests/wsum_hess/other/test_quad_form.h @@ -26,7 +26,7 @@ const char *test_wsum_hess_quad_form(void) memcpy(Q->p, Qp, 5 * sizeof(int)); expr *x = new_variable(4, 1, 3, 10); - expr *node = new_quad_form(x, Q); + expr *node = new_quad_form_sparse(x, Q); jacobian_init(node); node->forward(node, u_vals); diff --git a/tests/wsum_hess/other/test_quad_form_dense.h b/tests/wsum_hess/other/test_quad_form_dense.h new file mode 100644 index 0000000..becaa4f --- /dev/null +++ b/tests/wsum_hess/other/test_quad_form_dense.h @@ -0,0 +1,159 @@ +#include "atoms/affine.h" +#include "atoms/elementwise_full_dom.h" +#include "atoms/non_elementwise_full_dom.h" +#include "expr.h" +#include "minunit.h" +#include "numerical_diff.h" +#include "test_helpers.h" +#include +#include + +/* Dense path of quad_form: y = x' P x with a dense (permuted_dense) P. + * x is 3x1 with global index 2, total variables = 5. + * P = [1 2 0; 2 3 0; 0 0 4] (symmetric), x = [1, 2, 3]. + * value = x' P x = 57 + * gradient = 2 P x = [10, 16, 24] on columns {2, 3, 4} + * Hessian = 2 w P (full dense block over rows/cols {2, 3, 4}) + */ +const char *test_wsum_hess_quad_form_dense(void) +{ + double u_vals[5] = {0.0, 0.0, 1.0, 2.0, 3.0}; + double w = 2.0; + + /* row-major 3x3 dense P */ + double P[9] = {1.0, 2.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0}; + + expr *x = new_variable(3, 1, 2, 5); + expr *node = new_quad_form_dense(x, 3, P, NULL); + + jacobian_init(node); + node->forward(node, u_vals); + node->eval_jacobian(node); + wsum_hess_init(node); + node->eval_wsum_hess(node, &w); + + /* forward value */ + mu_assert("dense quad_form value fail", fabs(node->value[0] - 57.0) < 1e-9); + + /* gradient = 2 P x on columns {2,3,4} */ + double expected_grad[3] = {10.0, 16.0, 24.0}; + int expected_jp[2] = {0, 3}; + int expected_ji[3] = {2, 3, 4}; + mu_assert("dense quad_form jacobian vals fail", + cmp_values(node->jacobian, expected_grad, 3)); + mu_assert("dense quad_form jacobian sparsity fail", + cmp_sparsity(node->jacobian, expected_jp, expected_ji, 1, 3)); + + /* Hessian = 2 w P = 4 P as a dense block over rows/cols {2,3,4} */ + mu_assert("dense quad_form hessian is not permuted_dense", + node->wsum_hess->is_permuted_dense); + int expected_hp[6] = {0, 0, 0, 3, 6, 9}; + int expected_hi[9] = {2, 3, 4, 2, 3, 4, 2, 3, 4}; + double expected_hx[9] = {4.0, 8.0, 0.0, 8.0, 12.0, 0.0, 0.0, 0.0, 16.0}; + mu_assert("dense quad_form hessian sparsity fail", + cmp_sparsity(node->wsum_hess, expected_hp, expected_hi, 5, 9)); + mu_assert("dense quad_form hessian vals fail", + cmp_values(node->wsum_hess, expected_hx, 9)); + + free_expr(node); + return 0; +} + +/* Dense quad_form over an affine composition x = A u (term2 vanishes). */ +const char *test_wsum_hess_quad_form_dense_affine(void) +{ + double u_vals[3] = {0.5, 1.0, 1.5}; + double w = 2.0; + + /* row-major 3x3 dense P */ + double P[9] = {1.0, 2.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0}; + /* row-major 3x3 A (square, so A u has size n) */ + double A[9] = {0.5, -1.0, 0.2, 0.3, 0.7, -0.4, -0.6, 0.1, 0.9}; + + expr *x = new_variable(3, 1, 0, 3); + expr *Ax = new_left_matmul_dense(NULL, x, 3, 3, A); + expr *node = new_quad_form_dense(Ax, 3, P, NULL); + + mu_assert("dense quad_form affine composition wsum_hess failed", + check_wsum_hess(node, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(node); + return 0; +} + +/* Dense quad_form over a nonlinear composition x = exp(u) (nonzero term2). */ +const char *test_wsum_hess_quad_form_dense_exp(void) +{ + double u_vals[3] = {0.5, 1.0, 1.5}; + double w = 3.0; + + /* row-major 3x3 dense P */ + double P[9] = {1.0, 2.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0}; + + expr *x = new_variable(3, 1, 0, 3); + expr *exp_x = new_exp(x); + expr *node = new_quad_form_dense(exp_x, 3, P, NULL); + + mu_assert("dense quad_form nonlinear composition wsum_hess failed", + check_wsum_hess(node, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(node); + return 0; +} + +/* Parametric dense quad_form: P is supplied by a parameter and refreshed each + * solve. Same setup as above; updating the parameter changes the values but not + * the sparsity. + * P1 = [1 2 0; 2 3 0; 0 0 4]: value 57, grad [10,16,24], Hessian 4 P1. + * P2 = [2 1 0; 1 4 0; 0 0 5]: value 67, grad [8,18,30], Hessian 4 P2. + */ +const char *test_wsum_hess_quad_form_dense_param(void) +{ + double u_vals[5] = {0.0, 0.0, 1.0, 2.0, 3.0}; + double w = 2.0; + + double P1[9] = {1.0, 2.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 4.0}; + expr *param_P = new_parameter(3, 3, 0, 5, P1); /* param_id 0 = updatable */ + expr *x = new_variable(3, 1, 2, 5); + expr *node = new_quad_form_dense(x, 3, NULL, param_P); /* P fed by parameter */ + + jacobian_init(node); + wsum_hess_init(node); + node->forward(node, u_vals); + node->eval_jacobian(node); + node->eval_wsum_hess(node, &w); + + /* parameter P1: check values and sparsity */ + int expected_jp[2] = {0, 3}; + int expected_ji[3] = {2, 3, 4}; + int expected_hp[6] = {0, 0, 0, 3, 6, 9}; + int expected_hi[9] = {2, 3, 4, 2, 3, 4, 2, 3, 4}; + double grad1[3] = {10.0, 16.0, 24.0}; + double hess1[9] = {4.0, 8.0, 0.0, 8.0, 12.0, 0.0, 0.0, 0.0, 16.0}; + mu_assert("P1 value fail", fabs(node->value[0] - 57.0) < 1e-9); + mu_assert("P1 grad vals fail", cmp_values(node->jacobian, grad1, 3)); + mu_assert("P1 grad sparsity fail", + cmp_sparsity(node->jacobian, expected_jp, expected_ji, 1, 3)); + mu_assert("hessian not permuted_dense", node->wsum_hess->is_permuted_dense); + mu_assert("P1 hessian vals fail", cmp_values(node->wsum_hess, hess1, 9)); + mu_assert("P1 hessian sparsity fail", + cmp_sparsity(node->wsum_hess, expected_hp, expected_hi, 5, 9)); + + /* update the parameter; the refresh recomputes only the values */ + double P2[9] = {2.0, 1.0, 0.0, 1.0, 4.0, 0.0, 0.0, 0.0, 5.0}; + memcpy(param_P->value, P2, 9 * sizeof(double)); + expr_set_needs_refresh(node); + node->forward(node, u_vals); + node->eval_jacobian(node); + node->eval_wsum_hess(node, &w); + + /* parameter P2: only the values change, sparsity is unchanged */ + double grad2[3] = {8.0, 18.0, 30.0}; + double hess2[9] = {8.0, 4.0, 0.0, 4.0, 16.0, 0.0, 0.0, 0.0, 20.0}; + mu_assert("P2 value fail", fabs(node->value[0] - 67.0) < 1e-9); + mu_assert("P2 grad vals fail", cmp_values(node->jacobian, grad2, 3)); + mu_assert("P2 hessian vals fail", cmp_values(node->wsum_hess, hess2, 9)); + + free_expr(node); + return 0; +}