diff --git a/preliz/distributions/distributions.py b/preliz/distributions/distributions.py index 42715731..781e0ad1 100644 --- a/preliz/distributions/distributions.py +++ b/preliz/distributions/distributions.py @@ -315,8 +315,6 @@ def to_pymc(self, name=None, **kwargs): ------- PyMC distribution """ - pymc_dist = None - try: import pymc.distributions as pm_dists from pymc.model import Model @@ -369,10 +367,29 @@ def to_pymc(self, name=None, **kwargs): else: pymc_dist = pymc_class(name, **self.params_dict, **kwargs) + return pymc_dist + except ImportError: - pass + raise ImportError("This function requires PyMC") from None + + def to_bambi(self, **kwargs): + """ + Convert the PreliZ distribution to a Bambi Prior. + + kwargs : PyMC distributions properties + kwargs are used to specify properties such as shape or dims - return pymc_dist + Returns + ------- + Bambi Prior + """ + try: + from bambi import Prior + + return Prior(self.__class__.__name__, **self.params_dict, **kwargs) + + except ImportError: + raise ImportError("This function requires Bambi") from None def _check_endpoints(self, lower, upper, raise_error=True): """ diff --git a/preliz/tests/test_distributions.py b/preliz/tests/test_distributions.py index e82ba6cf..a7157622 100644 --- a/preliz/tests/test_distributions.py +++ b/preliz/tests/test_distributions.py @@ -313,3 +313,10 @@ def test_to_pymc(): assert model.basic_RVs[2].ndim == 0 assert Normal(0, 1).to_pymc(shape=2).ndim == 1 assert Censored(Normal(0, 1), lower=0).to_pymc().ndim == 0 + + +def test_to_bambi(): + bambi_prior = Gamma(mu=2, sigma=1).to_bambi() + assert bambi_prior.name == "Gamma" + assert bambi_prior.args["mu"] == 2 + assert bambi_prior.args["sigma"] == 1