From a6d120d0429436940bc9104e315ef4613cf3979d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 14 Jun 2021 14:07:26 -0600 Subject: [PATCH] Implement a SoftAsymmetricLaplace distribution (#2872) * Sketch SoftAsymmetricLaplace distrbution * Implement .log_prob() * Attempt to numerically stabilize * Use double precision * Address review comments * Attempt to fix coveralls on github * try to fix coveralls Co-authored-by: Du Phan --- .github/workflows/ci.yml | 18 ++-- docs/source/distributions.rst | 7 ++ pyro/distributions/__init__.py | 6 +- pyro/distributions/asymmetriclaplace.py | 124 +++++++++++++++++++++++- tests/distributions/conftest.py | 7 ++ 5 files changed, 149 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7e8f14396e..3cb93743cb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,13 +9,11 @@ on: env: CXX: g++-8 CC: gcc-8 + # See coveralls-python - Github Actions support: + # https://github.com/TheKevJames/coveralls-python/blob/master/docs/usage/configuration.rst#github-actions-support + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COVERALLS_PARALLEL: true - COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} - # those two lines seem to be required for coveralls - # see issue: https://github.com/lemurheavy/coveralls-public/issues/1435 COVERALLS_SERVICE_NAME: github - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - jobs: @@ -93,7 +91,7 @@ jobs: run: | pytest -vs --cov=pyro --cov-config .coveragerc --stage unit --durations 20 - name: Submit to coveralls - run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls + run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github examples: runs-on: ubuntu-20.04 needs: docs @@ -124,7 +122,7 @@ jobs: grep -l smoke_test tutorial/source/*.ipynb | xargs grep -L 'smoke_test = False' \ | CI=1 xargs pytest -vx --nbval-lax --current-env - name: Submit to coveralls - run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls + run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github integration_1: runs-on: ubuntu-20.04 needs: docs @@ -153,7 +151,7 @@ jobs: run: | pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_1 --durations 10 - name: Submit to coveralls - run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls + run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github integration_2: runs-on: ubuntu-20.04 needs: docs @@ -182,7 +180,7 @@ jobs: run: | pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_2 --durations 10 - name: Submit to coveralls - run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls + run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github funsor: runs-on: ubuntu-20.04 needs: docs @@ -213,7 +211,7 @@ jobs: pytest -vs --cov=pyro --cov-config .coveragerc --stage funsor --durations 10 CI=1 pytest -vs --cov=pyro --cov-config .coveragerc --stage test_examples --durations 10 -k funsor - name: Submit to coveralls - run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls + run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github finish: needs: [unit, examples, integration_1, integration_2, funsor] runs-on: ubuntu-20.04 diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 47f496171e..70dd2ddf5f 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -337,6 +337,13 @@ SineSkewed :undoc-members: :show-inheritance: +SoftAsymmetricLaplace +--------------------- +.. autoclass:: pyro.distributions.SoftAsymmetricLaplace + :members: + :undoc-members: + :show-inheritance: + SoftLaplace ------------- .. autoclass:: pyro.distributions.SoftLaplace diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index 305fbd1a0d..cdce54fe58 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -7,7 +7,10 @@ # isort: split from pyro.distributions.affine_beta import AffineBeta -from pyro.distributions.asymmetriclaplace import AsymmetricLaplace +from pyro.distributions.asymmetriclaplace import ( + AsymmetricLaplace, + SoftAsymmetricLaplace, +) from pyro.distributions.avf_mvn import AVFMultivariateNormal from pyro.distributions.coalescent import ( CoalescentRateLikelihood, @@ -135,6 +138,7 @@ "SineBivariateVonMises", "SineSkewed", "SoftLaplace", + "SoftAsymmetricLaplace", "SpanningTree", "Stable", "TorchDistribution", diff --git a/pyro/distributions/asymmetriclaplace.py b/pyro/distributions/asymmetriclaplace.py index 3206eccf9b..f14218b3f1 100644 --- a/pyro/distributions/asymmetriclaplace.py +++ b/pyro/distributions/asymmetriclaplace.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import math + import torch from torch.distributions import constraints from torch.distributions.utils import broadcast_all, lazy_property @@ -10,7 +12,8 @@ class AsymmetricLaplace(TorchDistribution): """ - Asymmetric version of Laplace distribution. + Asymmetric version of the :class:`~pyro.distributions.Laplace` + distribution. To the left of ``loc`` this acts like an ``-Exponential(1/(asymmetry*scale))``; to the right of ``loc`` this acts @@ -27,7 +30,7 @@ class AsymmetricLaplace(TorchDistribution): support = constraints.real has_rsample = True - def __init__(self, loc, scale, asymmetry, validate_args=None): + def __init__(self, loc, scale, asymmetry, *, validate_args=None): self.loc, self.scale, self.asymmetry = broadcast_all(loc, scale, asymmetry) super().__init__(self.loc.shape, validate_args=validate_args) @@ -74,3 +77,120 @@ def variance(self): p = left / total q = right / total return p * left ** 2 + q * right ** 2 + p * q * total ** 2 + + +class SoftAsymmetricLaplace(TorchDistribution): + """ + Soft asymmetric version of the :class:`~pyro.distributions.Laplace` + distribution. + + This has a smooth (infinitely differentiable) density with two asymmetric + asymptotically exponential tails, one on the left and one on the right. In + the limit of ``softness → 0``, this converges in distribution to the + :class:`AsymmetricLaplace` distribution. + + This is equivalent to the sum of three random variables ``z - u + v`` where:: + + z ~ Normal(loc, scale * softness) + u ~ Exponential(1 / (scale * asymmetry)) + v ~ Exponential(asymetry / scale) + + This is also equivalent the sum of two random variables ``z + a`` where:: + + z ~ Normal(loc, scale * softness) + a ~ AsymmetricLaplace(0, scale, asymmetry) + + :param loc: Location parameter, i.e. the mode. + :param scale: Scale parameter = geometric mean of left and right scales. + :param asymmetry: Square of ratio of left to right scales. Defaults to 1. + :param softness: Scale parameter of the Gaussian smoother. Defaults to 1. + """ + arg_constraints = {"loc": constraints.real, + "scale": constraints.positive, + "asymmetry": constraints.positive, + "softness": constraints.positive} + support = constraints.real + has_rsample = True + + def __init__(self, loc, scale, asymmetry=1.0, softness=1.0, *, validate_args=None): + self.loc, self.scale, self.asymmetry, self.softness = broadcast_all( + loc, scale, asymmetry, softness, + ) + super().__init__(self.loc.shape, validate_args=validate_args) + + @lazy_property + def left_scale(self): + return self.scale * self.asymmetry + + @lazy_property + def right_scale(self): + return self.scale / self.asymmetry + + @lazy_property + def soft_scale(self): + return self.scale * self.softness + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(AsymmetricLaplace, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + new.asymmetry = self.asymmetry.expand(batch_shape) + new.softness = self.softness.expand(batch_shape) + super(AsymmetricLaplace, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + # Standardize. + x = (value - self.loc) / self.scale + L = self.asymmetry + R = self.asymmetry.reciprocal() + S = self.softness + SS = S * S + S2 = S * math.sqrt(2) + Lx = L * x + Rx = R * x + + # This is the sum of two integrals which are proportional to: + # left = Integrate[e^(-t/L - ((x+t)/S)^2/2)/sqrt(2 pi)/S, {t,0,Infinity}] + # = 1/2 e^((2 L x + S^2)/(2 L^2)) erfc((L x + S^2)/(sqrt(2) L S)) + # right = Integrate[e^(-t/R - ((x-t)/S)^2/2)/sqrt(2 pi)/S, {t,0,Infinity}] + # = 1/2 e^((S^2 - 2 R x)/(2 R^2)) erfc((S^2 - R x)/(sqrt(2) R S)) + return math.log(0.5) + torch.logaddexp( + (SS / 2 + Lx) / L ** 2 + _logerfc((SS + Lx) / (L * S2)), + (SS / 2 - Rx) / R ** 2 + _logerfc((SS - Rx) / (R * S2)), + ) - (L + R).log() - self.scale.log() + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + z = self.loc.new_empty(shape).normal_() + u, v = self.loc.new_empty((2,) + shape).exponential_() + return (self.loc + self.soft_scale * z - self.left_scale * u + + self.right_scale * v) + + @property + def mean(self): + total_scale = self.left_scale + self.right_scale + return self.loc + (self.right_scale ** 2 - self.left_scale ** 2) / total_scale + + @property + def variance(self): + left = self.left_scale + right = self.right_scale + total = left + right + p = left / total + q = right / total + return (p * left ** 2 + q * right ** 2 + p * q * total ** 2 + + self.soft_scale ** 2) + + +def _logerfc(x): + try: + # Requires https://github.com/pytorch/pytorch/issues/31945 + return torch.logerfc(x) + except AttributeError: + return x.double().erfc().log().to(dtype=x.dtype) diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index 35fc7c2d09..01ddff405a 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -379,6 +379,13 @@ def __init__(self, von_loc, von_conc, skewness): {'loc': [2.0, -50.0], 'scale': [2.0, 10.0], 'asymmetry': [0.5, 2.5], 'test_data': [[2.0, 10.0], [-1.0, -50.0]]}, ]), + Fixture(pyro_dist=dist.SoftAsymmetricLaplace, + examples=[ + {'loc': [1.0], 'scale': [1.0], 'asymmetry': [2.0], + 'test_data': [2.0]}, + {'loc': [2.0, -50.0], 'scale': [2.0, 10.0], 'asymmetry': [0.5, 2.5], + 'softness': [0.7, 1.4], 'test_data': [[2.0, 10.0], [-1.0, -50.0]]}, + ]), ] discrete_dists = [