From e6335e1f5db68def7e930ef49bebe946ed313fc5 Mon Sep 17 00:00:00 2001 From: Sayam753 Date: Mon, 1 Feb 2021 19:26:43 +0530 Subject: [PATCH 1/5] Fix Dirichlet.logp by checking number of categories > 1 only at event dims --- pymc3/distributions/multivariate.py | 2 +- pymc3/tests/test_distributions.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index 8e99ccee228..a0ca1514a97 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -522,7 +522,7 @@ def logp(self, value): tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)), tt.all(value >= 0), tt.all(value <= 1), - np.logical_not(a.broadcastable), + np.logical_not(a.broadcastable[-1]), tt.all(a > 0), broadcast_conditions=False, ) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 2648952bb32..2ac2be243e5 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1696,6 +1696,11 @@ def test_lkj(self, x, eta, n, lp): def test_dirichlet(self, n): self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf) + def test_dirichlet_with_unit_batch_shape(self): + with pm.Model() as model: + a = pm.Dirichlet("a", a=np.ones((1, 2))) + np.isfinite(model.check_test_point()[0]) + def test_dirichlet_shape(self): a = tt.as_tensor_variable(np.r_[1, 2]) with pytest.warns(DeprecationWarning): From 18fd1f57481a012503599651caaad87e12cda9f9 Mon Sep 17 00:00:00 2001 From: Sayam Kumar Date: Tue, 2 Feb 2021 23:37:05 +0530 Subject: [PATCH 2/5] Update test_distributions.py --- pymc3/tests/test_distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 2ac2be243e5..c85969a82e1 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1699,7 +1699,7 @@ def test_dirichlet(self, n): def test_dirichlet_with_unit_batch_shape(self): with pm.Model() as model: a = pm.Dirichlet("a", a=np.ones((1, 2))) - np.isfinite(model.check_test_point()[0]) + assert np.isfinite(model.check_test_point()[0]) def test_dirichlet_shape(self): a = tt.as_tensor_variable(np.r_[1, 2]) From a6a08bb72d94b0385dde7b4eeea54500ddac869b Mon Sep 17 00:00:00 2001 From: Sayam753 Date: Wed, 3 Feb 2021 04:01:22 +0530 Subject: [PATCH 3/5] Removed the shape validation check to even work for last dimensional shape as 1. Modified the `test_dirichlet` function to check for the same. --- pymc3/distributions/multivariate.py | 1 - pymc3/tests/test_distributions.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index a0ca1514a97..3fcdb8dbdaf 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -522,7 +522,6 @@ def logp(self, value): tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)), tt.all(value >= 0), tt.all(value <= 1), - np.logical_not(a.broadcastable[-1]), tt.all(a > 0), broadcast_conditions=False, ) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index c85969a82e1..fc63154f6b1 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1692,7 +1692,7 @@ def test_lkj(self, x, eta, n, lp): decimals = select_by_precision(float64=6, float32=4) assert_almost_equal(model.fastlogp(pt), lp, decimal=decimals, err_msg=str(pt)) - @pytest.mark.parametrize("n", [2, 3]) + @pytest.mark.parametrize("n", [1, 2, 3]) def test_dirichlet(self, n): self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf) From 4d7c19221a1126f94ecaa2bd2acdbaf2c742c32b Mon Sep 17 00:00:00 2001 From: Sayam753 Date: Thu, 4 Feb 2021 21:37:45 +0530 Subject: [PATCH 4/5] Added a test to check Dirichlet.logp with different batch shapes. --- pymc3/tests/test_distributions.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index fc63154f6b1..eab628ecba0 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1696,10 +1696,14 @@ def test_lkj(self, x, eta, n, lp): def test_dirichlet(self, n): self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf) - def test_dirichlet_with_unit_batch_shape(self): + @pytest.mark.parametrize("dist_shape", [1, (2, 1), (1, 2), (2, 4, 3)]) + def test_dirichlet_with_batch_shapes(self, dist_shape): + a = np.ones(dist_shape) with pm.Model() as model: - a = pm.Dirichlet("a", a=np.ones((1, 2))) - assert np.isfinite(model.check_test_point()[0]) + d = pm.Dirichlet("a", a=a) + + value = d.tag.test_value + assert_almost_equal(dirichlet_logpdf(value, a), d.distribution.logp(value).eval().sum()) def test_dirichlet_shape(self): a = tt.as_tensor_variable(np.r_[1, 2]) From 1333c232e2322614ea646372d78ec381e53dcabf Mon Sep 17 00:00:00 2001 From: Sayam753 Date: Thu, 4 Feb 2021 22:01:29 +0530 Subject: [PATCH 5/5] Tested exact Dirichlet.logp values againt scipy implementation Given a mention in RELEASE-NOTES.md --- RELEASE-NOTES.md | 1 + pymc3/tests/test_distributions.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 1223b547044..fc85b4a86db 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -13,6 +13,7 @@ - `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). - `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)). - `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)). +- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)). ## PyMC3 3.11.0 (21 January 2021) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index eab628ecba0..0b0dff2b82d 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1702,8 +1702,10 @@ def test_dirichlet_with_batch_shapes(self, dist_shape): with pm.Model() as model: d = pm.Dirichlet("a", a=a) - value = d.tag.test_value - assert_almost_equal(dirichlet_logpdf(value, a), d.distribution.logp(value).eval().sum()) + pymc3_res = d.distribution.logp(d.tag.test_value).eval() + for idx in np.ndindex(a.shape[:-1]): + scipy_res = scipy.stats.dirichlet(a[idx]).logpdf(d.tag.test_value[idx]) + assert_almost_equal(pymc3_res[idx], scipy_res) def test_dirichlet_shape(self): a = tt.as_tensor_variable(np.r_[1, 2])