From 02a5a36a8c90ca9c662c68603f89e39c015eb96f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Feb 2025 10:19:53 +0100 Subject: [PATCH] Make more distributions symbolic so they work in different backends --- .github/workflows/tests.yml | 14 +- ...l => environment-alternative-backends.yml} | 2 + pymc/distributions/continuous.py | 29 ++- pymc/distributions/multivariate.py | 210 ++++++++---------- pymc/distributions/shape_utils.py | 3 +- tests/distributions/test_multivariate.py | 46 ++-- .../test_random_alternative_backends.py | 70 ++++++ tests/sampling/test_jax.py | 27 +-- 8 files changed, 214 insertions(+), 187 deletions(-) rename conda-envs/{environment-jax.yml => environment-alternative-backends.yml} (96%) create mode 100644 tests/distributions/test_random_alternative_backends.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 97e50bef2a..fe6d723c97 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -281,7 +281,7 @@ jobs: name: ${{ matrix.os }} ${{ matrix.floatx }} fail_ci_if_error: false - external_samplers: + alternative_backends: needs: changes if: ${{ needs.changes.outputs.changes == 'true' }} strategy: @@ -290,7 +290,11 @@ jobs: floatx: [float64] python-version: ["3.13"] test-subset: - - tests/sampling/test_jax.py tests/sampling/test_mcmc_external.py + - | + tests/distributions/test_random_alternative_backends.py + tests/sampling/test_jax.py + tests/sampling/test_mcmc_external.py + fail-fast: false runs-on: ${{ matrix.os }} env: @@ -305,7 +309,7 @@ jobs: persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: - environment-file: conda-envs/environment-jax.yml + environment-file: conda-envs/environment-alternative-backends.yml create-args: >- python=${{matrix.python-version}} environment-name: pymc-test @@ -324,7 +328,7 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads env_vars: TEST_SUBSET - name: JAX tests - ${{ matrix.os }} ${{ matrix.floatx }} + name: Alternative backend tests - ${{ matrix.os }} ${{ matrix.floatx }} fail_ci_if_error: false float32: @@ -378,7 +382,7 @@ jobs: all_tests: if: ${{ always() }} runs-on: ubuntu-latest - needs: [ changes, ubuntu, windows, macos, external_samplers, float32 ] + needs: [ changes, ubuntu, windows, macos, alternative_backends, float32 ] steps: - name: Check build matrix status if: ${{ needs.changes.outputs.changes == 'true' && diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-alternative-backends.yml similarity index 96% rename from conda-envs/environment-jax.yml rename to conda-envs/environment-alternative-backends.yml index ea630e806f..cc7a1b6033 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -10,6 +10,8 @@ dependencies: - cachetools>=4.2.1 - cloudpickle - zarr>=2.5.0,<3 +- numba +- nutpie >= 0.13.4 # Jaxlib version must not be greater than jax version! - blackjax>=1.2.2 - jax>=0.4.28 diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 082be31d5c..13228794d3 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2595,23 +2595,27 @@ def dist(cls, nu, **kwargs): return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs) -class WeibullBetaRV(RandomVariable): +class WeibullBetaRV(SymbolicRandomVariable): name = "weibull" - signature = "(),()->()" - dtype = "floatX" + extended_signature = "[rng],[size],(),()->[rng],()" _print_name = ("Weibull", "\\operatorname{Weibull}") - def __call__(self, alpha, beta, size=None, **kwargs): - return super().__call__(alpha, beta, size=size, **kwargs) - @classmethod - def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray: - if size is None: - size = np.broadcast_shapes(alpha.shape, beta.shape) - return np.asarray(beta * rng.weibull(alpha, size=size)) + def rv_op(cls, alpha, beta, *, rng=None, size=None) -> np.ndarray: + alpha = pt.as_tensor(alpha) + beta = pt.as_tensor(beta) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + if rv_size_is_none(size): + size = implicit_size_from_params(alpha, beta, ndims_params=cls.ndims_params) -weibull_beta = WeibullBetaRV() + next_rng, raw_weibull = pt.random.weibull(alpha, size=size, rng=rng).owner.outputs + draws = beta * raw_weibull + return cls( + inputs=[rng, size, alpha, beta], + outputs=[next_rng, draws], + )(rng, size, alpha, beta) class Weibull(PositiveContinuous): @@ -2660,7 +2664,8 @@ class Weibull(PositiveContinuous): Scale parameter (beta > 0). """ - rv_op = weibull_beta + rv_type = WeibullBetaRV + rv_op = WeibullBetaRV.rv_op @classmethod def dist(cls, alpha, beta, *args, **kwargs): diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 9ff56027a8..24fa8bae56 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -37,10 +37,10 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace from pytensor.tensor.linalg import inv as matrix_inverse +from pytensor.tensor.random import chisquare from pytensor.tensor.random.basic import MvNormalRV, dirichlet, multinomial, multivariate_normal from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import ( - broadcast_params, normalize_size_param, ) from pytensor.tensor.type import TensorType @@ -365,33 +365,37 @@ def mv_normal_to_precision_mv_normal(fgraph, node): ) -class MvStudentTRV(RandomVariable): +class MvStudentTRV(SymbolicRandomVariable): + r"""A specialized multivariate normal random variable defined in terms of precision. + + This class is introduced during specialization logprob rewrites, and not meant to be used directly. + """ + name = "multivariate_studentt" - signature = "(),(n),(n,n)->(n)" - dtype = "floatX" + extended_signature = "[rng],[size],(),(n),(n,n)->[rng],(n)" _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") @classmethod - def rng_fn(cls, rng, nu, mu, cov, size): - if size is None: - # When size is implicit, we need to broadcast parameters correctly, - # so that the MvNormal draws and the chisquare draws have the same number of batch dimensions. - # nu broadcasts mu and cov - if np.ndim(nu) > max(mu.ndim - 1, cov.ndim - 2): - _, mu, cov = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params) - # nu is broadcasted by either mu or cov - elif np.ndim(nu) < max(mu.ndim - 1, cov.ndim - 2): - nu, _, _ = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params) - - mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size) - - # Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below - chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None] + def rv_op(cls, nu, mean, scale, *, rng=None, size=None): + nu = pt.as_tensor(nu) + mean = pt.as_tensor(mean) + scale = pt.as_tensor(scale) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) - return (mv_samples / chi2_samples) + mu + if rv_size_is_none(size): + size = implicit_size_from_params(nu, mean, scale, ndims_params=cls.ndims_params) + next_rng, mv_draws = multivariate_normal( + mean.zeros_like(), scale, size=size, rng=rng + ).owner.outputs + next_rng, chi2_draws = chisquare(nu, size=size, rng=next_rng).owner.outputs + draws = mean + (mv_draws / pt.sqrt(chi2_draws / nu)[..., None]) -mv_studentt = MvStudentTRV() + return cls( + inputs=[rng, size, nu, mean, scale], + outputs=[next_rng, draws], + )(rng, size, nu, mean, scale) class MvStudentT(Continuous): @@ -435,7 +439,8 @@ class MvStudentT(Continuous): Whether the cholesky fatcor is given as a lower triangular matrix. """ - rv_op = mv_studentt + rv_type = MvStudentTRV + rv_op = MvStudentTRV.rv_op @classmethod def dist(cls, nu, *, Sigma=None, mu=0, scale=None, tau=None, chol=None, lower=True, **kwargs): @@ -1152,57 +1157,6 @@ def _lkj_normalizing_constant(eta, n): return result -class _LKJCholeskyCovBaseRV(RandomVariable): - name = "_lkjcholeskycovbase" - signature = "(),(),(d)->(n)" - dtype = "floatX" - _print_name = ("_lkjcholeskycovbase", "\\operatorname{_lkjcholeskycovbase}") - - def make_node(self, rng, size, n, eta, D): - n = pt.as_tensor_variable(n) - if not all(n.type.broadcastable): - raise ValueError("n must be a scalar.") - - eta = pt.as_tensor_variable(eta) - if not all(eta.type.broadcastable): - raise ValueError("eta must be a scalar.") - - D = pt.as_tensor_variable(D) - - return super().make_node(rng, size, n, eta, D) - - def _supp_shape_from_params(self, dist_params, param_shapes): - n = dist_params[0].squeeze() - return ((n * (n + 1)) // 2,) - - def rng_fn(self, rng, n, eta, D, size): - # We flatten the size to make operations easier, and then rebuild it - if size is None: - size = D.shape[:-1] - flat_size = np.prod(size).astype(int) - - n = n.squeeze() - eta = eta.squeeze() - - C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - D = D.reshape(flat_size, n) - C *= D[..., :, np.newaxis] * D[..., np.newaxis, :] - - tril_idx = np.tril_indices(n, k=0) - samples = np.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]] - - if size is None: - samples = samples[0] - else: - dist_shape = (n * (n + 1)) // 2 - samples = np.reshape(samples, (*size, dist_shape)) - - return samples - - -_ljk_cholesky_cov_base = _LKJCholeskyCovBaseRV() - - # _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't # be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper class _LKJCholeskyCovRV(SymbolicRandomVariable): @@ -1223,21 +1177,40 @@ def rv_op(cls, n, eta, sd_dist, *, size=None): # for each diagonal element. # Since `eta` and `n` are forced to be scalars we don't need to worry about # implied batched dimensions from those for the time being. + if rv_size_is_none(size): - size = sd_dist.shape[:-1] + sd_dist_size = sd_dist.shape[:-1] + else: + sd_dist_size = size - shape = (*size, n) if sd_dist.owner.op.ndim_supp == 0: - sd_dist = change_dist_size(sd_dist, shape) + sd_dist = change_dist_size(sd_dist, (*sd_dist_size, n)) else: # The support shape must be `n` but we have no way of controlling it - sd_dist = change_dist_size(sd_dist, shape[:-1]) + sd_dist = change_dist_size(sd_dist, sd_dist_size) + + D = sd_dist.type(name="D") # Make sd_dist opaque to OpFromGraph + size = D.shape[:-1] + + # We flatten the size to make operations easier, and then rebuild it + flat_size = pt.prod(size, dtype="int64") + + next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) + D_matrix = D.reshape((flat_size, n)) + C *= D_matrix[..., :, None] * D_matrix[..., None, :] - next_rng, lkjcov = _ljk_cholesky_cov_base(n, eta, sd_dist, rng=rng).owner.outputs + tril_idx = pt.tril_indices(n, k=0) + samples = pt.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]] + + if rv_size_is_none(size): + samples = samples[0] + else: + dist_shape = (n * (n + 1)) // 2 + samples = pt.reshape(samples, (*size, dist_shape)) return _LKJCholeskyCovRV( - inputs=[rng, n, eta, sd_dist], - outputs=[next_rng, lkjcov], + inputs=[rng, n, eta, D], + outputs=[next_rng, samples], )(rng, n, eta, sd_dist) def update(self, node): @@ -1508,10 +1481,9 @@ def helper_deterministics(cls, n, packed_chol): return chol, corr, stds -class LKJCorrRV(RandomVariable): +class LKJCorrRV(SymbolicRandomVariable): name = "lkjcorr" - signature = "(),()->(n)" - dtype = "floatX" + extended_signature = "[rng],[size],(),()->[rng],(n)" _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}") def make_node(self, rng, size, n, eta): @@ -1525,55 +1497,66 @@ def make_node(self, rng, size, n, eta): return super().make_node(rng, size, n, eta) - def _supp_shape_from_params(self, dist_params, **kwargs): - n = dist_params[0].squeeze() - dist_shape = ((n * (n - 1)) // 2,) - return dist_shape - @classmethod - def rng_fn(cls, rng, n, eta, size): + def rv_op(cls, n: int, eta, *, rng=None, size=None): # We flatten the size to make operations easier, and then rebuild it - if size is None: + n = pt.as_tensor(n, ndim=0, dtype=int) + eta = pt.as_tensor(eta, ndim=0) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + + if rv_size_is_none(size): flat_size = 1 else: - flat_size = np.prod(size).astype(int) + flat_size = pt.prod(size, dtype="int64") - n = n.squeeze() - eta = eta.squeeze() - C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) + next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - triu_idx = np.triu_indices(n, k=1) + triu_idx = pt.triu_indices(n, k=1) samples = C[..., triu_idx[0], triu_idx[1]] - if size is None: + if rv_size_is_none(size): samples = samples[0] else: dist_shape = (n * (n - 1)) // 2 - samples = np.reshape(samples, (*size, dist_shape)) + samples = pt.reshape(samples, (*size, dist_shape)) + + return cls( + inputs=[rng, size, n, eta], + outputs=[next_rng, samples], + )(rng, size, n, eta) + return samples @classmethod - def _random_corr_matrix(cls, rng, n, eta, flat_size): + def _random_corr_matrix( + cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable + ) -> tuple[Variable, TensorVariable]: # original implementation in R see: # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r beta = eta - 1.0 + n / 2.0 - r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=flat_size, random_state=rng) - 1.0 - P = np.full((flat_size, n, n), np.eye(n)) - P[..., 0, 1] = r12 - P[..., 1, 1] = np.sqrt(1.0 - r12**2) + next_rng, beta_rvs = pt.random.beta( + alpha=beta, beta=beta, size=flat_size, rng=rng + ).owner.outputs + r12 = 2.0 * beta_rvs - 1.0 + P = pt.full((flat_size, n, n), pt.eye(n)) + P = P[..., 0, 1].set(r12) + P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2)) + n = get_underlying_scalar_constant_value(n) for mp1 in range(2, n): beta -= 0.5 - y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=flat_size, random_state=rng) - z = stats.norm.rvs(loc=0, scale=1, size=(flat_size, mp1), random_state=rng) - z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis] - P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z - P[..., mp1, mp1] = np.sqrt(1.0 - y) - C = np.einsum("...ji,...jk->...ik", P, P) - return C - - -lkjcorr = LKJCorrRV() + next_rng, y = pt.random.beta( + alpha=mp1 / 2.0, beta=beta, size=flat_size, rng=next_rng + ).owner.outputs + next_rng, z = pt.random.normal( + loc=0, scale=1, size=(flat_size, mp1), rng=next_rng + ).owner.outputs + z = z / pt.sqrt(pt.einsum("ij,ij->i", z, z.copy()))[..., np.newaxis] + P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z) + P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y)) + C = pt.einsum("...ji,...jk->...ik", P, P.copy()) + return next_rng, C class MultivariateIntervalTransform(Interval): @@ -1585,7 +1568,8 @@ def log_jac_det(self, *args): # Returns list of upper triangular values class _LKJCorr(BoundedContinuous): - rv_op = lkjcorr + rv_type = LKJCorrRV + rv_op = LKJCorrRV.rv_op @classmethod def dist(cls, n, eta, **kwargs): diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 98c743b70e..6f54aba2d1 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -468,5 +468,6 @@ def implicit_size_from_params( pt.broadcast_shape( *batch_shapes, arrays_are_shapes=True, - ) + ), + dtype="int64", # In case it's empty, as_tensor will default to floatX ) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 2694230c32..b184e04afa 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1310,32 +1310,15 @@ def test_kronecker_normal_support_point(self, mu, covs, size, expected): [ (3, 1, None, np.zeros(3)), (5, 1, None, np.zeros(10)), - pytest.param( - 3, - 1, - 1, - np.zeros((1, 3)), - marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="LKJCorr logp is only implemented for vector values (ndim=1)", - ), - ), - pytest.param( - 5, - 1, - (2, 3), - np.zeros((2, 3, 10)), - marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="LKJCorr logp is only implemented for vector values (ndim=1)", - ), - ), + pytest.param(3, 1, 1, np.zeros((1, 3))), + pytest.param(5, 1, (2, 3), np.zeros((2, 3, 10))), ], ) def test_lkjcorr_support_point(self, n, eta, size, expected): with pm.Model() as model: pm.LKJCorr("x", n=n, eta=eta, size=size, return_matrix=False) - assert_support_point_is_expected(model, expected) + # LKJCorr logp is only implemented for vector values (size=None) + assert_support_point_is_expected(model, expected, check_finite_logp=size is None) @pytest.mark.parametrize( "n, eta, size, expected", @@ -2190,15 +2173,18 @@ def ref_rand(size, n, eta): beta = eta - 1 + n / 2 return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 - continuous_random_tester( - _LKJCorr, - { - "n": Domain([2, 10, 50], edges=(None, None)), - "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), - }, - ref_rand=ref_rand, - size=1000, - ) + # If passed as a domain, continuous_random_tester would make `n` a shared variable + # But this RV needs it to be constant in order to define the inner graph + for n in (2, 10, 50): + continuous_random_tester( + _LKJCorr, + { + "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), + }, + extra_args={"n": n}, + ref_rand=ft.partial(ref_rand, n=n), + size=1000, + ) @pytest.mark.parametrize( diff --git a/tests/distributions/test_random_alternative_backends.py b/tests/distributions/test_random_alternative_backends.py new file mode 100644 index 0000000000..98214cdae9 --- /dev/null +++ b/tests/distributions/test_random_alternative_backends.py @@ -0,0 +1,70 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +import numpy as np +import pytest + +import pymc as pm + +from pymc import DirichletMultinomial, MvStudentT +from pymc.model.transform.optimization import freeze_dims_and_data + + +@pytest.fixture(params=["FAST_RUN", "JAX", "NUMBA"]) +def mode(request): + mode_param = request.param + if mode_param != "FAST_RUN": + pytest.importorskip(mode_param.lower()) + return mode_param + + +def test_dirichlet_multinomial(mode): + """Test we can draw from a DM in the JAX backend if the shape is constant.""" + dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01) + dm_draws = pm.draw(dm, mode=mode) + np.testing.assert_equal(dm_draws, np.eye(3) * 5) + + +def test_dirichlet_multinomial_dims(mode): + """Test we can draw from a DM with a shape defined by dims in the JAX backend, + after freezing those dims. + """ + with pm.Model(coords={"trial": range(3), "item": range(3)}) as m: + dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item")) + + # JAX does not allow us to JIT a function with dynamic shape + expected_ctxt = pytest.raises(TypeError) if mode == "JAX" else nullcontext() + with expected_ctxt: + pm.draw(dm, mode=mode) + + # Should be fine after freezing the dims that specify the shape + frozen_dm = freeze_dims_and_data(m)["dm"] + dm_draws = pm.draw(frozen_dm, mode=mode) + np.testing.assert_equal(dm_draws, np.eye(3) * 5) + + +def test_mvstudentt(mode): + mvt = MvStudentT.dist(nu=100, mu=[1, 2, 3], scale=np.eye(3) * [0.01, 1, 100], shape=(10_000, 3)) + draws = pm.draw(mvt, mode=mode) + np.testing.assert_allclose(draws.mean(0), [1, 2, 3], rtol=0.1) + np.testing.assert_allclose(draws.std(0), np.sqrt([0.01, 1, 100]), rtol=0.1) + + +def test_repeated_arguments(mode): + # Regression test for a failure in Numba mode when a RV had repeated arguments + v = 0.5 * 1e5 + x = pm.Beta.dist(v, v) + x_draw = pm.draw(x, mode=mode) + np.testing.assert_allclose(x_draw, 0.5, rtol=0.01) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index ddec60e539..0205c4ebf7 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -34,8 +34,7 @@ import pymc as pm from pymc import ImputationWarning -from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix -from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.distributions.multivariate import PosDefMatrix from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, @@ -505,27 +504,3 @@ def test_convergence_warnings(caplog, nuts_sampler): [record] = caplog.records assert re.match(r"There were \d+ divergences after tuning", record.message) - - -def test_dirichlet_multinomial(): - """Test we can draw from a DM in the JAX backend if the shape is constant.""" - dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01) - dm_draws = pm.draw(dm, mode="JAX") - np.testing.assert_equal(dm_draws, np.eye(3) * 5) - - -def test_dirichlet_multinomial_dims(): - """Test we can draw from a DM with a shape defined by dims in the JAX backend, - after freezing those dims. - """ - with pm.Model(coords={"trial": range(3), "item": range(3)}) as m: - dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item")) - - # JAX does not allow us to JIT a function with dynamic shape - with pytest.raises(TypeError): - pm.draw(dm, mode="JAX") - - # Should be fine after freezing the dims that specify the shape - frozen_dm = freeze_dims_and_data(m)["dm"] - dm_draws = pm.draw(frozen_dm, mode="JAX") - np.testing.assert_equal(dm_draws, np.eye(3) * 5)