Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more distributions to roulette #287

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions preliz/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,8 @@ def _parametrization(self, mu=None, sigma=None, nu=None):
self.sigma = sigma
self.param_names = ("mu", "sigma", "nu")
self.params = (mu, sigma, nu)
self.params_support = ((-np.inf, np.inf), (eps, np.inf), (eps, np.inf))
# if nu is too small we get a non-smooth distribution
self.params_support = ((-np.inf, np.inf), (eps, np.inf), (1e-4, np.inf))
if all_not_none(mu, sigma, nu):
self._update(mu, sigma, nu)

Expand All @@ -602,7 +603,7 @@ def _update(self, mu, sigma, nu):

def _fit_moments(self, mean, sigma):
# Just assume this is a approximately Gaussian
self._update(mean, sigma, 1e-6)
self._update(mean, sigma, 1e-4)

def _fit_mle(self, sample, **kwargs):
K, mu, sigma = self.dist.fit(sample, **kwargs)
Expand Down
16 changes: 7 additions & 9 deletions preliz/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,9 +856,7 @@ def _update(self, mu, alpha):
self._update_rv_frozen()

def _fit_moments(self, mean, sigma):
mu = mean
alpha = mean**2 / (sigma**2 - mean)
self._update(mu, alpha)
optimize_moments(self, mean, sigma)

def _fit_mle(self, sample):
optimize_ml(self, sample)
Expand Down Expand Up @@ -981,7 +979,7 @@ def __init__(self, psi=None, n=None, p=None):
self.psi = psi
self.n = n
self.p = p
self.dist = ZIBinomial
self.dist = _ZIBinomial
self.support = (0, np.inf)
self._parametrization(psi, n, p)

Expand Down Expand Up @@ -1092,7 +1090,7 @@ def __init__(self, psi=None, mu=None, alpha=None, p=None, n=None):
self.p = p
self.alpha = alpha
self.mu = mu
self.dist = ZINegativeBinomial
self.dist = _ZINegativeBinomial
self.support = (0, np.inf)
self._parametrization(psi, mu, alpha, p, n)

Expand Down Expand Up @@ -1204,7 +1202,7 @@ def __init__(self, psi=None, mu=None):
super().__init__()
self.psi = psi
self.mu = mu
self.dist = ZIPoisson
self.dist = _ZIPoisson
self.support = (0, np.inf)
self._parametrization(psi, mu)

Expand Down Expand Up @@ -1238,7 +1236,7 @@ def _fit_mle(self, sample):
optimize_ml(self, sample)


class ZIBinomial(stats.rv_continuous):
class _ZIBinomial(stats.rv_continuous):
def __init__(self, psi=None, n=None, p=None):
super().__init__()
self.psi = psi
Expand Down Expand Up @@ -1296,7 +1294,7 @@ def rvs(self, size=1): # pylint: disable=arguments-differ
return samples


class ZINegativeBinomial(stats.rv_continuous):
class _ZINegativeBinomial(stats.rv_continuous):
def __init__(self, psi=None, p=None, n=None):
super().__init__()
self.psi = psi
Expand Down Expand Up @@ -1355,7 +1353,7 @@ def rvs(self, size=1): # pylint: disable=arguments-differ
return samples


class ZIPoisson(stats.rv_continuous):
class _ZIPoisson(stats.rv_continuous):
def __init__(self, psi=None, mu=None):
super().__init__()
self.psi = psi
Expand Down
4 changes: 3 additions & 1 deletion preliz/internal/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ def func(params, dist, x_vals, ecdf):
return loss

init_vals = np.array(dist.params)[none_idx]
bounds = np.array(dist.params_support)[none_idx]
bounds = list(zip(*bounds))

opt = least_squares(func, x0=init_vals, args=(dist, x_vals, ecdf))
opt = least_squares(func, x0=init_vals, args=(dist, x_vals, ecdf), bounds=bounds)
dist._update(*opt["x"])
loss = opt["cost"]
return loss
Expand Down
2 changes: 2 additions & 0 deletions preliz/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def test_moments(distribution, params):
"BetaBinomial",
"Binomial",
"DiscreteWeibull",
"ExGaussian",
"NegativeBinomial",
"Kumaraswamy",
"LogitNormal",
"Rice",
Expand Down
2 changes: 1 addition & 1 deletion preliz/tests/test_maxent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
(Cauchy(), -1, 1, 0.6, (-np.inf, np.inf), (0, 0.726)),
(Cauchy(alpha=0.5), -1, 1, 0.6, (-np.inf, np.inf), (0.6000)),
(ChiSquared(), 2, 7, 0.6, (0, np.inf), (4.002)),
(ExGaussian(), 9, 10, 0.8, (-np.inf, np.inf), (9.112, 0.133, 0.495)),
(ExGaussian(), 9, 10, 0.8, (-np.inf, np.inf), (9.496, 0.390, 0.003)),
(ExGaussian(sigma=0.2), 9, 10, 0.8, (-np.inf, np.inf), (9.168, 0.423)),
(Exponential(), 0, 4, 0.9, (0, np.inf), (0.575)),
(Gamma(), 0, 10, 0.7, (0, np.inf), (0.868, 0.103)),
Expand Down
45 changes: 41 additions & 4 deletions preliz/unidimensional/roulette.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,17 @@ def reset_dist_panel(x_min, x_max, ax, yticks):
def get_widgets(x_min, x_max, nrows, ncols):

width_entry_text = widgets.Layout(width="150px")
width_distribution_text = widgets.Layout(width="150px", height="125px")

w_x_min = widgets.IntText(
w_x_min = widgets.FloatText(
value=x_min,
step=1,
description="x_min:",
disabled=False,
layout=width_entry_text,
)

w_x_max = widgets.IntText(
w_x_max = widgets.FloatText(
value=x_max,
step=1,
description="x_max:",
Expand Down Expand Up @@ -343,10 +344,46 @@ def get_widgets(x_min, x_max, nrows, ncols):
layout=width_entry_text,
)

dist_names = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]
default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]

dist_names = [
"AsymmetricLaplace",
"BetaScaled",
"ChiSquared",
"ExGaussian",
"Exponential",
"Gamma",
"Gumbel",
"HalfNormal",
"HalfStudentT",
"InverseGamma",
"Laplace",
"LogNormal",
"Logistic",
# "LogitNormal", # fails if we add chips at x_value= 1
"Moyal",
"Normal",
"Pareto",
"Rice",
"SkewNormal",
"StudentT",
"Triangular",
"VonMises",
"Wald",
"Weibull",
"BetaBinomial",
"DiscreteWeibull",
"Geometric",
"NegativeBinomial",
"Poisson",
]

w_distributions = widgets.SelectMultiple(
options=dist_names, value=dist_names, description="", disabled=False
options=dist_names,
value=default_dist,
description="",
disabled=False,
layout=width_distribution_text,
)

return w_x_min, w_x_max, w_ncols, w_nrows, w_repr, w_distributions