From 15026f02243302ed4f64eb9af258679ad332dc2a Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 27 Aug 2022 08:21:49 -0300 Subject: [PATCH 1/5] clean laplace results --- bambi/backend/pymc.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 36fb4760e..f14922a68 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -108,7 +108,7 @@ def run( elif inference_method == "vi": result = self._run_vi(**kwargs) elif inference_method == "laplace": - result = self._run_laplace(draws) + result = self._run_laplace(draws, omit_offsets, include_mean) else: raise NotImplementedError(f"{inference_method} method has not been implemented") @@ -349,10 +349,10 @@ def _run_mcmc( f"``mcmc``, ``nuts_numpyro`` or ``nuts_blackjax``" ) - idata = self._clean_mcmc_results(idata, omit_offsets, include_mean) + idata = self._clean_results(idata, omit_offsets, include_mean) return idata - def _clean_mcmc_results(self, idata, omit_offsets, include_mean): + def _clean_results(self, idata, omit_offsets, include_mean): for group in idata.groups(): getattr(idata, group).attrs["modeling_interface"] = "bambi" getattr(idata, group).attrs["modeling_interface_version"] = version.__version__ @@ -438,7 +438,7 @@ def _run_vi(self, **kwargs): self.vi_approx = pm.fit(**kwargs) return self.vi_approx - def _run_laplace(self, draws): + def _run_laplace(self, draws, omit_offsets, include_mean): """Fit a model using a Laplace approximation. Mainly for pedagogical use, provides reasonable results for approximately @@ -473,7 +473,9 @@ def _run_laplace(self, draws): samples = np.random.multivariate_normal(modes, cov, size=draws) - return _posterior_samples_to_idata(samples, self.model) + idata = _posterior_samples_to_idata(samples, self.model) + idata = self._clean_results(idata, omit_offsets, include_mean) + return idata def _posterior_samples_to_idata(samples, model): From 6bd9c622324ec7a0ae6305b7980f20cc5ac14e73 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 27 Aug 2022 08:45:07 -0300 Subject: [PATCH 2/5] disable pylint false alarm --- bambi/backend/links.py | 2 +- bambi/backend/pymc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bambi/backend/links.py b/bambi/backend/links.py index f347aeaf1..97259f3de 100644 --- a/bambi/backend/links.py +++ b/bambi/backend/links.py @@ -36,7 +36,7 @@ def identity(x): def inverse_squared(x): - return at.inv(at.sqrt(x)) + return at.inv(at.sqrt(x)) # pylint: disable=no-member def arctan_2(x): diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index f14922a68..ce735a0e1 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -23,7 +23,7 @@ class PyMCModel: "cloglog": cloglog, "identity": identity, "inverse_squared": inverse_squared, - "inverse": at.inv, + "inverse": at.inv, # pylint: disable=no-member "log": at.exp, "logit": logit, "probit": probit, From bba8562fc948693a1b9a4e439672d14fbcc59453 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 27 Aug 2022 09:09:30 -0300 Subject: [PATCH 3/5] revert pylint disable and change inv for reciprocal --- bambi/backend/links.py | 2 +- bambi/backend/pymc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bambi/backend/links.py b/bambi/backend/links.py index 97259f3de..759fae619 100644 --- a/bambi/backend/links.py +++ b/bambi/backend/links.py @@ -36,7 +36,7 @@ def identity(x): def inverse_squared(x): - return at.inv(at.sqrt(x)) # pylint: disable=no-member + return at.reciprocal(at.sqrt(x)) def arctan_2(x): diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index ce735a0e1..b66bcb767 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -23,7 +23,7 @@ class PyMCModel: "cloglog": cloglog, "identity": identity, "inverse_squared": inverse_squared, - "inverse": at.inv, # pylint: disable=no-member + "inverse": at.reciprocal, "log": at.exp, "logit": logit, "probit": probit, From 6e403c03d9fbf07957f159ff79ce510f1301b1a4 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 27 Aug 2022 11:21:00 -0300 Subject: [PATCH 4/5] fix docstring --- bambi/backend/pymc.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index b66bcb767..bf751c25f 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -451,6 +451,11 @@ def _run_laplace(self, draws, omit_offsets, include_mean): model: PyMC model draws: int The number of samples to draw from the posterior distribution. + omit_offsets: bool + Omits offset terms in the ``InferenceData`` object returned when the model includes + group specific effects. + include_mean: bool + Compute the posterior of the mean response. Returns ------- From be41dd5fbe4a0d1e8e6d593e4217e6a90e16f769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Sat, 27 Aug 2022 11:32:46 -0300 Subject: [PATCH 5/5] Remove `model` parameter from docstring. It's not part of the method signature. --- bambi/backend/pymc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index bf751c25f..bd7709d12 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -448,7 +448,6 @@ def _run_laplace(self, draws, omit_offsets, include_mean): Parameters ---------- - model: PyMC model draws: int The number of samples to draw from the posterior distribution. omit_offsets: bool