Skip to content

Commit

Permalink
improve mle for NegativeBinomial (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Dec 17, 2022
1 parent 2ea67b4 commit 5572840
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
10 changes: 2 additions & 8 deletions preliz/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


from .distributions import Discrete
from ..utils.optimization import optimize_matching_moments
from ..utils.optimization import optimize_matching_moments, optimize_ml

_log = logging.getLogger("preliz")

Expand Down Expand Up @@ -298,13 +298,7 @@ def _fit_moments(self, mean, sigma):
self._update(mu, alpha)

def _fit_mle(self, sample):
# the upper bound is based on a quick heuristic. The fit will underestimate
# the value of n when p is very close to 1.
fitted = stats.fit(self.dist, sample, bounds={"n": (1, max(sample) * 2)})
if not fitted.success:
_log.info("Optimization did not terminate successfully.")
mu, alpha = self._from_p_n(fitted.params.p, fitted.params.n) # pylint: disable=no-member
self._update(mu, alpha)
optimize_ml(self, sample)


class Poisson(Discrete):
Expand Down
4 changes: 2 additions & 2 deletions preliz/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_moments(distribution, params):
(Gumbel, (0, 1)),
(HalfCauchy, (1,)),
(HalfNormal, (1,)),
(HalfStudent, (10, 1)),
(HalfStudent, (100, 1)),
(InverseGamma, (3, 0.5)),
(Laplace, (0, 1)),
(Logistic, (0, 1)),
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_mle(distribution, params):

assert_almost_equal(dist.rv_frozen.mean(), dist_.rv_frozen.mean(), 1)
assert_almost_equal(dist.rv_frozen.std(), dist_.rv_frozen.std(), 1)
if dist.name == "student":
if dist.name in "student":
assert_almost_equal(params[1:], dist_.params[1:], 0)
else:
assert_almost_equal(params, dist_.params, 0)
Expand Down
10 changes: 7 additions & 3 deletions preliz/utils/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,14 @@ def func(params, dist, mean, sigma):
def optimize_ml(dist, sample):
def negll(params, dist, sample):
dist._update(*params)
return -dist.rv_frozen.logpdf(sample).sum()
if dist.kind == "continuous":
neg = -dist.rv_frozen.logpdf(sample).sum()
else:
neg = -dist.rv_frozen.logpmf(sample).sum()
return neg

dist._fit_moments(0, np.std(sample))
init_vals = dist.params[::-1]
dist._fit_moments(np.mean(sample), np.std(sample))
init_vals = dist.params

opt = minimize(negll, x0=init_vals, bounds=dist.params_support, args=(dist, sample))

Expand Down

0 comments on commit 5572840

Please sign in to comment.