Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add efficient LogitNormal #406

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ This reference provides detailed documentation for user functions in the current
.. automodule:: preliz.distributions.logistic
:members:

.. automodule:: preliz.distributions.logitnormal
:members:

.. automodule:: preliz.distributions.lognormal
:members:

Expand Down
153 changes: 2 additions & 151 deletions preliz/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

import numpy as np
from scipy import stats
from scipy.special import beta as betaf # pylint: disable=no-name-in-module
from scipy.special import logit, expit # pylint: disable=no-name-in-module

from ..internal.optimization import optimize_ml, optimize_moments, optimize_moments_rice
from ..internal.optimization import optimize_moments_rice
from ..internal.distribution_helper import all_not_none, any_not_none
from .distributions import Continuous
from .asymmetric_laplace import AsymmetricLaplace
Expand All @@ -29,6 +27,7 @@
from .kumaraswamy import Kumaraswamy
from .laplace import Laplace
from .logistic import Logistic
from .logitnormal import LogitNormal
from .lognormal import LogNormal
from .moyal import Moyal
from .normal import Normal
Expand Down Expand Up @@ -296,154 +295,6 @@ def _fit_mle(self, sample, **kwargs):
self._update(beta)


class LogitNormal(Continuous):
r"""
Logit-Normal distribution.

The pdf of this distribution is

.. math::
f(x \mid \mu, \tau) =
\frac{1}{x(1-x)} \sqrt{\frac{\tau}{2\pi}}
\exp\left\{ -\frac{\tau}{2} (logit(x)-\mu)^2 \right\}


.. plot::
:context: close-figs

import arviz as az
from preliz import LogitNormal
az.style.use('arviz-doc')
mus = [0., 0., 0., 1.]
sigmas = [0.3, 1., 2., 1.]
for mu, sigma in zip(mus, sigmas):
LogitNormal(mu, sigma).plot_pdf()

======== ==========================================
Support :math:`x \in (0, 1)`
Mean no analytical solution
Variance no analytical solution
======== ==========================================

Parameters
----------
mu : float
Location parameter.
sigma : float
Scale parameter (sigma > 0).
tau : float
Scale parameter (tau > 0).
"""

def __init__(self, mu=None, sigma=None, tau=None):
super().__init__()
self.dist = _LogitNormal
self.support = (0, 1)
self._parametrization(mu, sigma, tau)

def _parametrization(self, mu=None, sigma=None, tau=None):
if all_not_none(sigma, tau):
raise ValueError(
"Incompatible parametrization. Either use mu and sigma, or mu and tau."
)

names = ("mu", "sigma")
self.params_support = ((-np.inf, np.inf), (eps, np.inf))

if tau is not None:
self.tau = tau
sigma = from_precision(tau)
names = ("mu", "tau")

self.mu = mu
self.sigma = sigma
self.param_names = names
if all_not_none(mu, sigma):
self._update(mu, sigma)

def _get_frozen(self):
frozen = None
if all_not_none(self.params):
frozen = self.dist(self.mu, self.sigma)
return frozen

def _update(self, mu, sigma):
self.mu = np.float64(mu)
self.sigma = np.float64(sigma)
self.tau = to_precision(sigma)

if self.param_names[1] == "sigma":
self.params = (self.mu, self.sigma)
elif self.param_names[1] == "tau":
self.params = (self.mu, self.tau)

self._update_rv_frozen()

def _fit_moments(self, mean, sigma):
mu = logit(mean)
sigma = np.diff((mean - sigma * 3, mean + sigma * 3))
self._update(mu, sigma)

def _fit_mle(self, sample, **kwargs):
mu, sigma = stats.norm.fit(logit(sample), **kwargs)
self._update(mu, sigma)


class _LogitNormal(stats.rv_continuous):
def __init__(self, mu=None, sigma=None):
super().__init__()
self.mu = mu
self.sigma = sigma

