Skip to content

Commit 1834d64

Browse files
authored
Merge pull request #894 from Silver-Fang/master
Fix in Distributions.Beta: avoid possible NaN caused by x==y==0 when a and b are both small
2 parents e989ab7 + c578a38 commit 1834d64

2 files changed

Lines changed: 24 additions & 3 deletions

File tree

src/Numerics.Tests/DistributionTests/Continuous/BetaTests.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,5 +405,13 @@ public void ValidateInverseCumulativeDistribution(double a, double b, double x,
405405
Assert.That(dist.InverseCumulativeDistribution(p), Is.EqualTo(x).Within(1e-6));
406406
Assert.That(Beta.InvCDF(a, b, p), Is.EqualTo(x).Within(1e-6));
407407
}
408+
409+
[TestCase(0.001)]
410+
public void ProbableNaNWhenABBothSmall(double ab)
411+
{
412+
Beta dist = new Beta(ab, ab);
413+
for (byte i = 0; i < 100; ++i)
414+
Assert.That(!double.IsNaN(dist.Sample()),"Generate NaN");
415+
}
408416
}
409417
}

src/Numerics/Distributions/Beta.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,22 @@ public IEnumerable<double> Samples()
392392
/// <returns>a random number from the Beta distribution.</returns>
393393
internal static double SampleUnchecked(System.Random rnd, double a, double b)
394394
{
395-
var x = Gamma.SampleUnchecked(rnd, a, 1.0);
396-
var y = Gamma.SampleUnchecked(rnd, b, 1.0);
397-
return x/(x + y);
395+
double x, y;
396+
if (a == b)
397+
{
398+
x = Gamma.SampleUnchecked(rnd, a, 1.0);
399+
y = Gamma.SampleUnchecked(rnd, b, 1.0);
400+
//When a==b (and possibly a==b==0), return value is equally possible to be 0 or 1
401+
if (x == 0 && y == 0)
402+
return Bernoulli.Sample(0.5);//In particular, when a==b==0, Beta distribution degradates to Bernoulli distribution.
403+
}
404+
else
405+
do
406+
{
407+
x = Gamma.SampleUnchecked(rnd, a, 1.0);
408+
y = Gamma.SampleUnchecked(rnd, b, 1.0);
409+
} while (x == 0 && y == 0);//When a!=b, return value is not equally possible to be 0 or 1. Regenerate.
410+
return x / (x + y);
398411
}
399412

400413
internal static void SamplesUnchecked(System.Random rnd, double[] values, double a, double b)

0 commit comments

Comments
 (0)