Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion include/atoms/non_elementwise_full_dom.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,21 @@
#include "subexpr.h"
#include "utils/CSR_matrix.h"

expr *new_quad_form(expr *child, CSR_matrix *Q);
expr *new_quad_form_sparse(expr *child, CSR_matrix *Q);

/* Dense / parametric quadratic form y = x' P x over a vector expression x (a
* leaf variable, or a composition x = f(u) handled via the chain rule).
*
* P is n x n, row-major, and assumed symmetric (matching the new_quad_form_sparse
* convention where the Hessian of x'Qx is taken to be 2Q). For a leaf x the
* Hessian is materialized as a dense permuted_dense block.
*
* - constant P: P_data points to n*n doubles, param_source == NULL.
* - parametric P: P_data == NULL, param_source is the parameter node that
* supplies P (n*n doubles) and is refreshed each solve.
*/
expr *new_quad_form_dense(expr *child, int n, const double *P_data,
expr *param_source);

/* product of all entries, without axis argument */
expr *new_prod(expr *child);
Expand Down
16 changes: 13 additions & 3 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,22 @@ typedef struct power_expr
double p;
} power_expr;

/* Quadratic form: y = x'*Q*x */
/* Quadratic form: y = x'*Q*x. Q is a polymorphic matrix: a sparse (CSR) backend
on the sparse path, or a dense (permuted_dense) backend on the dense path. */
typedef struct quad_form_expr
{
expr base;
CSR_matrix *Q;
CSC_matrix *QJf; /* Q * J_f in CSC_matrix (for chain rule hessian) */
matrix *Q;
/* Q * J_f for the composition chain-rule hessian; exactly one is used per
node. Sparse path: CSC (raw symmetric products, no matrix-vtable form).
Dense path: permuted_dense via the matrix dispatchers. */
CSC_matrix *QJf;
matrix *QJf_dense;
double *diag_w; /* length-n diagonal (= 2w) fed to BTDA on the dense path */
int n; /* quadratic dimension = left->size */

/* parametric dense path: param_source feeds Q each solve (NULL otherwise) */
expr *param_source;
} quad_form_expr;

/* Sum reduction along an axis */
Expand Down
237 changes: 210 additions & 27 deletions src/atoms/other/quad_form.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,53 @@
#include "subexpr.h"
#include "utils/CSC_matrix.h"
#include "utils/cblas_wrapper.h"
#include "utils/matmul_dispatchers.h"
#include "utils/matrix_sum.h"
#include "utils/permuted_dense.h"
#include "utils/sparse_matrix.h"
#include "utils/tracked_alloc.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Quadratic form y = x'Qx. Sparse path: Q is a CSR matrix. Dense path: Q is an
n x n permuted_dense (optionally parameter-fed). For a leaf x the Hessian 2Q is
materialized as a dense block; for a composition x = f(u) the dense path forms the
chain rule J_f^T Q J_f via the PD matmul dispatchers. Q is assumed symmetric. */

/* Refresh Q from the parameter once per solve (no-op when Q is constant).
Q is symmetric, so column-major == row-major and the copy is verbatim. */
static void refresh_param_values_qf(quad_form_expr *qnode)
{
if (qnode->param_source == NULL || !qnode->base.needs_parameter_refresh)
{
return;
}
qnode->base.needs_parameter_refresh = false;
memcpy(qnode->Q->x, qnode->param_source->value,
(size_t) qnode->n * qnode->n * sizeof(double));
}