def support(self, *args, **kwds): # pylint: disable=unused-argument
return (0, 1)

def cdf(self, x, *args, **kwds):
return stats.norm(self.mu, self.sigma, *args, **kwds).cdf(logit(x))

def pdf(self, x, *args, **kwds):
x = np.asarray(x)
mask = np.logical_or(x == 0, x == 1)
result = np.zeros_like(x, dtype=float)
result[~mask] = stats.norm(self.mu, self.sigma, *args, **kwds).pdf(logit(x[~mask])) / (
x[~mask] * (1 - x[~mask])
)
return result

def logpdf(self, x, *args, **kwds):
x = np.asarray(x)
mask = np.logical_or(x == 0, x == 1)
result = np.full_like(x, -np.inf, dtype=float)
result[~mask] = (
stats.norm(self.mu, self.sigma, *args, **kwds).logpdf(logit(x[~mask]))
- np.log(x[~mask])
- np.log1p(-x[~mask])
)
return result

def ppf(self, q, *args, **kwds):
x_vals = np.linspace(0, 1, 1000)
idx = np.searchsorted(self.cdf(x_vals[:-1], *args, **kwds), q)
return x_vals[idx]

def _stats(self, *args, **kwds): # pylint: disable=unused-argument
# https://en.wikipedia.org/wiki/Logit-normal_distribution#Moments
norm = stats.norm(self.mu, self.sigma)
logistic_inv = expit(norm.ppf(np.linspace(0, 1, 100000)))
mean = np.mean(logistic_inv)
var = np.var(logistic_inv)
return (mean, var, np.nan, np.nan)

def entropy(self): # pylint: disable=arguments-differ
moments = self._stats()
return stats.norm(moments[0], moments[1] ** 0.5).entropy()

def rvs(
self, size=1, random_state=None
): # pylint: disable=arguments-differ, disable=unused-argument
return expit(np.random.normal(self.mu, self.sigma, size))


class Rice(Continuous):
r"""
Rice distribution.
Expand Down
203 changes: 203 additions & 0 deletions preliz/distributions/logitnormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# pylint: disable=attribute-defined-outside-init
# pylint: disable=arguments-differ
import numba as nb
import numpy as np

from .distributions import Continuous
from ..internal.distribution_helper import eps, to_precision, from_precision, all_not_none
from ..internal.special import erf, erfinv, logit, expit, mean_and_std, cdf_bounds, ppf_bounds_cont


class LogitNormal(Continuous):
r"""
Logit-Normal distribution.

The pdf of this distribution is

.. math::
f(x \mid \mu, \tau) =
\frac{1}{x(1-x)} \sqrt{\frac{\tau}{2\pi}}
\exp\left\{ -\frac{\tau}{2} (logit(x)-\mu)^2 \right\}


.. plot::
:context: close-figs

import arviz as az
from preliz import LogitNormal
az.style.use('arviz-doc')
mus = [0., 0., 0., 1.]
sigmas = [0.3, 1., 2., 1.]
for mu, sigma in zip(mus, sigmas):
LogitNormal(mu, sigma).plot_pdf()

======== ==========================================
Support :math:`x \in (0, 1)`
Mean no analytical solution
Variance no analytical solution
======== ==========================================

