Skip to content

Commit

Permalink
Allow zero rate in Poisson.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 317865805
  • Loading branch information
Googler authored and tensorflower-gardener committed Jun 23, 2020
1 parent f05f819 commit 0b66e74
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
6 changes: 3 additions & 3 deletions tensorflow_probability/python/distributions/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _log_unnormalized_prob(self, x, log_rate):
# The log-probability at negative points is always -inf.
# Catch such x's and set the output value accordingly.
safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.)
y = safe_x * log_rate - tf.math.lgamma(1. + safe_x)
y = tf.math.multiply_no_nan(log_rate, safe_x) - tf.math.lgamma(1. + safe_x)
return tf.where(
tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf))

Expand Down Expand Up @@ -332,9 +332,9 @@ def _parameter_control_dependencies(self, is_init):
assertions = []
if self._rate is not None:
if is_init != tensor_util.is_ref(self._rate):
assertions.append(assert_util.assert_positive(
assertions.append(assert_util.assert_non_negative(
self._rate,
message='Argument `rate` must be positive.'))
message='Argument `rate` must be non-negative.'))
return assertions

def _sample_control_dependencies(self, x):
Expand Down
24 changes: 16 additions & 8 deletions tensorflow_probability/python/distributions/poisson_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,20 @@ def testPoissonShape(self):
self.assertEqual(poisson.event_shape, tf.TensorShape([]))

def testInvalidLam(self):
invalid_lams = [-.01, 0., -2.]
invalid_lams = [-.01, -1., -2.]
for lam in invalid_lams:
with self.assertRaisesOpError('Argument `rate` must be positive.'):
with self.assertRaisesOpError('Argument `rate` must be non-negative.'):
poisson = self._make_poisson(rate=lam)
self.evaluate(poisson.rate_parameter())

def testZeroLam(self):
lam = 0.
poisson = tfd.Poisson(rate=lam, validate_args=True)
self.assertAllClose(lam, self.evaluate(poisson.rate))
self.assertAllClose(0., poisson.prob(3))
self.assertAllClose(1., poisson.prob(0))
self.assertAllClose(0., poisson.log_prob(0))

def testPoissonLogPmfDiscreteMatchesScipy(self):
batch_size = 12
lam = tf.constant([3.0] * batch_size)
Expand Down Expand Up @@ -333,19 +341,19 @@ def testGradientThroughRate(self):
self.assertLen(grad, 1)
self.assertAllNotNone(grad)

def testAssertsPositiveRate(self):
def testAssertsNonNegativeRate(self):
rate = tf.Variable([1., 2., -3.])
self.evaluate(rate.initializer)
with self.assertRaisesOpError('Argument `rate` must be positive.'):
with self.assertRaisesOpError('Argument `rate` must be non-negative.'):
dist = self._make_poisson(rate=rate, validate_args=True)
self.evaluate(dist.sample(seed=test_util.test_seed()))

def testAssertsPositiveRateAfterMutation(self):
def testAssertsNonNegativeRateAfterMutation(self):
rate = tf.Variable([1., 2., 3.])
self.evaluate(rate.initializer)
dist = self._make_poisson(rate=rate, validate_args=True)
self.evaluate(dist.mean())
with self.assertRaisesOpError('Argument `rate` must be positive.'):
with self.assertRaisesOpError('Argument `rate` must be non-negative.'):
with tf.control_dependencies([rate.assign([1., 2., -3.])]):
self.evaluate(dist.sample(seed=test_util.test_seed()))

Expand All @@ -367,10 +375,10 @@ def _make_poisson(self,
def testInvalidLam(self):
pass

def testAssertsPositiveRate(self):
def testAssertsNonNegativeRate(self):
pass

def testAssertsPositiveRateAfterMutation(self):
def testAssertsNonNegativeRateAfterMutation(self):
pass

# The gradient is not tracked through tf.math.log(rate) in _make_poisson(),
Expand Down

0 comments on commit 0b66e74

Please sign in to comment.