static void forward(expr *node, const double *u)
{
quad_form_expr *qnode = (quad_form_expr *) node;
expr *x = node->left;

/* refresh Q from the parameter if needed (no-op on the constant/sparse path) */
if (qnode->param_source != NULL && node->needs_parameter_refresh)
{
qnode->param_source->forward(qnode->param_source, NULL);
}
refresh_param_values_qf(qnode);

/* child's forward pass */
x->forward(x, u);

/* local forward pass */
CSR_matrix *Q = ((quad_form_expr *) node)->Q;
Ax_csr(Q, x->value, node->work->dwork, 0);
node->value[0] = 0.0;

for (int i = 0; i < x->size; i++)
{
node->value[0] += x->value[i] * node->work->dwork[i];
}
/* dwork = Q @ x; value = x' (Q x) */
matrix *Q = qnode->Q;
Q->block_left_mult_vec(Q, x->value, node->work->dwork, 1);
node->value[0] = cblas_ddot(qnode->n, x->value, 1, node->work->dwork, 1);
}

static void jacobian_init_impl(expr *node)
Expand Down Expand Up @@ -90,15 +112,16 @@ static void jacobian_init_impl(expr *node)

static void eval_jacobian(expr *node)
{
quad_form_expr *qnode = (quad_form_expr *) node;
expr *x = node->left;
CSR_matrix *Q = ((quad_form_expr *) node)->Q;
CSR_matrix *jac = node->jacobian->to_csr(node->jacobian);

if (x->var_id != NOT_A_VARIABLE)
{
/* jacobian = 2 * (Q @ x)^T */
Ax_csr(Q, x->value, jac->x, 0);
cblas_dscal(x->size, 2.0, jac->x, 1);
/* jacobian = 2 * (Q @ x)^T (leaf x: sparsity is the variable block) */
matrix *Q = qnode->Q;
Q->block_left_mult_vec(Q, x->value, jac->x, 1);
cblas_dscal(qnode->n, 2.0, jac->x, 1);
}
else
{
Expand All @@ -124,9 +147,12 @@ static void eval_jacobian(expr *node)
}
}

static void wsum_hess_init_impl(expr *node)
/* Sparse-backend hessian. The non-leaf chain rule J_f^T Q J_f uses raw CSR/CSC
symmetric products that have no matrix-vtable equivalent. */
static void wsum_hess_init_sparse(expr *node)
{
CSR_matrix *Q = ((quad_form_expr *) node)->Q;
quad_form_expr *qnode = (quad_form_expr *) node;
CSR_matrix *Q = qnode->Q->to_csr(qnode->Q);
expr *x = node->left;

if (x->var_id != NOT_A_VARIABLE)
Expand Down Expand Up @@ -160,7 +186,6 @@ static void wsum_hess_init_impl(expr *node)
*/

/* jacobian_csc_init(x) already called in jacobian_init */
quad_form_expr *qnode = (quad_form_expr *) node;
CSC_matrix *Jf = x->work->jacobian_csc;

/* term1 = Jf^T W Jf = Jf^T B*/
Expand All @@ -181,9 +206,10 @@ static void wsum_hess_init_impl(expr *node)
}
}

static void eval_wsum_hess(expr *node, const double *w)
static void eval_wsum_hess_sparse(expr *node, const double *w)
{
CSR_matrix *Q = ((quad_form_expr *) node)->Q;
quad_form_expr *qnode = (quad_form_expr *) node;
CSR_matrix *Q = qnode->Q->to_csr(qnode->Q);
expr *x = node->left;
double two_w = 2.0 * w[0];

Expand All @@ -209,10 +235,10 @@ static void eval_wsum_hess(expr *node, const double *w)
}
}

CSC_matrix *QJf = ((quad_form_expr *) node)->QJf;
CSC_matrix *QJf = qnode->QJf;
CSR_matrix *term1 = node->work->hess_term1->to_csr(node->work->hess_term1);

/* term1 = J_f^T Q J_f = J_f^T B */
/* term1 = J_f^T Q J_f = J_f^T B */
BA_fill_values(Q, Jf, QJf);
BTDA_fill_values(Jf, QJf, NULL, term1);

Expand All @@ -233,16 +259,124 @@ static void eval_wsum_hess(expr *node, const double *w)
}
}

