Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a SoftAsymmetricLaplace distribution #2872

Merged
merged 8 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ on:
env:
CXX: g++-8
CC: gcc-8
# See coveralls-python - Github Actions support:
# https://github.com/TheKevJames/coveralls-python/blob/master/docs/usage/configuration.rst#github-actions-support
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COVERALLS_PARALLEL: true
COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }}
# those two lines seem to be required for coveralls
# see issue: https://github.com/lemurheavy/coveralls-public/issues/1435
COVERALLS_SERVICE_NAME: github
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}



jobs:
Expand Down Expand Up @@ -93,7 +91,7 @@ jobs:
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage unit --durations 20
- name: Submit to coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github
examples:
runs-on: ubuntu-20.04
needs: docs
Expand Down Expand Up @@ -124,7 +122,7 @@ jobs:
grep -l smoke_test tutorial/source/*.ipynb | xargs grep -L 'smoke_test = False' \
| CI=1 xargs pytest -vx --nbval-lax --current-env
- name: Submit to coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github
integration_1:
runs-on: ubuntu-20.04
needs: docs
Expand Down Expand Up @@ -153,7 +151,7 @@ jobs:
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_1 --durations 10
- name: Submit to coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github
integration_2:
runs-on: ubuntu-20.04
needs: docs
Expand Down Expand Up @@ -182,7 +180,7 @@ jobs:
run: |
pytest -vs --cov=pyro --cov-config .coveragerc --stage integration_batch_2 --durations 10
- name: Submit to coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github
funsor:
runs-on: ubuntu-20.04
needs: docs
Expand Down Expand Up @@ -213,7 +211,7 @@ jobs:
pytest -vs --cov=pyro --cov-config .coveragerc --stage funsor --durations 10
CI=1 pytest -vs --cov=pyro --cov-config .coveragerc --stage test_examples --durations 10 -k funsor
- name: Submit to coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls
run: GITHUB_SHA="$GITHUB_RUN_ID" GITHUB_REF="" coveralls --service=github
finish:
needs: [unit, examples, integration_1, integration_2, funsor]
runs-on: ubuntu-20.04
Expand Down
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,13 @@ SineSkewed
:undoc-members:
:show-inheritance:

SoftAsymmetricLaplace
---------------------
.. autoclass:: pyro.distributions.SoftAsymmetricLaplace
:members:
:undoc-members:
:show-inheritance:

SoftLaplace
-------------
.. autoclass:: pyro.distributions.SoftLaplace
Expand Down
6 changes: 5 additions & 1 deletion pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
# isort: split

from pyro.distributions.affine_beta import AffineBeta
from pyro.distributions.asymmetriclaplace import AsymmetricLaplace
from pyro.distributions.asymmetriclaplace import (
AsymmetricLaplace,
SoftAsymmetricLaplace,
)
from pyro.distributions.avf_mvn import AVFMultivariateNormal
from pyro.distributions.coalescent import (
CoalescentRateLikelihood,
Expand Down Expand Up @@ -135,6 +138,7 @@
"SineBivariateVonMises",
"SineSkewed",
"SoftLaplace",
"SoftAsymmetricLaplace",
"SpanningTree",
"Stable",
"TorchDistribution",
Expand Down
124 changes: 122 additions & 2 deletions pyro/distributions/asymmetriclaplace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all, lazy_property
Expand All @@ -10,7 +12,8 @@

class AsymmetricLaplace(TorchDistribution):
"""
Asymmetric version of Laplace distribution.
Asymmetric version of the :class:`~pyro.distributions.Laplace`
distribution.

To the left of ``loc`` this acts like an
``-Exponential(1/(asymmetry*scale))``; to the right of ``loc`` this acts
Expand All @@ -27,7 +30,7 @@ class AsymmetricLaplace(TorchDistribution):
support = constraints.real
has_rsample = True

def __init__(self, loc, scale, asymmetry, validate_args=None):
def __init__(self, loc, scale, asymmetry, *, validate_args=None):
self.loc, self.scale, self.asymmetry = broadcast_all(loc, scale, asymmetry)
super().__init__(self.loc.shape, validate_args=validate_args)

Expand Down Expand Up @@ -74,3 +77,120 @@ def variance(self):
p = left / total
q = right / total
return p * left ** 2 + q * right ** 2 + p * q * total ** 2


class SoftAsymmetricLaplace(TorchDistribution):
"""
Soft asymmetric version of the :class:`~pyro.distributions.Laplace`
distribution.

This has a smooth (infinitely differentiable) density with two asymmetric
asymptotically exponential tails, one on the left and one on the right. In
the limit of ``softness → 0``, this converges in distribution to the
:class:`AsymmetricLaplace` distribution.

This is equivalent to the sum of three random variables ``z - u + v`` where::

z ~ Normal(loc, scale * softness)
u ~ Exponential(1 / (scale * asymmetry))
v ~ Exponential(asymetry / scale)

This is also equivalent the sum of two random variables ``z + a`` where::

z ~ Normal(loc, scale * softness)
a ~ AsymmetricLaplace(0, scale, asymmetry)

:param loc: Location parameter, i.e. the mode.
:param scale: Scale parameter = geometric mean of left and right scales.
:param asymmetry: Square of ratio of left to right scales. Defaults to 1.
:param softness: Scale parameter of the Gaussian smoother. Defaults to 1.
"""
arg_constraints = {"loc": constraints.real,
"scale": constraints.positive,
"asymmetry": constraints.positive,
"softness": constraints.positive}
support = constraints.real
has_rsample = True

def __init__(self, loc, scale, asymmetry=1.0, softness=1.0, *, validate_args=None):
self.loc, self.scale, self.asymmetry, self.softness = broadcast_all(
loc, scale, asymmetry, softness,
)
super().__init__(self.loc.shape, validate_args=validate_args)

@lazy_property
def left_scale(self):
return self.scale * self.asymmetry

@lazy_property
def right_scale(self):
return self.scale / self.asymmetry

@lazy_property
def soft_scale(self):
return self.scale * self.softness

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(AsymmetricLaplace, _instance)
batch_shape = torch.Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
new.asymmetry = self.asymmetry.expand(batch_shape)
new.softness = self.softness.expand(batch_shape)
super(AsymmetricLaplace, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Standardize.
x = (value - self.loc) / self.scale
L = self.asymmetry
R = self.asymmetry.reciprocal()
S = self.softness
SS = S * S
S2 = S * math.sqrt(2)
Lx = L * x
Rx = R * x

# This is the sum of two integrals which are proportional to:
# left = Integrate[e^(-t/L - ((x+t)/S)^2/2)/sqrt(2 pi)/S, {t,0,Infinity}]
# = 1/2 e^((2 L x + S^2)/(2 L^2)) erfc((L x + S^2)/(sqrt(2) L S))
# right = Integrate[e^(-t/R - ((x-t)/S)^2/2)/sqrt(2 pi)/S, {t,0,Infinity}]
# = 1/2 e^((S^2 - 2 R x)/(2 R^2)) erfc((S^2 - R x)/(sqrt(2) R S))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the approach sounds reasonable to me. I haven't checked the details but we can trust the tests I think.

return math.log(0.5) + torch.logaddexp(
(SS / 2 + Lx) / L ** 2 + _logerfc((SS + Lx) / (L * S2)),
(SS / 2 - Rx) / R ** 2 + _logerfc((SS - Rx) / (R * S2)),
) - (L + R).log() - self.scale.log()

def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
z = self.loc.new_empty(shape).normal_()
u, v = self.loc.new_empty((2,) + shape).exponential_()
return (self.loc + self.soft_scale * z - self.left_scale * u
+ self.right_scale * v)

@property
def mean(self):
total_scale = self.left_scale + self.right_scale
return self.loc + (self.right_scale ** 2 - self.left_scale ** 2) / total_scale

@property
def variance(self):
left = self.left_scale
right = self.right_scale
total = left + right
p = left / total
q = right / total
return (p * left ** 2 + q * right ** 2 + p * q * total ** 2
+ self.soft_scale ** 2)


def _logerfc(x):
try:
# Requires https://github.com/pytorch/pytorch/issues/31945
return torch.logerfc(x)
except AttributeError:
return x.double().erfc().log().to(dtype=x.dtype)
7 changes: 7 additions & 0 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ def __init__(self, von_loc, von_conc, skewness):
{'loc': [2.0, -50.0], 'scale': [2.0, 10.0],
'asymmetry': [0.5, 2.5], 'test_data': [[2.0, 10.0], [-1.0, -50.0]]},
]),
Fixture(pyro_dist=dist.SoftAsymmetricLaplace,
examples=[
{'loc': [1.0], 'scale': [1.0], 'asymmetry': [2.0],
'test_data': [2.0]},
{'loc': [2.0, -50.0], 'scale': [2.0, 10.0], 'asymmetry': [0.5, 2.5],
'softness': [0.7, 1.4], 'test_data': [[2.0, 10.0], [-1.0, -50.0]]},
]),
]

discrete_dists = [
Expand Down