Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion SparseDiffEngine
Submodule SparseDiffEngine updated 71 files
+15 −1 include/atoms/non_elementwise_full_dom.h
+13 −3 include/subexpr.h
+10 −9 include/utils/Vec_macros.h
+14 −0 include/utils/matrix.h
+63 −5 include/utils/tracked_alloc.h
+1 −1 src/atoms/affine/add.c
+2 −2 src/atoms/affine/broadcast.c
+7 −6 src/atoms/affine/convolve.c
+3 −2 src/atoms/affine/diag_mat.c
+2 −2 src/atoms/affine/diag_vec.c
+8 −6 src/atoms/affine/hstack.c
+6 −6 src/atoms/affine/index.c
+6 −6 src/atoms/affine/left_matmul.c
+1 −1 src/atoms/affine/neg.c
+1 −1 src/atoms/affine/parameter.c
+1 −1 src/atoms/affine/promote.c
+1 −1 src/atoms/affine/reshape.c
+4 −4 src/atoms/affine/right_matmul.c
+1 −1 src/atoms/affine/scalar_mult.c
+8 −42 src/atoms/affine/sum.c
+6 −6 src/atoms/affine/trace.c
+3 −3 src/atoms/affine/transpose.c
+3 −2 src/atoms/affine/upper_tri.c
+1 −1 src/atoms/affine/variable.c
+2 −2 src/atoms/affine/vector_mult.c
+2 −2 src/atoms/affine/vstack.c
+10 −9 src/atoms/bivariate_full_dom/matmul.c
+6 −6 src/atoms/bivariate_full_dom/multiply.c
+5 −5 src/atoms/bivariate_restricted_dom/quad_over_lin.c
+1 −1 src/atoms/bivariate_restricted_dom/rel_entr.c
+2 −3 src/atoms/bivariate_restricted_dom/rel_entr_scalar_vector.c
+2 −3 src/atoms/bivariate_restricted_dom/rel_entr_vector_scalar.c
+6 −7 src/atoms/elementwise_full_dom/common.c
+1 −1 src/atoms/elementwise_full_dom/power.c
+1 −1 src/atoms/elementwise_restricted_dom/common.c
+1 −1 src/atoms/other/prod.c
+7 −7 src/atoms/other/prod_axis_one.c
+7 −7 src/atoms/other/prod_axis_zero.c
+222 −37 src/atoms/other/quad_form.c
+12 −12 src/expr.c
+3 −3 src/old-code/linear_op.c
+12 −12 src/old-code/old_permuted_dense.c
+40 −26 src/problem.c
+14 −14 src/utils/COO_matrix.c
+20 −20 src/utils/CSC_matrix.c
+10 −10 src/utils/CSR_matrix.c
+2 −2 src/utils/CSR_sum.c
+2 −2 src/utils/int_double_pair.c
+9 −9 src/utils/linalg_dense_sparse_matmuls.c
+6 −6 src/utils/linalg_sparse_matmuls.c
+2 −2 src/utils/matrix_sum.c
+172 −29 src/utils/permuted_dense.c
+2 −2 src/utils/permuted_dense_linalg.c
+51 −9 src/utils/sparse_matrix.c
+113 −22 src/utils/stacked_pd.c
+18 −18 src/utils/stacked_pd_coalesce.c
+15 −15 src/utils/stacked_pd_kron_linalg.c
+11 −11 src/utils/stacked_pd_linalg.c
+1 −0 src/utils/tracked_alloc.c
+2 −2 src/utils/utils.c
+12 −1 tests/all_tests.c
+35 −0 tests/jacobian_tests/affine/test_sum.h
+2 −2 tests/jacobian_tests/composite/test_chain_rule_jacobian.h
+2 −2 tests/jacobian_tests/other/test_quad_form.h
+2 −4 tests/problem/test_param_broadcast.h
+9 −9 tests/profiling/profile_hessian_exp_AX.h
+59 −0 tests/profiling/profile_memory.h
+96 −0 tests/utils/test_permuted_dense.h
+5 −5 tests/wsum_hess/composite/test_chain_rule_wsum_hess.h
+1 −1 tests/wsum_hess/other/test_quad_form.h
+159 −0 tests/wsum_hess/other/test_quad_form_dense.h
157 changes: 125 additions & 32 deletions sparsediffpy/_bindings/atoms/quad_form.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,38 @@
#include "common.h"
#include "non_elementwise_full_dom.h"

