From 0b66e747d985bafb2fdc31119986bebd6efc5946 Mon Sep 17 00:00:00 2001 From: Googler Date: Tue, 23 Jun 2020 08:02:20 -0700 Subject: [PATCH] Allow zero rate in Poisson. PiperOrigin-RevId: 317865805 --- .../python/distributions/poisson.py | 6 ++--- .../python/distributions/poisson_test.py | 24 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tensorflow_probability/python/distributions/poisson.py b/tensorflow_probability/python/distributions/poisson.py index 5fce497b08..ff85d39e3d 100644 --- a/tensorflow_probability/python/distributions/poisson.py +++ b/tensorflow_probability/python/distributions/poisson.py @@ -276,7 +276,7 @@ def _log_unnormalized_prob(self, x, log_rate): # The log-probability at negative points is always -inf. # Catch such x's and set the output value accordingly. safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.) - y = safe_x * log_rate - tf.math.lgamma(1. + safe_x) + y = tf.math.multiply_no_nan(log_rate, safe_x) - tf.math.lgamma(1. + safe_x) return tf.where( tf.equal(x, safe_x), y, dtype_util.as_numpy_dtype(y.dtype)(-np.inf)) @@ -332,9 +332,9 @@ def _parameter_control_dependencies(self, is_init): assertions = [] if self._rate is not None: if is_init != tensor_util.is_ref(self._rate): - assertions.append(assert_util.assert_positive( + assertions.append(assert_util.assert_non_negative( self._rate, - message='Argument `rate` must be positive.')) + message='Argument `rate` must be non-negative.')) return assertions def _sample_control_dependencies(self, x): diff --git a/tensorflow_probability/python/distributions/poisson_test.py b/tensorflow_probability/python/distributions/poisson_test.py index 8e0b8d6c75..8eb2daaa06 100644 --- a/tensorflow_probability/python/distributions/poisson_test.py +++ b/tensorflow_probability/python/distributions/poisson_test.py @@ -51,12 +51,20 @@ def testPoissonShape(self): self.assertEqual(poisson.event_shape, tf.TensorShape([])) def testInvalidLam(self): - invalid_lams = [-.01, 0., -2.] + invalid_lams = [-.01, -1., -2.] for lam in invalid_lams: - with self.assertRaisesOpError('Argument `rate` must be positive.'): + with self.assertRaisesOpError('Argument `rate` must be non-negative.'): poisson = self._make_poisson(rate=lam) self.evaluate(poisson.rate_parameter()) + def testZeroLam(self): + lam = 0. + poisson = tfd.Poisson(rate=lam, validate_args=True) + self.assertAllClose(lam, self.evaluate(poisson.rate)) + self.assertAllClose(0., poisson.prob(3)) + self.assertAllClose(1., poisson.prob(0)) + self.assertAllClose(0., poisson.log_prob(0)) + def testPoissonLogPmfDiscreteMatchesScipy(self): batch_size = 12 lam = tf.constant([3.0] * batch_size) @@ -333,19 +341,19 @@ def testGradientThroughRate(self): self.assertLen(grad, 1) self.assertAllNotNone(grad) - def testAssertsPositiveRate(self): + def testAssertsNonNegativeRate(self): rate = tf.Variable([1., 2., -3.]) self.evaluate(rate.initializer) - with self.assertRaisesOpError('Argument `rate` must be positive.'): + with self.assertRaisesOpError('Argument `rate` must be non-negative.'): dist = self._make_poisson(rate=rate, validate_args=True) self.evaluate(dist.sample(seed=test_util.test_seed())) - def testAssertsPositiveRateAfterMutation(self): + def testAssertsNonNegativeRateAfterMutation(self): rate = tf.Variable([1., 2., 3.]) self.evaluate(rate.initializer) dist = self._make_poisson(rate=rate, validate_args=True) self.evaluate(dist.mean()) - with self.assertRaisesOpError('Argument `rate` must be positive.'): + with self.assertRaisesOpError('Argument `rate` must be non-negative.'): with tf.control_dependencies([rate.assign([1., 2., -3.])]): self.evaluate(dist.sample(seed=test_util.test_seed())) @@ -367,10 +375,10 @@ def _make_poisson(self, def testInvalidLam(self): pass - def testAssertsPositiveRate(self): + def testAssertsNonNegativeRate(self): pass - def testAssertsPositiveRateAfterMutation(self): + def testAssertsNonNegativeRateAfterMutation(self): pass # The gradient is not tracked through tf.math.log(rate) in _make_poisson(),