/* Dense-backend hessian. Leaf x: 2wQ materialized as a permuted_dense block (the
fast common case). Composition x = f(u): the chain rule
H = J_f^T (2w Q) J_f + sum_i (2w Q f(u))_i nabla^2 f_i = term1 + term2,
with term1 formed via the PD matmul dispatchers (Q symmetric PD so QJf = Q J_f
is PD; J_f^T Q J_f = (Q J_f)^T J_f keeps the PD operand on the dispatch key). */
static void wsum_hess_init_dense(expr *node)
{
quad_form_expr *qnode = (quad_form_expr *) node;
expr *x = node->left;
int n = qnode->n;

if (x->var_id != NOT_A_VARIABLE)
{
/* Hessian is the dense block 2Q over x's contiguous variable range. */
int *perm = (int *) sp_malloc(n * sizeof(int));
for (int i = 0; i < n; i++)
{
perm[i] = x->var_id + i;
}
node->wsum_hess =
new_permuted_dense(node->n_vars, node->n_vars, n, n, perm, perm, NULL);
sp_free(perm);
}
else
{
/* The dispatchers read a sparse child jacobian through its csc_cache. */
if (!x->jacobian->is_permuted_dense && !x->jacobian->is_stacked_pd)
{
sparse_matrix_ensure_csc_cache((sparse_matrix *) x->jacobian);
}

/* term1 = J_f^T Q J_f = (Q J_f)^T J_f. QJf is PD; passing it as the
transposed operand B keeps the PD type on the dispatch key. */
permuted_dense *Q_pd = (permuted_dense *) qnode->Q;
qnode->QJf_dense = BA_pd_matrices_alloc(Q_pd, x->jacobian);
node->work->hess_term1 = BTA_matrices_alloc(x->jacobian, qnode->QJf_dense);
qnode->diag_w = (double *) sp_malloc(n * sizeof(double));

/* term2 = sum_i (Q f(x))_i nabla^2 f_i */
wsum_hess_init(x);
node->work->hess_term2 = x->wsum_hess->copy_sparsity(x->wsum_hess);

/* hess = term1 + term2 (CSR-backed; sum_matrices is type-agnostic) */
int max_nnz = node->work->hess_term1->nnz + node->work->hess_term2->nnz;
node->wsum_hess =
new_sparse_matrix_alloc(node->n_vars, node->n_vars, max_nnz);
sum_matrices_alloc(node->work->hess_term1, node->work->hess_term2,
node->wsum_hess);
}
}

static void eval_wsum_hess_dense(expr *node, const double *w)
{
quad_form_expr *qnode = (quad_form_expr *) node;
expr *x = node->left;
double two_w = 2.0 * w[0];

if (x->var_id != NOT_A_VARIABLE)
{
int nn = qnode->n * qnode->n;
/* Hessian = 2 w Q (Q symmetric, constant up to the weight). The PD's value
buffer (->x) aliases its dense block, so writing it updates to_csr too. */
memcpy(node->wsum_hess->x, qnode->Q->x, nn * sizeof(double));
cblas_dscal(nn, two_w, node->wsum_hess->x, 1);
}
else
{
/* Mirror the child jacobian's current values into its csc_cache; the PD
dispatchers below read from it. */
x->jacobian->refresh_csc_values(x->jacobian);

/* term1 = 2w J_f^T Q J_f. The dispatcher fill is B^T diag(d) A (no plain
Comment thread
Transurgeon marked this conversation as resolved.
B^T A form); a constant diagonal d = 2w carries the weight.
Potential TODO: Add back BTA_matrices_fill_values_kernel so we don't have
to form diag_w. */
for (int i = 0; i < qnode->n; i++)
{
qnode->diag_w[i] = two_w;
}
BA_pd_matrices_fill_values((permuted_dense *) qnode->Q, x->jacobian,
(permuted_dense *) qnode->QJf_dense);
BTDA_matrices_fill_values(x->jacobian, qnode->diag_w, qnode->QJf_dense,
node->work->hess_term1);

/* term2 = 2w sum_i (Q f(x))_i nabla^2 f_i (dwork = Q f(x) from forward) */
x->eval_wsum_hess(x, node->work->dwork);
memcpy(node->work->hess_term2->x, x->wsum_hess->x,
x->wsum_hess->nnz * sizeof(double));
cblas_dscal(node->work->hess_term2->nnz, two_w, node->work->hess_term2->x,
1);

sum_matrices_fill_values(node->work->hess_term1, node->work->hess_term2,
node->wsum_hess);
}
}

