Skip to content

Commit

Permalink
fixed bug when batch_dim > 1.
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed May 5, 2021
1 parent 15c4f3d commit 1789fb0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
17 changes: 9 additions & 8 deletions pyro/distributions/sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
import warnings
from functools import reduce
from math import pi
Expand Down Expand Up @@ -72,7 +73,8 @@ def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration, corre
assert (correlation is None) != (weighted_correlation is None)

if weighted_correlation is not None:
correlation = weighted_correlation * (phi_concentration * psi_concentration).sqrt() + 1e-8
sqrt_ = torch.sqrt if isinstance(phi_concentration, torch.Tensor) else math.sqrt
correlation = weighted_correlation * sqrt_(phi_concentration * psi_concentration) + 1e-8

phi_loc, psi_loc, phi_concentration, psi_concentration, correlation = broadcast_all(phi_loc, psi_loc,
phi_concentration,
Expand All @@ -96,7 +98,7 @@ def __init__(self, phi_loc, psi_loc, phi_concentration, psi_concentration, corre
@lazy_property
def norm_const(self):
corr = self.correlation.view(1, -1) + 1e-8
conc = torch.stack((self.phi_concentration, self.psi_concentration)).view(-1, 2)
conc = torch.stack((self.phi_concentration, self.psi_concentration), dim=-1).view(-1, 2)
m = torch.arange(50, device=self.phi_loc.device).view(-1, 1)
fs = SineBivariateVonMises._lbinoms(m.max() + 1).view(-1, 1) + 2 * m * torch.log(corr) - m * torch.log(
4 * torch.prod(conc, dim=-1))
Expand Down Expand Up @@ -141,7 +143,7 @@ def sample(self, sample_shape=torch.Size()):
# flatten batch_shape
conc = conc.view(2, -1, 1)
eigmin = eigmin.view(-1, 1)
corr = corr.view(-1, 1)
corr = corr.reshape(-1, 1)
eig = eig.view(2, -1)
b0 = b0.view(-1)
phi_den = log_I1(0, conc[1]).view(-1, 1)
Expand Down Expand Up @@ -188,12 +190,10 @@ def sample(self, sample_shape=torch.Size()):
beta = torch.atan(corr / conc[1] * torch.sin(phi))

psi = VonMises(beta, alpha).sample()
phi = phi.view((*self.batch_shape, total))
psi = psi.view((*self.batch_shape, total))

phi_psi = torch.vstack(((phi + self.phi_loc.view((*self.batch_shape, 1)) + pi) % (2 * pi) - pi,
(psi + self.psi_loc.view((*self.batch_shape, 1)) + pi) % (2 * pi) - pi)).T
return phi_psi.view(*sample_shape, *self.batch_shape, *self.event_shape)
phi_psi = torch.stack(((phi + self.phi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi,
(psi + self.psi_loc.reshape((-1, 1)) + pi) % (2 * pi) - pi), dim=-1).permute(1, 0, 2)
return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape)

@property
def mean(self):
Expand All @@ -209,6 +209,7 @@ def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
for k in SineBivariateVonMises.arg_constraints.keys():
setattr(new, k, getattr(self, k).expand(batch_shape))
new.norm_const = self.norm_const.expand(batch_shape)
super(SineBivariateVonMises, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
Expand Down
13 changes: 9 additions & 4 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,15 @@ def __init__(self, rate, *, validate_args=None):
]),
Fixture(pyro_dist=dist.SineBivariateVonMises,
examples=[
{'phi_loc': [0.], 'psi_loc': [0.], 'phi_concentration': [5.], 'psi_concentration': [6.],
'correlation': [2.], 'test_data': [[1., 0.]]},
{'phi_loc': [0.], 'psi_loc': [0.], 'phi_concentration': [5.], 'psi_concentration': [6.],
'weighted_correlation': [.5], 'test_data': [[1., 0.]]}
{'phi_loc': [math.pi - .2, 1.], 'psi_loc': [0., 1.],
'phi_concentration': [5., 5.], 'psi_concentration': [7., .5],
'weighted_correlation': [.5, .1], 'test_data': [[[1., -3.], [1., 59.]]]},
{'phi_loc': 0., 'psi_loc': 0., 'phi_concentration': 5., 'psi_concentration': 6.,
'correlation': 2., 'test_data': [1., 0.]},
{'phi_loc': [3.003], 'psi_loc': [-1.343], 'phi_concentration': [5.], 'psi_concentration': [6.],
'correlation': [2.], 'test_data': [[0., 1.]]},
{'phi_loc': -math.pi / 3, 'psi_loc': -1., 'phi_concentration': .5, 'psi_concentration': 10.,
'correlation': .1, 'test_data': [1., 0.555]},
]),
Fixture(pyro_dist=dist.SoftLaplace,
examples=[
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_bvm_multidim():
assert_equal(bmv.sample(sample_dim).shape, torch.Size((*sample_dim, *batch_dim, 2)))


def test_mle_bvm(): # FIXME
def test_mle_bvm():
vm = VonMises(tensor(0.), tensor(1.))
hn = HalfNormal(tensor(.8))
b = Beta(tensor(2.), tensor(5.))
Expand Down

0 comments on commit 1789fb0

Please sign in to comment.