Parameters
----------
mu : float
Location parameter.
sigma : float
Scale parameter (sigma > 0).
tau : float
Scale parameter (tau > 0).
"""

def __init__(self, mu=None, sigma=None, tau=None):
super().__init__()
self.support = (0, 1)
self._parametrization(mu, sigma, tau)

def _parametrization(self, mu=None, sigma=None, tau=None):
if all_not_none(sigma, tau):
raise ValueError(
"Incompatible parametrization. Either use mu and sigma, or mu and tau."
)

names = ("mu", "sigma")
self.params_support = ((-np.inf, np.inf), (eps, np.inf))

if tau is not None:
self.tau = tau
sigma = from_precision(tau)
names = ("mu", "tau")

self.mu = mu
self.sigma = sigma
self.param_names = names
if all_not_none(mu, sigma):
self._update(mu, sigma)

def _get_frozen(self):
frozen = None
if all_not_none(self.params):
frozen = self.dist(self.mu, self.sigma)
return frozen

def _update(self, mu, sigma):
self.mu = np.float64(mu)
self.sigma = np.float64(sigma)
self.tau = to_precision(sigma)

if self.param_names[1] == "sigma":
self.params = (self.mu, self.sigma)
elif self.param_names[1] == "tau":
self.params = (self.mu, self.tau)

self.is_frozen = True

def pdf(self, x):
"""
Compute the probability density function (PDF) at a given point x.
"""
x = np.asarray(x)
return np.exp(self.logpdf(x))

def cdf(self, x):
"""
Compute the cumulative distribution function (CDF) at a given point x.
"""
x = np.asarray(x)
return nb_cdf(x, self.mu, self.sigma)

def ppf(self, q):
"""
Compute the percent point function (PPF) at a given probability q.
"""
q = np.asarray(q)
return nb_ppf(q, self.mu, self.sigma)

def logpdf(self, x):
"""
Compute the log probability density function (log PDF) at a given point x.
"""
return nb_logpdf(x, self.mu, self.sigma)

def _neg_logpdf(self, x):
"""
Compute the neg log_pdf sum for the array x.
"""
return nb_neg_logpdf(x, self.mu, self.sigma)

def entropy(self):
x_values = self.xvals("restricted")
logpdf = self.logpdf(x_values)
return -np.trapz(np.exp(logpdf) * logpdf, x_values)

def mean(self):
x_values = self.xvals("full")
pdf = self.pdf(x_values)
return np.trapz(x_values * pdf, x_values)

def median(self):
return self.ppf(0.5)

def var(self):
x_values = self.xvals("full")
pdf = self.pdf(x_values)
return np.trapz((x_values - self.mean()) ** 2 * pdf, x_values)

def std(self):
return self.var() ** 0.5

def skewness(self):
mean = self.mean()
std = self.std()
x_values = self.xvals("full")
pdf = self.pdf(x_values)
return np.trapz(((x_values - mean) / std) ** 3 * pdf, x_values)

def kurtosis(self):
mean = self.mean()
std = self.std()
x_values = self.xvals("full")
pdf = self.pdf(x_values)
return np.trapz(((x_values - mean) / std) ** 4 * pdf, x_values) - 3

def rvs(self, size=None, random_state=None):
random_state = np.random.default_rng(random_state)
return expit(random_state.normal(self.mu, self.sigma, size))

def _fit_moments(self, mean, sigma):
mu = logit(mean)
sigma = np.diff((mean - sigma * 3, mean + sigma * 3))
self._update(mu, sigma)

def _fit_mle(self, sample):
mu, sigma = mean_and_std(logit(sample))
self._update(mu, sigma)


@nb.njit(cache=True)
def nb_cdf(x, mu, sigma):
return cdf_bounds(0.5 * (1 + erf((logit(x) - mu) / (sigma * 2**0.5))), x, 0, 1)


@nb.njit(cache=True)
def nb_ppf(q, mu, sigma):
return ppf_bounds_cont(expit(mu + sigma * 2**0.5 * erfinv(2 * q - 1)), q, 0, 1)


@nb.vectorize(nopython=True, cache=True)
def nb_logpdf(x, mu, sigma):
if x <= 0:
return -np.inf
if x >= 1:
return -np.inf
else:
return (
-np.log(sigma)
- 0.5 * np.log(2 * np.pi)
- 0.5 * ((logit(x) - mu) / sigma) ** 2
- np.log(x)
- np.log1p(-x)
)


@nb.njit(cache=True)
def nb_neg_logpdf(x, mu, sigma):
return -(nb_logpdf(x, mu, sigma)).sum()
Loading
Loading