From 84d301ff952237820dc44a46c102da399fa9f0ab Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Thu, 31 Oct 2024 22:50:27 -0300 Subject: [PATCH] Add to_bambi method (#578) * add to_bambi method * add to_bambi method --- preliz/distributions/distributions.py | 25 +++++++++++++++++++++---- preliz/tests/test_distributions.py | 7 +++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/preliz/distributions/distributions.py b/preliz/distributions/distributions.py index 42715731..781e0ad1 100644 --- a/preliz/distributions/distributions.py +++ b/preliz/distributions/distributions.py @@ -315,8 +315,6 @@ def to_pymc(self, name=None, **kwargs): ------- PyMC distribution """ - pymc_dist = None - try: import pymc.distributions as pm_dists from pymc.model import Model @@ -369,10 +367,29 @@ def to_pymc(self, name=None, **kwargs): else: pymc_dist = pymc_class(name, **self.params_dict, **kwargs) + return pymc_dist + except ImportError: - pass + raise ImportError("This function requires PyMC") from None + + def to_bambi(self, **kwargs): + """ + Convert the PreliZ distribution to a Bambi Prior. + + kwargs : PyMC distributions properties + kwargs are used to specify properties such as shape or dims - return pymc_dist + Returns + ------- + Bambi Prior + """ + try: + from bambi import Prior + + return Prior(self.__class__.__name__, **self.params_dict, **kwargs) + + except ImportError: + raise ImportError("This function requires Bambi") from None def _check_endpoints(self, lower, upper, raise_error=True): """ diff --git a/preliz/tests/test_distributions.py b/preliz/tests/test_distributions.py index e82ba6cf..a7157622 100644 --- a/preliz/tests/test_distributions.py +++ b/preliz/tests/test_distributions.py @@ -313,3 +313,10 @@ def test_to_pymc(): assert model.basic_RVs[2].ndim == 0 assert Normal(0, 1).to_pymc(shape=2).ndim == 1 assert Censored(Normal(0, 1), lower=0).to_pymc().ndim == 0 + + +def test_to_bambi(): + bambi_prior = Gamma(mu=2, sigma=1).to_bambi() + assert bambi_prior.name == "Gamma" + assert bambi_prior.args["mu"] == 2 + assert bambi_prior.args["sigma"] == 1