Skip to content

Commit ec03bfb

Browse files
authored
Merge pull request mathnet#912 from jkalias/sparse-pointwise-multiplication-division
Pointwise multiplication and division on sparse matrices
2 parents dfc9db0 + 595528e commit ec03bfb

2 files changed

Lines changed: 28 additions & 3 deletions

File tree

src/Numerics.Tests/LinearAlgebraTests/MatrixTests.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
// </copyright>
2929

3030
using System;
31-
using System.Linq;
3231
using MathNet.Numerics.LinearAlgebra;
3332
using MathNet.Numerics.LinearAlgebra.Double;
3433
using MathNet.Numerics.LinearAlgebra.Storage;
@@ -97,5 +96,25 @@ public void SparseCompressedRowMatrixStorage_CoordinateFormatDuplicateRemovalChe
9796
Assert.True(Math.Abs(actual - expected) < tol, $"Expected {expected:E6} at ({r}, {c}), but got {actual:E6}");
9897
}
9998
}
99+
100+
[Test]
101+
public void PointwiseMultiplication_SparseReturnsSameResultAsDenseTest()
102+
{
103+
var x = SparseMatrix.OfDiagonalArray(new double[] { 1, 2, 3, 4 });
104+
var y = DenseMatrix.OfDiagonalArray(new double[] { 5, -6, 7, -8 });
105+
x.PointwiseMultiply(y, x);
106+
var result = SparseMatrix.OfDiagonalArray(new double[] { 5, -12, 21, -32 });
107+
Assert.AreEqual(result, x);
108+
}
109+
110+
[Test]
111+
public void PointwiseDivision_SparseReturnsSameResultAsDenseTest()
112+
{
113+
var x = SparseMatrix.OfDiagonalArray(new double[] { 30, 20, 10 });
114+
var y = DenseMatrix.OfDiagonalArray(new double[] { 3, 4, 5 });
115+
x.PointwiseDivide(y, x);
116+
var result = SparseMatrix.OfDiagonalArray(new double[] { 10, 5, 2 });
117+
Assert.AreEqual(result, x);
118+
}
100119
}
101120
}

src/Numerics/LinearAlgebra/Double/SparseMatrix.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,10 @@ protected override void DoTransposeThisAndMultiply(Vector<double> rightSide, Vec
11761176
/// <param name="result">The matrix to store the result of the pointwise multiplication.</param>
11771177
protected override void DoPointwiseMultiply(Matrix<double> other, Matrix<double> result)
11781178
{
1179-
result.Clear();
1179+
if (!ReferenceEquals(this, result))
1180+
{
1181+
result.Clear();
1182+
}
11801183

11811184
var rowPointers = _storage.RowPointers;
11821185
var columnIndices = _storage.ColumnIndices;
@@ -1203,7 +1206,10 @@ protected override void DoPointwiseMultiply(Matrix<double> other, Matrix<double>
12031206
/// <param name="result">The matrix to store the result of the pointwise division.</param>
12041207
protected override void DoPointwiseDivide(Matrix<double> divisor, Matrix<double> result)
12051208
{
1206-
result.Clear();
1209+
if (!ReferenceEquals(this, result))
1210+
{
1211+
result.Clear();
1212+
}
12071213

12081214
var rowPointers = _storage.RowPointers;
12091215
var columnIndices = _storage.ColumnIndices;

0 commit comments

Comments
 (0)