static void free_type_data(expr *node)
{
quad_form_expr *qnode = (quad_form_expr *) node;
free_CSR_matrix(qnode->Q);
free_matrix(qnode->Q);
qnode->Q = NULL;
if (qnode->QJf != NULL)
{
free_CSC_matrix(qnode->QJf);
qnode->QJf = NULL;
}
if (qnode->QJf_dense != NULL)
{
free_matrix(qnode->QJf_dense);
qnode->QJf_dense = NULL;
}
if (qnode->diag_w != NULL)
{
sp_free(qnode->diag_w);
qnode->diag_w = NULL;
}
free_expr(qnode->param_source);
qnode->param_source = NULL;
}

static bool is_affine(const expr *node)
Expand All @@ -252,22 +386,71 @@ static bool is_affine(const expr *node)
return false;
}

expr *new_quad_form(expr *left, CSR_matrix *Q)
expr *new_quad_form_sparse(expr *left, CSR_matrix *Q)
{
assert(left->d1 == 1 || left->d2 == 1); /* left must be a vector */
quad_form_expr *qnode = (quad_form_expr *) sp_calloc(1, sizeof(quad_form_expr));
expr *node = &qnode->base;

init_expr(node, 1, 1, left->n_vars, forward, jacobian_init_impl, eval_jacobian,
is_affine, wsum_hess_init_impl, eval_wsum_hess, free_type_data);
is_affine, wsum_hess_init_sparse, eval_wsum_hess_sparse,
free_type_data);
node->left = left;
expr_retain(left);

/* Set type-specific field */
qnode->Q = new_CSR_matrix(Q->m, Q->n, Q->nnz);
copy_CSR_matrix(Q, qnode->Q);
/* Set type-specific field. new_sparse_matrix takes ownership, so clone. */
qnode->Q = new_sparse_matrix(new_csr(Q));
qnode->n = left->size; /* quadratic dimension; used by the shared forward */

/* dwork stores the result of Q @ f(x) in the forward pass */
node->work->dwork = (double *) sp_malloc(left->size * sizeof(double));
return node;
}

expr *new_quad_form_dense(expr *child, int n, const double *P_data,
expr *param_source)
{
assert(child->d1 == 1 || child->d2 == 1); /* child must be a vector */
assert(child->size == n);

quad_form_expr *qnode = (quad_form_expr *) sp_calloc(1, sizeof(quad_form_expr));
expr *node = &qnode->base;

init_expr(node, 1, 1, child->n_vars, forward, jacobian_init_impl, eval_jacobian,
is_affine, wsum_hess_init_dense, eval_wsum_hess_dense, free_type_data);
node->left = child;
expr_retain(child);

qnode->n = n;
/* dwork stores Q @ x in the forward pass */
node->work->dwork = (double *) sp_malloc(n * sizeof(double));

qnode->param_source = param_source;
if (param_source != NULL)
{
if (P_data != NULL)
{
fprintf(stderr, "Error in new_quad_form_dense: param and data both "
"set\n");
exit(1);
}

expr_retain(param_source);

/* Q is filled from the parameter on the first forward pass. */
qnode->Q = new_permuted_dense_full(n, n, NULL);
node->needs_parameter_refresh = true;
}
else
{
if (P_data == NULL)
{
fprintf(stderr, "Error in new_quad_form_dense: need P data\n");
exit(1);
}

qnode->Q = new_permuted_dense_full(n, n, P_data);
}

return node;
}
Loading
Loading