diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 1223b547044..fc85b4a86db 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -13,6 +13,7 @@ - `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). - `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)). - `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)). +- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)). ## PyMC3 3.11.0 (21 January 2021) diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index 8e99ccee228..3fcdb8dbdaf 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -522,7 +522,6 @@ def logp(self, value): tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)), tt.all(value >= 0), tt.all(value <= 1), - np.logical_not(a.broadcastable), tt.all(a > 0), broadcast_conditions=False, ) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 2648952bb32..0b0dff2b82d 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1692,10 +1692,21 @@ def test_lkj(self, x, eta, n, lp): decimals = select_by_precision(float64=6, float32=4) assert_almost_equal(model.fastlogp(pt), lp, decimal=decimals, err_msg=str(pt)) - @pytest.mark.parametrize("n", [2, 3]) + @pytest.mark.parametrize("n", [1, 2, 3]) def test_dirichlet(self, n): self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf) + @pytest.mark.parametrize("dist_shape", [1, (2, 1), (1, 2), (2, 4, 3)]) + def test_dirichlet_with_batch_shapes(self, dist_shape): + a = np.ones(dist_shape) + with pm.Model() as model: + d = pm.Dirichlet("a", a=a) + + pymc3_res = d.distribution.logp(d.tag.test_value).eval() + for idx in np.ndindex(a.shape[:-1]): + scipy_res = scipy.stats.dirichlet(a[idx]).logpdf(d.tag.test_value[idx]) + assert_almost_equal(pymc3_res[idx], scipy_res) + def test_dirichlet_shape(self): a = tt.as_tensor_variable(np.r_[1, 2]) with pytest.warns(DeprecationWarning):