Skip to content

Commit

Permalink
Support 4D ProjectedNormal distribution (#3011)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Jan 26, 2022
1 parent 65225ae commit 5ab7da2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
21 changes: 21 additions & 0 deletions pyro/distributions/projected_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,24 @@ def _log_prob_3(concentration, value):
).log()

return para_part + perp_part


@ProjectedNormal._register_log_prob(dim=4)
def _log_prob_4(concentration, value):
# We integrate along a ray, factorizing the integrand as a product of:
# a truncated normal distribution over coordinate t parallel to the ray, and
# a bivariate normal distribution over coordinate r perpendicular to the ray.
t = _dot(concentration, value)
t2 = t.square()
r2 = _dot(concentration, concentration) - t2
perp_part = r2.mul(-0.5) - 1.5 * math.log(2 * math.pi)

# This is the log of a definite integral, computed by mathematica:
# Integrate[x^3/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
# = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) (1 + Erf[t/Sqrt[2]]))/2
para_part = (
(2 + t2) * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5
+ t * (3 + t2) * (1 + (t * 0.5 ** 0.5).erf()) / 2
).log()

return para_part + perp_part
5 changes: 5 additions & 0 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ def __init__(self, von_loc, von_conc, skewness):
{"concentration": [2.0, 3.0], "test_data": [0.0, 1.0]},
{"concentration": [0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0]},
{"concentration": [-1.0, 2.0, 3.0], "test_data": [0.0, 0.0, 1.0]},
{"concentration": [0.0, 0.0, 0.0, 0.0], "test_data": [1.0, 0.0, 0.0, 0.0]},
{
"concentration": [-1.0, 2.0, 0.5, -0.5],
"test_data": [0.0, 1.0, 0.0, 0.0],
},
],
),
Fixture(
Expand Down

0 comments on commit 5ab7da2

Please sign in to comment.