/* Quadratic form: y = x' * Q * x where Q is a constant matrix */
/* Quadratic form: y = x' Q x.
*
* Python signatures (mirroring make_left_matmul):
* make_quad_form(None, child, "sparse", data, indices, indptr, m, n)
* make_quad_form(None, child, "dense", P_data_flat, n)
* make_quad_form(src, child, "dense", None, n)
*
* - "sparse": Q is a constant CSR matrix (param must be None).
* - "dense": Q is a dense n x n symmetric matrix, either a constant buffer
* (src None, P_data given) or supplied by a "source" expression
* capsule (src given, P_data None) -- a parameter, or any
* matrix-valued expression of parameters (e.g. reshape(M @ theta)) --
* which is re-evaluated each solve.
*/
static PyObject *py_make_quad_form(PyObject *self, PyObject *args)
{
PyObject *child_capsule;
PyObject *data_obj, *indices_obj, *indptr_obj;
int m, n;
Py_ssize_t nargs = PyTuple_Size(args);
if (nargs < 4)
{
PyErr_SetString(PyExc_TypeError,
"make_quad_form requires at least 4 arguments");
return NULL;
}

PyObject *param_obj = PyTuple_GetItem(args, 0);
PyObject *child_capsule = PyTuple_GetItem(args, 1);
PyObject *fmt_obj = PyTuple_GetItem(args, 2);

if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj,
&indptr_obj, &m, &n))
if (!PyUnicode_Check(fmt_obj))
{
PyErr_SetString(PyExc_TypeError,
"third argument must be 'sparse' or 'dense'");
return NULL;
}

Expand All @@ -24,41 +46,112 @@ static PyObject *py_make_quad_form(PyObject *self, PyObject *args)
return NULL;
}

PyArrayObject *data_array =
(PyArrayObject *) PyArray_FROM_OTF(data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
PyArrayObject *indices_array = (PyArrayObject *) PyArray_FROM_OTF(
indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);
PyArrayObject *indptr_array = (PyArrayObject *) PyArray_FROM_OTF(
indptr_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);
const char *fmt = PyUnicode_AsUTF8(fmt_obj);

if (!data_array || !indices_array || !indptr_array)
if (strcmp(fmt, "sparse") == 0)
{
Py_XDECREF(data_array);
Py_XDECREF(indices_array);
Py_XDECREF(indptr_array);
return NULL;
}
/* Parse: param_or_none, child, "sparse", data, indices, indptr, m, n */
PyObject *data_obj, *indices_obj, *indptr_obj;
int m, n;
if (!PyArg_ParseTuple(args, "OOsOOOii", &param_obj, &child_capsule, &fmt,
&data_obj, &indices_obj, &indptr_obj, &m, &n))
{
return NULL;
}
if (param_obj != Py_None)
{
PyErr_SetString(PyExc_ValueError,
"parameter for a sparse quad_form is not supported");
return NULL;
}

PyArrayObject *data_array = (PyArrayObject *) PyArray_FROM_OTF(
data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
PyArrayObject *indices_array = (PyArrayObject *) PyArray_FROM_OTF(
indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);
PyArrayObject *indptr_array = (PyArrayObject *) PyArray_FROM_OTF(
indptr_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);

int nnz = (int) PyArray_SIZE(data_array);
CSR_matrix *Q = new_CSR_matrix(m, n, nnz);
memcpy(Q->x, PyArray_DATA(data_array), nnz * sizeof(double));
memcpy(Q->i, PyArray_DATA(indices_array), nnz * sizeof(int));
memcpy(Q->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int));
if (!data_array || !indices_array || !indptr_array)
{
Py_XDECREF(data_array);
Py_XDECREF(indices_array);
Py_XDECREF(indptr_array);
return NULL;
}

Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);
int nnz = (int) PyArray_SIZE(data_array);
CSR_matrix *Q = new_CSR_matrix(m, n, nnz);
memcpy(Q->x, PyArray_DATA(data_array), nnz * sizeof(double));
memcpy(Q->i, PyArray_DATA(indices_array), nnz * sizeof(int));
memcpy(Q->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int));

expr *node = new_quad_form(child, Q);
free_CSR_matrix(Q);
Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);

if (!node)
expr *node = new_quad_form_sparse(child, Q);
free_CSR_matrix(Q);

if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create quad_form node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}
else if (strcmp(fmt, "dense") == 0)
{
/* Parse: param_or_none, child, "dense", P_data_or_none, n.
* (param_node, P_data) are mutually exclusive, like left_matmul. */
PyObject *data_obj;
int n;
if (!PyArg_ParseTuple(args, "OOsOi", &param_obj, &child_capsule, &fmt,
&data_obj, &n))
{
return NULL;
}

expr *node;
if (param_obj == Py_None)
{
PyArrayObject *data_array = (PyArrayObject *) PyArray_FROM_OTF(
data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
if (!data_array)
{
return NULL;
}
double *P_data = (double *) PyArray_DATA(data_array);
node = new_quad_form_dense(child, n, P_data, NULL);
Py_DECREF(data_array);
}
else
{
expr *param_node =
(expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
if (!param_node)
{
PyErr_SetString(PyExc_ValueError, "invalid parameter capsule");
return NULL;
}
node = new_quad_form_dense(child, n, NULL, param_node);
}

if (!node)
{
PyErr_SetString(PyExc_RuntimeError,
"failed to create dense quad_form node");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}
else
{
PyErr_SetString(PyExc_RuntimeError, "failed to create quad_form node");
PyErr_SetString(PyExc_ValueError, "format must be 'sparse' or 'dense'");
return NULL;
}
expr_retain(node); /* Capsule owns a reference */
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_QUAD_FORM_H */
Loading