diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 57a7d6bbc3..dcc66f2d5a 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -232,6 +232,13 @@ LKJCorrCholesky :undoc-members: :show-inheritance: +Logistic +-------- +.. autoclass:: pyro.distributions.Logistic + :members: + :undoc-members: + :show-inheritance: + MaskedDistribution ------------------ .. autoclass:: pyro.distributions.MaskedDistribution diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index e797d50864..e89c95dda2 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -49,7 +49,7 @@ from pyro.distributions.improper_uniform import ImproperUniform from pyro.distributions.inverse_gamma import InverseGamma from pyro.distributions.lkj import LKJ, LKJCorrCholesky -from pyro.distributions.logistic import SkewLogistic +from pyro.distributions.logistic import Logistic, SkewLogistic from pyro.distributions.mixture import MaskedMixture from pyro.distributions.multivariate_studentt import MultivariateStudentT from pyro.distributions.omt_mvn import OMTMultivariateNormal @@ -120,9 +120,10 @@ "ImproperUniform", "IndependentHMM", "InverseGamma", - "LinearHMM", "LKJ", "LKJCorrCholesky", + "LinearHMM", + "Logistic", "MaskedDistribution", "MaskedMixture", "MixtureOfDiagNormals", diff --git a/pyro/distributions/logistic.py b/pyro/distributions/logistic.py index b98ea10139..354aaef933 100644 --- a/pyro/distributions/logistic.py +++ b/pyro/distributions/logistic.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import math + import torch from torch.distributions import constraints from torch.distributions.utils import broadcast_all @@ -9,6 +11,77 @@ from .torch_distribution import TorchDistribution +class Logistic(TorchDistribution): + r""" + Logistic distribution. + + This is a smooth distribution with symmetric asymptotically exponential + tails and a concave log density. For standard ``loc=0``, ``scale=1``, the + density is given by + + .. math:: + + p(x) = \frac {e^{-x}} {(1 + e^{-x})^2} + + Like the :class:`~pyro.distributions.Laplace` density, this density has the + heaviest possible tails (asymptotically) while still being log-convex. + Unlike the :class:`~pyro.distributions.Laplace` distribution, this + distribution is infinitely differentiable everywhere, and is thus suitable + for constructing Laplace approximations. + + :param loc: Location parameter. + :param scale: Scale parameter. + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + + def __init__(self, loc, scale, *, validate_args=None): + self.loc, self.scale = broadcast_all(loc, scale) + super().__init__(self.loc.shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(SkewLogistic, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Logistic, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + z = (value - self.loc) / self.scale + return logsigmoid(z) * 2 - z - self.scale.log() + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + u = self.loc.new_empty(shape).uniform_() + return self.icdf(u) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + z = (value - self.loc) / self.scale + return z.sigmoid() + + def icdf(self, value): + return self.loc + self.scale * value.logit() + + @property + def mean(self): + return self.loc + + @property + def variance(self): + return self.scale ** 2 * (math.pi ** 2 / 3) + + def entropy(self): + return self.scale.log() + 2 + + class SkewLogistic(TorchDistribution): r""" Skewed generalization of the Logistic distribution (Type I in [1]). diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 91d0c3d8da..28680a881e 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -646,6 +646,17 @@ def __init__(self, von_loc, von_conc, skewness): }, ], ), + Fixture( + pyro_dist=dist.Logistic, + examples=[ + {"loc": [1.0], "scale": [1.0], "test_data": [2.0]}, + { + "loc": [2.0, -50.0], + "scale": [2.0, 10.0], + "test_data": [[2.0, 10.0], [-1.0, -50.0]], + }, + ], + ), ] discrete_dists = [