diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 403d5be77..574bcd7fd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] name: Set up Python ${{ matrix.python-version }} steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f7f7bb22..566ead47a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,12 @@ ### Maintenance and fixes +* Fix bug in predictions with models using HSGP (#780) + ### Documentation +* Our Code of Conduct now includes how to send a report (#783) + ### Deprecation ## 0.13.0 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index fb45a77ff..930baed28 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,5 +1,9 @@ # Bambi Community Code of Conduct +Bambi adopts the NumFOCUS Code of Conduct directly. In other words, we expect our community to treat others with kindness and understanding. + +# The short version + Be kind to others. Do not insult or put down others. Behave professionally. Remember that harassment and sexist, racist, or exclusionary jokes are not appropriate. @@ -15,3 +19,31 @@ or religion. We do not tolerate harassment of community members in any form. Thank you for helping make this a welcoming, friendly community for all. + +# How to Submit a Report + +If you feel that there has been a Code of Conduct violation an anonymous +reporting form is available. + +**If you feel your safety is in jeopardy or the situation is an +emergency, we urge you to contact local law enforcement before making +a report. (In the U.S., dial 911.)** + +We are committed to promptly addressing any reported issues. +If you have experienced or witnessed behavior that violates this +Code of Conduct, please complete the form below to +make a report. + +**REPORTING FORM:** https://numfocus.typeform.com/to/ynjGdT + +Reports are sent to the NumFOCUS Code of Conduct Enforcement Team +(see below). + +You can view the Privacy Policy and Terms of Service for TypeForm here. +The NumFOCUS Privacy Policy is here: +https://www.numfocus.org/privacy-policy + +# Full Code of Conduct + +The full text of the NumFOCUS/Bambi Code of Conduct can be found on +NumFOCUS's website https://numfocus.org/code-of-conduct \ No newline at end of file diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 354f08ce3..82b646ebe 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -1,7 +1,9 @@ import functools +import importlib import logging +import operator import traceback - +import warnings from copy import deepcopy from importlib.metadata import version @@ -12,7 +14,6 @@ import pytensor.tensor as pt from pytensor.tensor.special import softmax - from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2 from bambi.backend.model_components import ConstantComponent, DistributionalComponent from bambi.utils import get_aliased_name @@ -46,6 +47,8 @@ def __init__(self): self.model = None self.spec = None self.components = {} + self.bayeux_methods = _get_bayeux_methods() + self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]} def build(self, spec): """Compile the PyMC model from an abstract model specification. @@ -94,8 +97,24 @@ def run( ): """Run PyMC sampler.""" inference_method = inference_method.lower() + + if inference_method == "nuts_numpyro": + inference_method = "numpyro_nuts" + warnings.warn( + "'nuts_numpyro' has been replaced by 'numpyro_nuts' and will be " + "removed in a future release", + category=FutureWarning, + ) + elif inference_method == "nuts_blackjax": + inference_method = "blackjax_nuts" + warnings.warn( + "'nuts_blackjax' has been replaced by 'blackjax_nuts' and will " + "be removed in a future release", + category=FutureWarning, + ) + # NOTE: Methods return different types of objects (idata, approximation, and dictionary) - if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]: + if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]): result = self._run_mcmc( draws, tune, @@ -110,7 +129,7 @@ def run( inference_method, **kwargs, ) - elif inference_method == "vi": + elif inference_method in self.pymc_methods["vi"]: result = self._run_vi(**kwargs) elif inference_method == "laplace": result = self._run_laplace(draws, omit_offsets, include_mean) @@ -169,8 +188,8 @@ def _run_mcmc( sampler_backend="mcmc", **kwargs, ): - with self.model: - if sampler_backend == "mcmc": + if sampler_backend in self.pymc_methods["mcmc"]: + with self.model: try: idata = pm.sample( draws=draws, @@ -205,41 +224,35 @@ def _run_mcmc( ) else: raise - elif sampler_backend == "nuts_numpyro": - import pymc.sampling_jax # pylint: disable=import-outside-toplevel - - if not chains: - # sample_numpyro_nuts does not handle chains = None like pm.sample does - chains = 4 - idata = pymc.sampling_jax.sample_numpyro_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, - ) - elif sampler_backend == "nuts_blackjax": - import pymc.sampling_jax # pylint: disable=import-outside-toplevel - - # sample_blackjax_nuts does not handle chains = None like pm.sample does - if not chains: - chains = 4 - idata = pymc.sampling_jax.sample_blackjax_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, - ) - else: - raise ValueError( - f"sampler_backend value {sampler_backend} is not valid. Please choose one of" - f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'" - ) - idata = self._clean_results(idata, omit_offsets, include_mean) + idata_from = "pymc" + elif sampler_backend in self.bayeux_methods["mcmc"]: + import bayeux as bx # pylint: disable=import-outside-toplevel + import jax # pylint: disable=import-outside-toplevel + + # Set the seed for reproducibility if provided + if random_seed is not None: + if not isinstance(random_seed, int): + random_seed = random_seed[0] + np.random.seed(random_seed) + + jax_seed = jax.random.PRNGKey(np.random.randint(2**32 - 1)) + + bx_model = bx.Model.from_pymc(self.model) + bx_sampler = operator.attrgetter(sampler_backend)( + bx_model.mcmc # pylint: disable=no-member + ) + idata = bx_sampler(seed=jax_seed, **kwargs) + idata_from = "bayeux" + else: + raise ValueError( + f"sampler_backend value {sampler_backend} is not valid. Please choose one of" + f" {self.pymc_methods['mcmc'] + self.bayeux_methods['mcmc']}" + ) + + idata = self._clean_results(idata, omit_offsets, include_mean, idata_from) return idata - def _clean_results(self, idata, omit_offsets, include_mean): + def _clean_results(self, idata, omit_offsets, include_mean, idata_from): for group in idata.groups(): getattr(idata, group).attrs["modeling_interface"] = "bambi" @@ -258,6 +271,15 @@ def _clean_results(self, idata, omit_offsets, include_mean): dims_original = list(self.model.coords) + # Identify bayeux idata and rename dims and coordinates to match PyMC model + if idata_from == "bayeux": + pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim] + bayeux_dims = [ + dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw")) + ] + cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims)) + idata = idata.rename(cleaned_dims) + # Discard dims that are in the model but unused in the posterior dims_original = [dim for dim in dims_original if dim in idata.posterior.dims] @@ -272,7 +294,6 @@ def _clean_results(self, idata, omit_offsets, include_mean): idata.posterior = idata.posterior.transpose(*dims_new) # Compute the actual intercept in all distributional components that have an intercept - for pymc_component in self.distributional_components.values(): bambi_component = pymc_component.component if ( @@ -317,8 +338,8 @@ def _run_laplace(self, draws, omit_offsets, include_mean): Mainly for pedagogical use, provides reasonable results for approximately Gaussian posteriors. The approximation can be very poor for some models - like hierarchical ones. Use ``mcmc``, ``nuts_numpyro``, ``nuts_blackjax`` - or ``vi`` for better approximations. + like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods + for better approximations. Parameters ---------- @@ -352,7 +373,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean): samples = np.random.multivariate_normal(modes, cov, size=draws) idata = _posterior_samples_to_idata(samples, self.model) - idata = self._clean_results(idata, omit_offsets, include_mean) + idata = self._clean_results(idata, omit_offsets, include_mean, idata_from="pymc") return idata @property @@ -367,6 +388,10 @@ def constant_components(self): def distributional_components(self): return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)} + @property + def inference_methods(self): + return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} + def _posterior_samples_to_idata(samples, model): """Create InferenceData from samples. @@ -406,3 +431,22 @@ def _posterior_samples_to_idata(samples, model): idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model) return idata + + +def _get_bayeux_methods(): + """Gets a dictionary of usable bayeux methods if the bayeux package is installed + within the user's environment. + + Returns + ------- + dict + A dict where the keys are the module names and the values are the methods + available in that module. + """ + if importlib.util.find_spec("bayeux") is None: + return {"mcmc": []} + + import bayeux as bx # pylint: disable=import-outside-toplevel + + # Dummy log density to get access to all methods + return bx.Model(lambda x: -(x**2), 0.0).methods diff --git a/bambi/data/__init__.py b/bambi/data/__init__.py index 38adba619..1f6fb200c 100644 --- a/bambi/data/__init__.py +++ b/bambi/data/__init__.py @@ -1,4 +1,5 @@ """Code for loading datasets.""" + from .datasets import clear_data_home, load_data __all__ = ["clear_data_home", "load_data"] diff --git a/bambi/data/datasets.py b/bambi/data/datasets.py index dc063cb18..0be57e9ba 100644 --- a/bambi/data/datasets.py +++ b/bambi/data/datasets.py @@ -1,4 +1,5 @@ """Base IO code for datasets. Heavily influenced by Arviz's (and scikit-learn's) implementation.""" + import hashlib import itertools import os diff --git a/bambi/defaults/__init__.py b/bambi/defaults/__init__.py index 3dedec422..53ef82498 100644 --- a/bambi/defaults/__init__.py +++ b/bambi/defaults/__init__.py @@ -1,4 +1,5 @@ """Settings for default priors, families, etc. in Bambi.""" + from bambi.defaults.utils import get_default_prior from bambi.defaults.families import get_builtin_family diff --git a/bambi/families/__init__.py b/bambi/families/__init__.py index df1a84a5a..645d27bee 100644 --- a/bambi/families/__init__.py +++ b/bambi/families/__init__.py @@ -1,4 +1,5 @@ """Classes to construct model families.""" + from bambi.families.family import Family from bambi.families.likelihood import Likelihood from bambi.families.link import Link diff --git a/bambi/interpret/utils.py b/bambi/interpret/utils.py index a56e23560..cbb7bde19 100644 --- a/bambi/interpret/utils.py +++ b/bambi/interpret/utils.py @@ -102,6 +102,7 @@ def set_default_variable_values(self) -> np.ndarray: If categoric dtype the returned value is the unique levels of `variable'. """ + values = None # Otherwise pylint complains terms = get_model_terms(self.model) # get default values for each variable in the model for term in terms.values(): @@ -236,11 +237,20 @@ def get_model_covariates(model: Model) -> np.ndarray: for term in terms.values(): if hasattr(term, "components"): for component in term.components: - # if the component is a function call, use the argument names + # if the component is a function call, look for relevant argument names if isinstance(component, Call): + # Add variable names passed as unnamed arguments covariates.append( [arg.name for arg in component.call.args if isinstance(arg, LazyVariable)] ) + # Add variable names passed as named arguments + covariates.append( + [ + kwarg_value.name + for kwarg_value in component.call.kwargs.values() + if isinstance(kwarg_value, LazyVariable) + ] + ) else: covariates.append([component.name]) elif hasattr(term, "factor"): diff --git a/bambi/model_components.py b/bambi/model_components.py index f4691e5e2..44c781127 100644 --- a/bambi/model_components.py +++ b/bambi/model_components.py @@ -239,11 +239,12 @@ def predict_common( X = np.delete(X, term_slice, axis=1) # Add HSGP components contribution to the linear predictor + hsgp_slices = [] for term_name, term in self.hsgp_terms.items(): # Extract data for the HSGP component from the design matrix term_slice = self.design.common.slices[term_name] x_slice = X[:, term_slice] - X = np.delete(X, term_slice, axis=1) + hsgp_slices.append(term_slice) term_aliased_name = get_aliased_name(term) hsgp_to_stack_dims = (f"{term_aliased_name}_weights_dim",) @@ -288,6 +289,12 @@ def predict_common( # Add contribution to the linear predictor linear_predictor += hsgp_contribution + # Remove columns of X that are associated with HSGP contributions + # All the slices _must be_ deleted at the same time. Otherwise the slice objects don't + # reflect the right columns of X at the time they're used + if hsgp_slices: + X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1) + if self.common_terms or self.intercept_term: # Create DataArray X_terms = [get_aliased_name(term) for term in self.common_terms.values()] diff --git a/bambi/models.py b/bambi/models.py index 5d5eafe09..9b5421d51 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -266,9 +266,9 @@ def fit( using the ``fit`` function. Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. - To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax`` - respectively. Both methods will only work if you can use NUTS sampling, so your model - must be differentiable. + To get a list of JAX based inference methods, call + ``model.backend.inference_methods['bayeux']``. This will return a dictionary of the + available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others. init : str Initialization method. Defaults to ``"auto"``. The available methods are: * auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``. @@ -306,7 +306,8 @@ def fit( Returns ------- An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default), - "nuts_numpyro", "nuts_blackjax" or "laplace". + "laplace", or one of the MCMC methods in + ``model.backend.inference_methods['bayeux']['mcmc]``. An ``Approximation`` object if ``"vi"``. """ method = kwargs.pop("method", None) diff --git a/bambi/priors/__init__.py b/bambi/priors/__init__.py index c90e68945..6884486a6 100644 --- a/bambi/priors/__init__.py +++ b/bambi/priors/__init__.py @@ -1,4 +1,5 @@ """Classes to represent prior distributions and methods to set automatic priors""" + from .prior import Prior from .scaler import PriorScaler diff --git a/bambi/terms/base.py b/bambi/terms/base.py index 81fb77a2a..c11f55bc6 100644 --- a/bambi/terms/base.py +++ b/bambi/terms/base.py @@ -13,33 +13,27 @@ class BaseTerm(ABC): @property @abstractmethod - def term(self): - ... + def term(self): ... @property @abstractmethod - def data(self): - ... + def data(self): ... @property @abstractmethod - def name(self): - ... + def name(self): ... @property @abstractmethod - def shape(self): - ... + def shape(self): ... @property @abstractmethod - def levels(self): - ... + def levels(self): ... @property @abstractmethod - def categorical(self): - ... + def categorical(self): ... @property def alias(self): diff --git a/bambi/transformations.py b/bambi/transformations.py index 9442b0a15..eb226ed43 100644 --- a/bambi/transformations.py +++ b/bambi/transformations.py @@ -175,6 +175,7 @@ def weighted(x, weights): weighted.__metadata__ = {"kind": "weighted"} + # pylint: disable = invalid-name @register_stateful_transform class HSGP: # pylint: disable = too-many-instance-attributes diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 0d141baf0..87ba502de 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -89,6 +89,9 @@ website: - notebooks/plot_comparisons.ipynb - notebooks/plot_slopes.ipynb - notebooks/interpret_advanced_usage.ipynb + - section: Alternative sampling backends + contents: + - notebooks/alternative_samplers.ipynb quartodoc: style: pkgdown diff --git a/docs/notebooks/alternative_samplers.ipynb b/docs/notebooks/alternative_samplers.ipynb new file mode 100644 index 000000000..24d610d96 --- /dev/null +++ b/docs/notebooks/alternative_samplers.ipynb @@ -0,0 +1,6963 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Alternative sampling backends\n", + "\n", + "In Bambi, the sampler used is automatically selected given the type of variables used in the model. For inference, Bambi supports both MCMC and variational inference. By default, Bambi uses PyMC's implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, it is not the only sampling method, nor is PyMC the only library implementing NUTS. \n", + "\n", + "To this extent, Bambi supports multiple backends for MCMC sampling such as NumPyro and Blackjax. This notebook will cover how to use such alternatives in Bambi.\n", + "\n", + "_Note_: Bambi utilizes [bayeux](https://github.com/jax-ml/bayeux) to access a variety of sampling backends. Thus, you will need to install the optional dependencies in the Bambi [pyproject.toml](https://github.com/bambinos/bambi/blob/main/pyproject.toml) file to use these backends." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import arviz as az\n", + "import bambi as bmb\n", + "import bayeux as bx\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## bayeux\n", + "\n", + "Bambi leverages `bayeux` to access different sampling backends. In short, `bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. \n", + "\n", + "Since the underlying Bambi model is a PyMC model, this PyMC model can be \"given\" to `bayeux`. Then, we can choose from a variety of MCMC methods to perform inference. \n", + "\n", + "To demonstrate the available backends, we will fist simulate data and build a model." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "num_samples = 100\n", + "num_features = 1\n", + "noise_std = 1.0\n", + "random_seed = 42\n", + "\n", + "np.random.seed(random_seed)\n", + "\n", + "coefficients = np.random.randn(num_features)\n", + "X = np.random.randn(num_samples, num_features)\n", + "error = np.random.normal(scale=noise_std, size=num_samples)\n", + "y = X @ coefficients + error\n", + "\n", + "data = pd.DataFrame({\"y\": y, \"x\": X.flatten()})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model = bmb.Model(\"y ~ x\", data)\n", + "model.build()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can call `model.backend.inference_methods` that returns a nested dictionary of the backends and list of inference methods." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},\n", + " 'bayeux': {'mcmc': ['tfp_hmc',\n", + " 'tfp_nuts',\n", + " 'tfp_snaper_hmc',\n", + " 'blackjax_hmc',\n", + " 'blackjax_chees_hmc',\n", + " 'blackjax_meads_hmc',\n", + " 'blackjax_nuts',\n", + " 'blackjax_hmc_pathfinder',\n", + " 'blackjax_nuts_pathfinder',\n", + " 'flowmc_rqspline_hmc',\n", + " 'flowmc_rqspline_mala',\n", + " 'flowmc_realnvp_hmc',\n", + " 'flowmc_realnvp_mala',\n", + " 'numpyro_hmc',\n", + " 'numpyro_nuts'],\n", + " 'optimize': ['jaxopt_bfgs',\n", + " 'jaxopt_gradient_descent',\n", + " 'jaxopt_lbfgs',\n", + " 'jaxopt_nonlinear_cg',\n", + " 'optimistix_bfgs',\n", + " 'optimistix_chord',\n", + " 'optimistix_dogleg',\n", + " 'optimistix_gauss_newton',\n", + " 'optimistix_indirect_levenberg_marquardt',\n", + " 'optimistix_levenberg_marquardt',\n", + " 'optimistix_nelder_mead',\n", + " 'optimistix_newton',\n", + " 'optimistix_nonlinear_cg',\n", + " 'optax_adabelief',\n", + " 'optax_adafactor',\n", + " 'optax_adagrad',\n", + " 'optax_adam',\n", + " 'optax_adamw',\n", + " 'optax_adamax',\n", + " 'optax_amsgrad',\n", + " 'optax_fromage',\n", + " 'optax_lamb',\n", + " 'optax_lion',\n", + " 'optax_noisy_sgd',\n", + " 'optax_novograd',\n", + " 'optax_radam',\n", + " 'optax_rmsprop',\n", + " 'optax_sgd',\n", + " 'optax_sm3',\n", + " 'optax_yogi'],\n", + " 'vi': ['tfp_factored_surrogate_posterior']}}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "methods = model.backend.inference_methods\n", + "methods" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the PyMC backend, we have access to their implementation of the NUTS sampler and mean-field variational inference." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'mcmc': ['mcmc'], 'vi': ['vi']}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "methods[\"pymc\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`bayeux` lets us have access to Tensorflow probability, Blackjax, FlowMC, and NumPyro backends." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['tfp_hmc',\n", + " 'tfp_nuts',\n", + " 'tfp_snaper_hmc',\n", + " 'blackjax_hmc',\n", + " 'blackjax_chees_hmc',\n", + " 'blackjax_meads_hmc',\n", + " 'blackjax_nuts',\n", + " 'blackjax_hmc_pathfinder',\n", + " 'blackjax_nuts_pathfinder',\n", + " 'flowmc_rqspline_hmc',\n", + " 'flowmc_rqspline_mala',\n", + " 'flowmc_realnvp_hmc',\n", + " 'flowmc_realnvp_mala',\n", + " 'numpyro_hmc',\n", + " 'numpyro_nuts']" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "methods[\"bayeux\"][\"mcmc\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The values of the MCMC and VI keys in the dictionary are the names of the argument you would pass to `inference_method` in `model.fit`. This is shown in the section below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying an `inference_method`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the `bayeux` MCMC method to the `inference_method` parameter of the `fit` method." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Blackjax" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + " \n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "blackjax_nuts_idata = model.fit(inference_method=\"blackjax_nuts\")\n", + "blackjax_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own `kwargs` to the `fit` method.\n", + "\n", + "Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': .wrap_log_density..wrapped(args)>,\n", + " 'is_mass_matrix_diagonal': True,\n", + " 'initial_step_size': 1.0,\n", + " 'target_acceptance_rate': 0.8,\n", + " 'progress_bar': False,\n", + " 'algorithm': blackjax.mcmc.nuts.nuts},\n", + " 'adapt.run': {'num_steps': 500},\n", + " blackjax.mcmc.nuts.nuts: {'max_num_doublings': 10,\n", + " 'divergence_threshold': 1000,\n", + " 'integrator': .euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,\n", + " 'logdensity_fn': .wrap_log_density..wrapped(args)>,\n", + " 'step_size': 0.5},\n", + " 'extra_parameters': {'chain_method': 'vectorized',\n", + " 'num_chains': 8,\n", + " 'num_draws': 500,\n", + " 'num_adapt_draws': 500,\n", + " 'return_pytree': False}}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bx.Model.from_pymc(model.backend.model).mcmc.blackjax_nuts.get_kwargs()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can identify the kwargs we would like to change and pass to the `fit` method." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 4, draw: 250)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 242 243 244 245 246 247 248 249\n",
      +       "Data variables:\n",
      +       "    y_sigma    (chain, draw) float64 1.078 1.05 0.8647 ... 0.856 0.9391 0.9165\n",
      +       "    Intercept  (chain, draw) float64 -0.1116 -0.1474 ... -0.04961 0.0266\n",
      +       "    x          (chain, draw) float64 0.4042 0.3106 0.4226 ... 0.2611 0.3592\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:03.782531\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:          (chain: 4, draw: 250)\n",
      +       "Coordinates:\n",
      +       "  * chain            (chain) int64 0 1 2 3\n",
      +       "  * draw             (draw) int64 0 1 2 3 4 5 6 ... 243 244 245 246 247 248 249\n",
      +       "Data variables:\n",
      +       "    lp               (chain, draw) float64 -142.2 -141.9 ... -139.9 -139.3\n",
      +       "    step_size        (chain, draw) float64 0.9072 0.9072 ... 0.7606 0.7606\n",
      +       "    diverging        (chain, draw) bool False False False ... False False False\n",
      +       "    energy           (chain, draw) float64 144.6 143.1 142.2 ... 141.1 140.1\n",
      +       "    tree_depth       (chain, draw) int64 3 3 2 3 2 1 2 2 2 ... 3 3 2 2 3 3 2 3 2\n",
      +       "    n_steps          (chain, draw) int64 7 7 3 7 3 1 3 3 3 ... 7 7 3 3 7 7 3 7 3\n",
      +       "    acceptance_rate  (chain, draw) float64 1.0 0.9854 0.9968 ... 0.9882 0.9931\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:03.784254\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kwargs = {\n", + " \"adapt.run\": {\"num_steps\": 500},\n", + " \"num_chains\": 4,\n", + " \"num_draws\": 250,\n", + " \"num_adapt_draws\": 250\n", + "}\n", + "\n", + "blackjax_nuts_idata = model.fit(inference_method=\"blackjax_nuts\", **kwargs)\n", + "blackjax_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tensorflow probability" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 8, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "Data variables:\n",
      +       "    y_sigma    (chain, draw) float64 0.9946 0.8708 0.8651 ... 0.9908 0.9958\n",
      +       "    Intercept  (chain, draw) float64 -0.09685 -0.01575 0.0419 ... 0.1091 0.1152\n",
      +       "    x          (chain, draw) float64 0.4584 0.399 0.4485 ... 0.5167 0.4703\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T15:57:23.746257\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:          (chain: 8, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain            (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw             (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999\n",
      +       "Data variables:\n",
      +       "    accept_ratio     (chain, draw) float64 1.0 0.9889 0.9769 ... 1.0 0.9975 1.0\n",
      +       "    diverging        (chain, draw) bool False False False ... False False False\n",
      +       "    is_accepted      (chain, draw) bool True True True True ... True True True\n",
      +       "    n_steps          (chain, draw) int32 7 3 7 7 3 7 3 7 7 ... 3 3 7 3 7 1 1 3 7\n",
      +       "    step_size        (chain, draw) float64 0.5332 0.5332 0.5332 ... nan nan nan\n",
      +       "    target_log_prob  (chain, draw) float64 -141.0 -139.9 ... -140.9 -140.5\n",
      +       "    tune             (chain, draw) float64 0.0 0.0 0.0 0.0 ... nan nan nan nan\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T15:57:23.747950\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tfp_nuts_idata = model.fit(inference_method=\"tfp_nuts\")\n", + "tfp_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### NumPyro" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sample: 100%|██████████| 1500/1500 [00:02<00:00, 551.76it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 8, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "Data variables:\n",
      +       "    Intercept  (chain, draw) float64 -0.02485 0.1376 -0.00766 ... 0.01202 0.0375\n",
      +       "    x          (chain, draw) float64 0.4336 0.4907 0.4996 ... 0.4032 0.3964\n",
      +       "    y_sigma    (chain, draw) float64 0.9225 1.015 0.9409 ... 0.8574 0.9083 0.822\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:07.292211\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    inference_library:           numpyro\n",
      +       "    inference_library_version:   0.13.2\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:          (chain: 8, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain            (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw             (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999\n",
      +       "Data variables:\n",
      +       "    acceptance_rate  (chain, draw) float64 0.9973 0.6392 0.987 ... 0.9744 0.8087\n",
      +       "    step_size        (chain, draw) float64 0.7525 0.7525 ... 0.8295 0.8295\n",
      +       "    diverging        (chain, draw) bool False False False ... False False False\n",
      +       "    energy           (chain, draw) float64 140.6 143.8 141.9 ... 140.8 141.4\n",
      +       "    n_steps          (chain, draw) int64 3 3 3 3 1 1 3 7 7 ... 7 7 3 7 7 7 15 3\n",
      +       "    tree_depth       (chain, draw) int64 2 2 2 2 1 1 2 3 3 ... 2 3 3 2 3 3 3 4 2\n",
      +       "    lp               (chain, draw) float64 139.7 141.1 140.3 ... 139.4 141.0\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:07.316723\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    inference_library:           numpyro\n",
      +       "    inference_library_version:   0.13.2\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "numpyro_nuts_idata = model.fit(inference_method=\"numpyro_nuts\")\n", + "numpyro_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### flowMC" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No autotune found, use input sampler_params\n", + "Training normalizing flow\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.23s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting Production run\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Production run: 100%|██████████| 5/5 [00:00<00:00, 9.38it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 20, draw: 500)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "Data variables:\n",
      +       "    y_sigma    (chain, draw) float64 0.8082 1.024 1.024 ... 0.971 0.971 0.971\n",
      +       "    Intercept  (chain, draw) float64 0.09035 0.06867 0.06867 ... -0.1322 -0.1322\n",
      +       "    x          (chain, draw) float64 0.4452 0.503 0.503 ... 0.3238 0.3238 0.3238\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:59.802971\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "flowmc_idata = model.fit(inference_method=\"flowmc_realnvp_hmc\")\n", + "flowmc_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampler comparisons\n", + "\n", + "With ArviZ, we can compare the inference result summaries of the samplers. _Note:_ We can't use `az.compare` as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
y_sigma0.9450.0700.8191.0800.0020.0021044.0667.01.0
Intercept0.0180.089-0.1560.1850.0030.002844.0733.01.0
x0.3580.1050.1630.5540.0040.003829.0767.01.0
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "y_sigma 0.945 0.070 0.819 1.080 0.002 0.002 1044.0 \n", + "Intercept 0.018 0.089 -0.156 0.185 0.003 0.002 844.0 \n", + "x 0.358 0.105 0.163 0.554 0.004 0.003 829.0 \n", + "\n", + " ess_tail r_hat \n", + "y_sigma 667.0 1.0 \n", + "Intercept 733.0 1.0 \n", + "x 767.0 1.0 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(blackjax_nuts_idata)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
y_sigma0.9480.0670.8241.0730.0010.0018107.05585.01.0
Intercept0.0250.095-0.1520.2000.0010.0016772.05624.01.0
x0.3610.1040.1570.5510.0010.0016682.05414.01.0
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "y_sigma 0.948 0.067 0.824 1.073 0.001 0.001 8107.0 \n", + "Intercept 0.025 0.095 -0.152 0.200 0.001 0.001 6772.0 \n", + "x 0.361 0.104 0.157 0.551 0.001 0.001 6682.0 \n", + "\n", + " ess_tail r_hat \n", + "y_sigma 5585.0 1.0 \n", + "Intercept 5624.0 1.0 \n", + "x 5414.0 1.0 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(tfp_nuts_idata)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
Intercept0.0220.097-0.1490.2170.0010.0017412.05758.01.0
x0.3590.1050.1590.5550.0010.0017406.05967.01.0
y_sigma0.9470.0690.8221.0790.0010.0017371.05405.01.0
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "Intercept 0.022 0.097 -0.149 0.217 0.001 0.001 7412.0 \n", + "x 0.359 0.105 0.159 0.555 0.001 0.001 7406.0 \n", + "y_sigma 0.947 0.069 0.822 1.079 0.001 0.001 7371.0 \n", + "\n", + " ess_tail r_hat \n", + "Intercept 5758.0 1.0 \n", + "x 5967.0 1.0 \n", + "y_sigma 5405.0 1.0 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(numpyro_nuts_idata)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
y_sigma0.9460.0670.8251.0760.0010.0016260.05213.01.00
Intercept0.0130.093-0.1650.1900.0030.002924.01302.01.02
x0.3590.1030.1660.5560.0010.0015132.05790.01.00
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "y_sigma 0.946 0.067 0.825 1.076 0.001 0.001 6260.0 \n", + "Intercept 0.013 0.093 -0.165 0.190 0.003 0.002 924.0 \n", + "x 0.359 0.103 0.166 0.556 0.001 0.001 5132.0 \n", + "\n", + " ess_tail r_hat \n", + "y_sigma 5213.0 1.00 \n", + "Intercept 1302.0 1.02 \n", + "x 5790.0 1.00 " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(flowmc_idata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "Thanks to `bayeux`, we can use three different sampling backends and 10+ alternative MCMC methods in Bambi. Using these methods is as simple as passing the inference name to the `inference_method` of the `fit` method." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last updated: Fri Mar 01 2024\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.11.7\n", + "IPython version : 8.21.0\n", + "\n", + "arviz : 0.17.0\n", + "bambi : 0.13.1.dev16+g9a1387a7.d20240204\n", + "numpy : 1.26.3\n", + "pandas : 2.2.0\n", + "bayeux : 0.1.9\n", + "matplotlib: 3.8.2\n", + "\n", + "Watermark: 2.4.3\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bayeux_bambi", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/gallery.yml b/docs/notebooks/gallery.yml index e4bb954a8..cbce285e2 100644 --- a/docs/notebooks/gallery.yml +++ b/docs/notebooks/gallery.yml @@ -133,4 +133,10 @@ - title: Advanced interpret usage subtitle: Create data grids and compute complex quantities of interest href: interpret_advanced_usage.ipynb - thumbnail: thumbnails/advanced_interpret.png \ No newline at end of file + thumbnail: thumbnails/advanced_interpret.png +- category: Alternative sampling backends + description: "" + tiles: + - title: Using other samplers + subtitle: JAX based samplers + href: alternative_samplers.ipynb \ No newline at end of file diff --git a/docs/notebooks/getting_started.ipynb b/docs/notebooks/getting_started.ipynb index ce5954227..adf071611 100644 --- a/docs/notebooks/getting_started.ipynb +++ b/docs/notebooks/getting_started.ipynb @@ -970,35 +970,35 @@ "\n", "
\n", "\n", - "|Family name |Response distribution | Default link |\n", - "|:----------------------------- |:------------------------------- |:--------------- |\n", - "asymmetriclaplace | AsymmetricLaplace | identity |\n", - "bernoulli | Bernoulli | logit |\n", - "beta | Beta | logit |\n", - "beta_binomial | BetaBinomial | logit |\n", - "binomial | Binomial | logit | \n", - "categorical | Categorical | softmax | \n", - "cumulative | Cumulative | logit | \n", - "dirichlet_multinomial | DirichletMultinomial | logit |\n", - "exponential | Exponential | log | \n", - "gamma | Gamma | inverse |\n", - "gaussian | Normal | identity |\n", - "hurdle_gamma | HurdleGamma | log |\n", - "hurdle_lognormal | HurdleLogNormal | identity |\n", - "hurdle_negativebinomial | HurdleNegativeBinomial | log |\n", - "hurdle_poisson | HurdlePoisson | log |\n", - "multinomial | Multinomial | softmax |\n", - "negativebinomial | NegativeBinomial | log |\n", - "laplace | Laplace | identity |\n", - "poisson | Poisson | log |\n", - "sratio | StoppingRatio | logit |\n", - "t | StudentT | identity |\n", - "vonmises | VonMises | tan(x / 2) |\n", - "wald | InverseGaussian | inverse squared |\n", - "weibull | Weibull | log |\n", - "zero_inflated_binomial | ZeroInflatedBinomial | logit |\n", - "zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log |\n", - "zero_inflated_poisson | ZeroInflatedPoisson | log |\n", + "|Family name |Response distribution | Default link | Example notebook |\n", + "|:----------------------------- |:------------------------------- |:--------------- |:-----------------|\n", + "asymmetriclaplace | AsymmetricLaplace | identity | [Quantile Regression](https://bambinos.github.io/bambi/notebooks/quantile_regression.html#quantile-regression) |\n", + "bernoulli | Bernoulli | logit | [Logistic Regression](https://bambinos.github.io/bambi/notebooks/logistic_regression.html) |\n", + "beta | Beta | logit | [Beta Regression](https://bambinos.github.io/bambi/notebooks/beta_regression.html) |\n", + "beta_binomial | BetaBinomial | logit | _To be added_ |\n", + "binomial | Binomial | logit | [Hierarchical Logistic Regression](https://bambinos.github.io/bambi/notebooks/hierarchical_binomial_bambi.html) | \n", + "categorical | Categorical | softmax | [Categorical Regression](https://bambinos.github.io/bambi/notebooks/categorical_regression.html) | \n", + "cumulative | Cumulative | logit | [Ordinal Models](https://bambinos.github.io/bambi/notebooks/ordinal_regression.html#cumulative-model) | \n", + "dirichlet_multinomial | DirichletMultinomial | logit | _To be added_ |\n", + "exponential | Exponential | log | [Survival Models](https://bambinos.github.io/bambi/notebooks/survival_model.html#survival-models) | \n", + "gamma | Gamma | inverse | [Gamma Regression](https://bambinos.github.io/bambi/notebooks/wald_gamma_glm.html) |\n", + "gaussian | Normal | identity | [Multiple Linear Regression](https://bambinos.github.io/bambi/notebooks/ESCS_multiple_regression.html) |\n", + "hurdle_gamma | HurdleGamma | log | _To be added_ |\n", + "hurdle_lognormal | HurdleLogNormal | identity | _To be added_ |\n", + "hurdle_negativebinomial | HurdleNegativeBinomial | log | _To be added_ |\n", + "hurdle_poisson | HurdlePoisson | log | [Hurdle Poisson Regression](https://bambinos.github.io/bambi/notebooks/zero_inflated_regression.html#hurdle-poisson) |\n", + "multinomial | Multinomial | softmax | _To be added_ |\n", + "negativebinomial | NegativeBinomial | log | [Negative Binomial Regression](https://bambinos.github.io/bambi/notebooks/negative_binomial.html) |\n", + "laplace | Laplace | identity | _To be added_ |\n", + "poisson | Poisson | log | [Gaussian Processes with a Poisson likelihood](https://bambinos.github.io/bambi/notebooks/hsgp_2d.html#a-more-complex-example-poisson-likelihood-with-group-specific-effects) |\n", + "sratio | StoppingRatio | logit | [Ordinal Models](https://bambinos.github.io/bambi/notebooks/ordinal_regression.html#sequential-model) |\n", + "t | StudentT | identity | [Robust Linear Regression](https://bambinos.github.io/bambi/notebooks/t_regression.html) |\n", + "vonmises | VonMises | tan(x / 2) | [Circular Regression](https://bambinos.github.io/bambi/notebooks/circular_regression.html#circular-regression) |\n", + "wald | InverseGaussian | inverse squared | [Wald Regression](https://bambinos.github.io/bambi/notebooks/wald_gamma_glm.html) |\n", + "weibull | Weibull | log | _To be added_ |\n", + "zero_inflated_binomial | ZeroInflatedBinomial | logit | _To be added_ |\n", + "zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log | _To be added_ |\n", + "zero_inflated_poisson | ZeroInflatedPoisson | log | [Zero Inflated Poisson Regression](https://bambinos.github.io/bambi/notebooks/zero_inflated_regression.html#zero-inflated-poisson)|\n", "\n", "\n", "
\n", diff --git a/pyproject.toml b/pyproject.toml index 058d94b04..262482de1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires = ["setuptools>=61.0", "setuptools_scm>=8"] [project] name = "bambi" description = "BAyesian Model Building Interface in Python" -requires-python = ">=3.8" +requires-python = ">=3.10,<3.13" readme = "README.md" license = {file = "LICENSE"} dynamic = ["version"] @@ -18,24 +18,25 @@ maintainers = [ dependencies = [ "arviz>=0.12.0", - "formulae>=0.5.0", + "formulae>=0.5.3", "graphviz", "pandas>=1.0.0", - "pymc>=5.5.0", + "pymc>=5.12.0", ] [project.optional-dependencies] dev = [ - "black==22.3.0", + "black==24.3.0", "ipython>=5.8.0,!=8.7.0", "pre-commit>=2.19", - "pylint==2.17.5", + "pylint==3.1.0", "pytest-cov>=2.6.1", "pytest>=4.4.0", "quartodoc==0.6.1", "seaborn>=0.9.0", ] jax = [ + "bayeux-ml>=0.1.9", "blackjax>=1.0.0", "jax>=0.3.1", "jaxlib>=0.3.1", @@ -62,4 +63,4 @@ packages = [ [tool.black] line-length = 100 -target-version = ["py39", "py310"] \ No newline at end of file +target-version = ["py310"] \ No newline at end of file diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index a16134762..6222f3df3 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -1,10 +1,15 @@ import bambi as bmb +import bayeux as bx import numpy as np import pandas as pd import pytest +MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__] +MCMC_METHODS_FILTERED = [i for i in MCMC_METHODS if not any(x in i for x in ("flowmc", "chees", "meads"))] + + @pytest.fixture(scope="module") def data_n100(): size = 100 @@ -51,28 +56,14 @@ def test_vi(): (mode_n.item(), std_n.item()), (mode_a.item(), std_a.item()), decimal=2 ) - -@pytest.mark.parametrize( - "args", - [ - ("mcmc", {}), - ("nuts_numpyro", {"chain_method": "vectorized"}), - ("nuts_blackjax", {"chain_method": "vectorized"}), - ], -) -def test_logistic_regression_categoric_alternative_samplers(data_n100, args): +# +@pytest.mark.parametrize("sampler", MCMC_METHODS_FILTERED) +def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler): model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli") - model.fit(tune=50, draws=50, inference_method=args[0], **args[1]) + model.fit(inference_method=sampler) -@pytest.mark.parametrize( - "args", - [ - ("mcmc", {}), - ("nuts_numpyro", {"chain_method": "vectorized"}), - ("nuts_blackjax", {"chain_method": "vectorized"}), - ], -) -def test_regression_alternative_samplers(data_n100, args): +@pytest.mark.parametrize("sampler", MCMC_METHODS) +def test_regression_alternative_samplers(data_n100, sampler): model = bmb.Model("n1 ~ n2", data_n100) - model.fit(tune=50, draws=50, inference_method=args[0], **args[1]) + model.fit(inference_method=sampler) diff --git a/tests/test_hsgp.py b/tests/test_hsgp.py index 30bf5ce1c..770c70cc5 100644 --- a/tests/test_hsgp.py +++ b/tests/test_hsgp.py @@ -300,3 +300,35 @@ def test_minimal_1d_predicts(data_1d_single_group): new_idata = model.predict(idata, data=new_data, kind="pps", inplace=False) assert new_idata.posterior_predictive["y"].dims == ("chain", "draw", "y_obs") assert new_idata.posterior_predictive["y"].to_numpy().shape == (2, 500, 10) + + +def test_multiple_hsgp_and_by(data_1d_multiple_groups): + rng = np.random.default_rng(1234) + df = data_1d_multiple_groups.copy() + df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0]) + + formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)" + model = bmb.Model( + formula=formula, + data=df, + categorical=["fac"], + ) + idata = model.fit(tune=400, draws=200, target_accept=0.9) + + bmb.interpret.plot_predictions( + model, + idata, + conditional="x1", + subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"}, + ); + + bmb.interpret.plot_predictions( + model, + idata, + conditional={ + "x1": np.linspace(0, 1, num=100), + "fac2": ["a", "b", "c"] + }, + legend=False, + subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"}, + ); \ No newline at end of file