From 557284006a88325c77ddec3535b19fcd78b9e43f Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sat, 17 Dec 2022 18:37:08 -0300 Subject: [PATCH] improve mle for NegativeBinomial (#147) --- preliz/distributions/discrete.py | 10 ++-------- preliz/tests/test_distributions.py | 4 ++-- preliz/utils/optimization.py | 10 +++++++--- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/preliz/distributions/discrete.py b/preliz/distributions/discrete.py index d0d563af..429d6ef3 100644 --- a/preliz/distributions/discrete.py +++ b/preliz/distributions/discrete.py @@ -11,7 +11,7 @@ from .distributions import Discrete -from ..utils.optimization import optimize_matching_moments +from ..utils.optimization import optimize_matching_moments, optimize_ml _log = logging.getLogger("preliz") @@ -298,13 +298,7 @@ def _fit_moments(self, mean, sigma): self._update(mu, alpha) def _fit_mle(self, sample): - # the upper bound is based on a quick heuristic. The fit will underestimate - # the value of n when p is very close to 1. - fitted = stats.fit(self.dist, sample, bounds={"n": (1, max(sample) * 2)}) - if not fitted.success: - _log.info("Optimization did not terminate successfully.") - mu, alpha = self._from_p_n(fitted.params.p, fitted.params.n) # pylint: disable=no-member - self._update(mu, alpha) + optimize_ml(self, sample) class Poisson(Discrete): diff --git a/preliz/tests/test_distributions.py b/preliz/tests/test_distributions.py index ceee7e22..6a721c38 100644 --- a/preliz/tests/test_distributions.py +++ b/preliz/tests/test_distributions.py @@ -93,7 +93,7 @@ def test_moments(distribution, params): (Gumbel, (0, 1)), (HalfCauchy, (1,)), (HalfNormal, (1,)), - (HalfStudent, (10, 1)), + (HalfStudent, (100, 1)), (InverseGamma, (3, 0.5)), (Laplace, (0, 1)), (Logistic, (0, 1)), @@ -127,7 +127,7 @@ def test_mle(distribution, params): assert_almost_equal(dist.rv_frozen.mean(), dist_.rv_frozen.mean(), 1) assert_almost_equal(dist.rv_frozen.std(), dist_.rv_frozen.std(), 1) - if dist.name == "student": + if dist.name in "student": assert_almost_equal(params[1:], dist_.params[1:], 0) else: assert_almost_equal(params, dist_.params, 0) diff --git a/preliz/utils/optimization.py b/preliz/utils/optimization.py index 50f398dd..617d467e 100644 --- a/preliz/utils/optimization.py +++ b/preliz/utils/optimization.py @@ -102,10 +102,14 @@ def func(params, dist, mean, sigma): def optimize_ml(dist, sample): def negll(params, dist, sample): dist._update(*params) - return -dist.rv_frozen.logpdf(sample).sum() + if dist.kind == "continuous": + neg = -dist.rv_frozen.logpdf(sample).sum() + else: + neg = -dist.rv_frozen.logpmf(sample).sum() + return neg - dist._fit_moments(0, np.std(sample)) - init_vals = dist.params[::-1] + dist._fit_moments(np.mean(sample), np.std(sample)) + init_vals = dist.params opt = minimize(negll, x0=init_vals, bounds=dist.params_support, args=(dist, sample))