diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 3a66317a..ee798aef 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -63,6 +63,9 @@ This reference provides detailed documentation for user functions in the current .. automodule:: preliz.distributions.logistic :members: +.. automodule:: preliz.distributions.logitnormal + :members: + .. automodule:: preliz.distributions.lognormal :members: diff --git a/preliz/distributions/continuous.py b/preliz/distributions/continuous.py index c88eac8f..197d7225 100644 --- a/preliz/distributions/continuous.py +++ b/preliz/distributions/continuous.py @@ -10,10 +10,8 @@ import numpy as np from scipy import stats -from scipy.special import beta as betaf # pylint: disable=no-name-in-module -from scipy.special import logit, expit # pylint: disable=no-name-in-module -from ..internal.optimization import optimize_ml, optimize_moments, optimize_moments_rice +from ..internal.optimization import optimize_moments_rice from ..internal.distribution_helper import all_not_none, any_not_none from .distributions import Continuous from .asymmetric_laplace import AsymmetricLaplace @@ -29,6 +27,7 @@ from .kumaraswamy import Kumaraswamy from .laplace import Laplace from .logistic import Logistic +from .logitnormal import LogitNormal from .lognormal import LogNormal from .moyal import Moyal from .normal import Normal @@ -296,154 +295,6 @@ def _fit_mle(self, sample, **kwargs): self._update(beta) -class LogitNormal(Continuous): - r""" - Logit-Normal distribution. - - The pdf of this distribution is - - .. math:: - f(x \mid \mu, \tau) = - \frac{1}{x(1-x)} \sqrt{\frac{\tau}{2\pi}} - \exp\left\{ -\frac{\tau}{2} (logit(x)-\mu)^2 \right\} - - - .. plot:: - :context: close-figs - - import arviz as az - from preliz import LogitNormal - az.style.use('arviz-doc') - mus = [0., 0., 0., 1.] - sigmas = [0.3, 1., 2., 1.] - for mu, sigma in zip(mus, sigmas): - LogitNormal(mu, sigma).plot_pdf() - - ======== ========================================== - Support :math:`x \in (0, 1)` - Mean no analytical solution - Variance no analytical solution - ======== ========================================== - - Parameters - ---------- - mu : float - Location parameter. - sigma : float - Scale parameter (sigma > 0). - tau : float - Scale parameter (tau > 0). - """ - - def __init__(self, mu=None, sigma=None, tau=None): - super().__init__() - self.dist = _LogitNormal - self.support = (0, 1) - self._parametrization(mu, sigma, tau) - - def _parametrization(self, mu=None, sigma=None, tau=None): - if all_not_none(sigma, tau): - raise ValueError( - "Incompatible parametrization. Either use mu and sigma, or mu and tau." - ) - - names = ("mu", "sigma") - self.params_support = ((-np.inf, np.inf), (eps, np.inf)) - - if tau is not None: - self.tau = tau - sigma = from_precision(tau) - names = ("mu", "tau") - - self.mu = mu - self.sigma = sigma - self.param_names = names - if all_not_none(mu, sigma): - self._update(mu, sigma) - - def _get_frozen(self): - frozen = None - if all_not_none(self.params): - frozen = self.dist(self.mu, self.sigma) - return frozen - - def _update(self, mu, sigma): - self.mu = np.float64(mu) - self.sigma = np.float64(sigma) - self.tau = to_precision(sigma) - - if self.param_names[1] == "sigma": - self.params = (self.mu, self.sigma) - elif self.param_names[1] == "tau": - self.params = (self.mu, self.tau) - - self._update_rv_frozen() - - def _fit_moments(self, mean, sigma): - mu = logit(mean) - sigma = np.diff((mean - sigma * 3, mean + sigma * 3)) - self._update(mu, sigma) - - def _fit_mle(self, sample, **kwargs): - mu, sigma = stats.norm.fit(logit(sample), **kwargs) - self._update(mu, sigma) - - -class _LogitNormal(stats.rv_continuous): - def __init__(self, mu=None, sigma=None): - super().__init__() - self.mu = mu - self.sigma = sigma - - def support(self, *args, **kwds): # pylint: disable=unused-argument - return (0, 1) - - def cdf(self, x, *args, **kwds): - return stats.norm(self.mu, self.sigma, *args, **kwds).cdf(logit(x)) - - def pdf(self, x, *args, **kwds): - x = np.asarray(x) - mask = np.logical_or(x == 0, x == 1) - result = np.zeros_like(x, dtype=float) - result[~mask] = stats.norm(self.mu, self.sigma, *args, **kwds).pdf(logit(x[~mask])) / ( - x[~mask] * (1 - x[~mask]) - ) - return result - - def logpdf(self, x, *args, **kwds): - x = np.asarray(x) - mask = np.logical_or(x == 0, x == 1) - result = np.full_like(x, -np.inf, dtype=float) - result[~mask] = ( - stats.norm(self.mu, self.sigma, *args, **kwds).logpdf(logit(x[~mask])) - - np.log(x[~mask]) - - np.log1p(-x[~mask]) - ) - return result - - def ppf(self, q, *args, **kwds): - x_vals = np.linspace(0, 1, 1000) - idx = np.searchsorted(self.cdf(x_vals[:-1], *args, **kwds), q) - return x_vals[idx] - - def _stats(self, *args, **kwds): # pylint: disable=unused-argument - # https://en.wikipedia.org/wiki/Logit-normal_distribution#Moments - norm = stats.norm(self.mu, self.sigma) - logistic_inv = expit(norm.ppf(np.linspace(0, 1, 100000))) - mean = np.mean(logistic_inv) - var = np.var(logistic_inv) - return (mean, var, np.nan, np.nan) - - def entropy(self): # pylint: disable=arguments-differ - moments = self._stats() - return stats.norm(moments[0], moments[1] ** 0.5).entropy() - - def rvs( - self, size=1, random_state=None - ): # pylint: disable=arguments-differ, disable=unused-argument - return expit(np.random.normal(self.mu, self.sigma, size)) - - class Rice(Continuous): r""" Rice distribution. diff --git a/preliz/distributions/logitnormal.py b/preliz/distributions/logitnormal.py new file mode 100644 index 00000000..6f4113a1 --- /dev/null +++ b/preliz/distributions/logitnormal.py @@ -0,0 +1,203 @@ +# pylint: disable=attribute-defined-outside-init +# pylint: disable=arguments-differ +import numba as nb +import numpy as np + +from .distributions import Continuous +from ..internal.distribution_helper import eps, to_precision, from_precision, all_not_none +from ..internal.special import erf, erfinv, logit, expit, mean_and_std, cdf_bounds, ppf_bounds_cont + + +class LogitNormal(Continuous): + r""" + Logit-Normal distribution. + + The pdf of this distribution is + + .. math:: + f(x \mid \mu, \tau) = + \frac{1}{x(1-x)} \sqrt{\frac{\tau}{2\pi}} + \exp\left\{ -\frac{\tau}{2} (logit(x)-\mu)^2 \right\} + + + .. plot:: + :context: close-figs + + import arviz as az + from preliz import LogitNormal + az.style.use('arviz-doc') + mus = [0., 0., 0., 1.] + sigmas = [0.3, 1., 2., 1.] + for mu, sigma in zip(mus, sigmas): + LogitNormal(mu, sigma).plot_pdf() + + ======== ========================================== + Support :math:`x \in (0, 1)` + Mean no analytical solution + Variance no analytical solution + ======== ========================================== + + Parameters + ---------- + mu : float + Location parameter. + sigma : float + Scale parameter (sigma > 0). + tau : float + Scale parameter (tau > 0). + """ + + def __init__(self, mu=None, sigma=None, tau=None): + super().__init__() + self.support = (0, 1) + self._parametrization(mu, sigma, tau) + + def _parametrization(self, mu=None, sigma=None, tau=None): + if all_not_none(sigma, tau): + raise ValueError( + "Incompatible parametrization. Either use mu and sigma, or mu and tau." + ) + + names = ("mu", "sigma") + self.params_support = ((-np.inf, np.inf), (eps, np.inf)) + + if tau is not None: + self.tau = tau + sigma = from_precision(tau) + names = ("mu", "tau") + + self.mu = mu + self.sigma = sigma + self.param_names = names + if all_not_none(mu, sigma): + self._update(mu, sigma) + + def _get_frozen(self): + frozen = None + if all_not_none(self.params): + frozen = self.dist(self.mu, self.sigma) + return frozen + + def _update(self, mu, sigma): + self.mu = np.float64(mu) + self.sigma = np.float64(sigma) + self.tau = to_precision(sigma) + + if self.param_names[1] == "sigma": + self.params = (self.mu, self.sigma) + elif self.param_names[1] == "tau": + self.params = (self.mu, self.tau) + + self.is_frozen = True + + def pdf(self, x): + """ + Compute the probability density function (PDF) at a given point x. + """ + x = np.asarray(x) + return np.exp(self.logpdf(x)) + + def cdf(self, x): + """ + Compute the cumulative distribution function (CDF) at a given point x. + """ + x = np.asarray(x) + return nb_cdf(x, self.mu, self.sigma) + + def ppf(self, q): + """ + Compute the percent point function (PPF) at a given probability q. + """ + q = np.asarray(q) + return nb_ppf(q, self.mu, self.sigma) + + def logpdf(self, x): + """ + Compute the log probability density function (log PDF) at a given point x. + """ + return nb_logpdf(x, self.mu, self.sigma) + + def _neg_logpdf(self, x): + """ + Compute the neg log_pdf sum for the array x. + """ + return nb_neg_logpdf(x, self.mu, self.sigma) + + def entropy(self): + x_values = self.xvals("restricted") + logpdf = self.logpdf(x_values) + return -np.trapz(np.exp(logpdf) * logpdf, x_values) + + def mean(self): + x_values = self.xvals("full") + pdf = self.pdf(x_values) + return np.trapz(x_values * pdf, x_values) + + def median(self): + return self.ppf(0.5) + + def var(self): + x_values = self.xvals("full") + pdf = self.pdf(x_values) + return np.trapz((x_values - self.mean()) ** 2 * pdf, x_values) + + def std(self): + return self.var() ** 0.5 + + def skewness(self): + mean = self.mean() + std = self.std() + x_values = self.xvals("full") + pdf = self.pdf(x_values) + return np.trapz(((x_values - mean) / std) ** 3 * pdf, x_values) + + def kurtosis(self): + mean = self.mean() + std = self.std() + x_values = self.xvals("full") + pdf = self.pdf(x_values) + return np.trapz(((x_values - mean) / std) ** 4 * pdf, x_values) - 3 + + def rvs(self, size=None, random_state=None): + random_state = np.random.default_rng(random_state) + return expit(random_state.normal(self.mu, self.sigma, size)) + + def _fit_moments(self, mean, sigma): + mu = logit(mean) + sigma = np.diff((mean - sigma * 3, mean + sigma * 3)) + self._update(mu, sigma) + + def _fit_mle(self, sample): + mu, sigma = mean_and_std(logit(sample)) + self._update(mu, sigma) + + +@nb.njit(cache=True) +def nb_cdf(x, mu, sigma): + return cdf_bounds(0.5 * (1 + erf((logit(x) - mu) / (sigma * 2**0.5))), x, 0, 1) + + +@nb.njit(cache=True) +def nb_ppf(q, mu, sigma): + return ppf_bounds_cont(expit(mu + sigma * 2**0.5 * erfinv(2 * q - 1)), q, 0, 1) + + +@nb.vectorize(nopython=True, cache=True) +def nb_logpdf(x, mu, sigma): + if x <= 0: + return -np.inf + if x >= 1: + return -np.inf + else: + return ( + -np.log(sigma) + - 0.5 * np.log(2 * np.pi) + - 0.5 * ((logit(x) - mu) / sigma) ** 2 + - np.log(x) + - np.log1p(-x) + ) + + +@nb.njit(cache=True) +def nb_neg_logpdf(x, mu, sigma): + return -(nb_logpdf(x, mu, sigma)).sum() diff --git a/preliz/internal/special.py b/preliz/internal/special.py index 7430ae2b..408a2583 100644 --- a/preliz/internal/special.py +++ b/preliz/internal/special.py @@ -433,6 +433,26 @@ def gammaln(x): return tmp + np.log(stp * ser) +@nb.vectorize(nopython=True, cache=True) +def logit(x): + if x == 0: + return -np.inf + elif x == 1: + return np.inf + if x < 0 or x > 1: + return np.nan + else: + return np.log(x / (1 - x)) + + +@nb.vectorize(nopython=True, cache=True) +def expit(x): + if x >= 0: + return 1 / (1 + np.exp(-x)) + else: + return np.exp(x) / (1 + np.exp(x)) + + @nb.vectorize(nopython=True, cache=True) def xlogy(x, y): if x == 0: diff --git a/preliz/tests/test_maxent.py b/preliz/tests/test_maxent.py index 27ca847e..ab1a1c5b 100644 --- a/preliz/tests/test_maxent.py +++ b/preliz/tests/test_maxent.py @@ -93,7 +93,7 @@ (Logistic(), -1, 1, 0.5, (-np.inf, np.inf), (0, 0.91)), (LogNormal(), 1, 4, 0.5, (0, np.inf), (1.216, 0.859)), (LogNormal(mu=1), 1, 4, 0.5, (0, np.inf), (0.978)), - (LogitNormal(), 0.3, 0.8, 0.9, (0, 1), (0.226, 0.677)), + (LogitNormal(), 0.3, 0.8, 0.9, (0, 1), (0.213, 0.676)), (LogitNormal(mu=0.7), 0.3, 0.8, 0.9, (0, 1), (0.531)), (Moyal(), 0, 10, 0.9, (-np.inf, np.inf), (2.935, 1.6)), (Moyal(mu=4), 0, 10, 0.9, (-np.inf, np.inf), (1.445)), diff --git a/preliz/tests/test_scipy.py b/preliz/tests/test_scipy.py index f80bb598..cb3d49df 100644 --- a/preliz/tests/test_scipy.py +++ b/preliz/tests/test_scipy.py @@ -19,6 +19,7 @@ Kumaraswamy, Laplace, Logistic, + LogitNormal, LogNormal, Moyal, Normal, @@ -67,6 +68,7 @@ (Laplace, stats.laplace, {"mu": 2.5, "b": 4}, {"loc": 2.5, "scale": 4}), (Logistic, stats.logistic, {"mu": 2.5, "s": 4}, {"loc": 2.5, "scale": 4}), (LogNormal, stats.lognorm, {"mu": 0, "sigma": 2}, {"s": 2, "scale": 1}), + (LogitNormal, stats.beta, {"mu": 0, "sigma": 0.2}, {"a": 50.5, "b": 50.5}), # not in scipy (Moyal, stats.moyal, {"mu": 1, "sigma": 2}, {"loc": 1, "scale": 2}), (Normal, stats.norm, {"mu": 0, "sigma": 2}, {"loc": 0, "scale": 2}), (Pareto, stats.pareto, {"m": 1, "alpha": 4.5}, {"b": 4.5}), @@ -124,7 +126,7 @@ def test_match_scipy(p_dist, sp_dist, p_params, sp_params): expected = scipy_dist.entropy() if preliz_dist.kind == "discrete": assert_almost_equal(actual, expected, decimal=1) - elif preliz_name in ["HalfStudentT", "Moyal"]: + elif preliz_name in ["HalfStudentT", "Moyal", "LogitNormal"]: assert_almost_equal(actual, expected, decimal=2) else: assert_almost_equal(actual, expected, decimal=4) @@ -136,6 +138,7 @@ def test_match_scipy(p_dist, sp_dist, p_params, sp_params): if preliz_name in [ "HalfStudentT", "Kumaraswamy", + "LogitNormal", "Moyal", "StudentT", "Weibull", @@ -159,7 +162,9 @@ def test_match_scipy(p_dist, sp_dist, p_params, sp_params): else: expected_pdf = scipy_dist.pmf(actual_rvs) - if preliz_name == "HalfStudentT": + if preliz_name == "LogitNormal": + assert_almost_equal(actual_pdf, expected_pdf, decimal=1) + elif preliz_name == "HalfStudentT": assert_almost_equal(actual_pdf, expected_pdf, decimal=2) else: assert_almost_equal(actual_pdf, expected_pdf, decimal=4) @@ -170,7 +175,7 @@ def test_match_scipy(p_dist, sp_dist, p_params, sp_params): actual_cdf = preliz_dist.cdf(cdf_vals) expected_cdf = scipy_dist.cdf(cdf_vals) - if preliz_name == "HalfStudentT": + if preliz_name in ["HalfStudentT", "LogitNormal"]: assert_almost_equal(actual_cdf, expected_cdf, decimal=2) else: assert_almost_equal(actual_cdf, expected_cdf, decimal=6) @@ -178,7 +183,7 @@ def test_match_scipy(p_dist, sp_dist, p_params, sp_params): x_vals = [-1, 0, 0.25, 0.5, 0.75, 1, 2] actual_ppf = preliz_dist.ppf(x_vals) expected_ppf = scipy_dist.ppf(x_vals) - if preliz_name in ["HalfStudentT", "Wald"]: + if preliz_name in ["HalfStudentT", "Wald", "LogitNormal"]: assert_almost_equal(actual_ppf, expected_ppf, decimal=2) else: assert_almost_equal(actual_ppf, expected_ppf) @@ -188,14 +193,17 @@ def test_match_scipy(p_dist, sp_dist, p_params, sp_params): expected_logpdf = scipy_dist.logpdf(actual_rvs) else: expected_logpdf = scipy_dist.logpmf(actual_rvs) + if preliz_name == "HalfStudentT": assert_almost_equal(actual_logpdf, expected_logpdf, decimal=0) + elif preliz_name == "LogitNormal": + assert_almost_equal(actual_logpdf, expected_logpdf, decimal=1) else: assert_almost_equal(actual_logpdf, expected_logpdf) actual_neg_logpdf = preliz_dist._neg_logpdf(actual_rvs) expected_neg_logpdf = -expected_logpdf.sum() - if preliz_name == "HalfStudentT": + if preliz_name in ["HalfStudentT", "LogitNormal"]: assert_almost_equal(actual_neg_logpdf, expected_neg_logpdf, decimal=1) else: assert_almost_equal(actual_neg_logpdf, expected_neg_logpdf) @@ -221,7 +229,7 @@ def test_match_scipy(p_dist, sp_dist, p_params, sp_params): actual_moments = preliz_dist.moments("mv") expected_moments = scipy_dist.stats("mv") - if preliz_name == "HalfStudentT": + if preliz_name in ["HalfStudentT", "LogitNormal"]: assert_almost_equal(actual_moments, expected_moments, decimal=1) else: assert_almost_equal(actual_moments, expected_moments) diff --git a/preliz/tests/test_special.py b/preliz/tests/test_special.py index 2aeca10c..31da8d88 100644 --- a/preliz/tests/test_special.py +++ b/preliz/tests/test_special.py @@ -1,4 +1,4 @@ -import pytest +# pylint: disable=no-member from numpy.testing import assert_almost_equal import numpy as np @@ -57,3 +57,13 @@ def test_gamma(): def test_digamma(): x = np.linspace(0.1, 10, 100) assert_almost_equal(sc_special.digamma(x), pz_special.digamma(x)) + + +def test_logit(): + x = np.linspace(-0.1, 1.1, 100) + assert_almost_equal(sc_special.logit(x), pz_special.logit(x)) + + +def test_expit(): + x = np.linspace(-20, 10, 500) + assert_almost_equal(sc_special.expit(x), pz_special.expit(x))