Skip to content

Commit b5aefcf

Browse files
Convenient function to access inference methods and kwargs (#795)
* add inference_methods class to obtain names of methods and kwargs * re-run notebook * update notebook to include new methods * convienent methods for getting inference names and kwargs * Fix `get_model_covariates()` utility function (#801) * Support PyMC 5.13 and fix bayeux related issues (#803) * run black to fix formatting * add test to check for inference method names * test get_kwargs method of InferenceMethods class --------- Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
1 parent 793be6a commit b5aefcf

File tree

7 files changed

+1286
-996
lines changed

7 files changed

+1286
-996
lines changed

bambi/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pymc import math
66

7-
from .backend import PyMCModel
7+
from .backend import inference_methods, PyMCModel
88
from .config import config
99
from .data import clear_data_home, load_data
1010
from .families import Family, Likelihood, Link
@@ -25,6 +25,7 @@
2525
"Formula",
2626
"clear_data_home",
2727
"config",
28+
"inference_methods",
2829
"load_data",
2930
"math",
3031
]

bambi/backend/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .pymc import PyMCModel
2+
from .inference_methods import inference_methods
23

3-
__all__ = ["PyMCModel"]
4+
__all__ = ["inference_methods", "PyMCModel"]

bambi/backend/inference_methods.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import importlib
2+
import inspect
3+
import operator
4+
5+
import pymc as pm
6+
7+
8+
class InferenceMethods:
9+
"""Obtain a dictionary of available inference methods for Bambi
10+
models and or the default kwargs of each inference method.
11+
"""
12+
13+
def __init__(self):
14+
# In order to access inference methods, a bayeux model must be initialized
15+
self.bayeux_model = bayeux_model()
16+
self.bayeux_methods = self._get_bayeux_methods(bayeux_model())
17+
self.pymc_methods = self._pymc_methods()
18+
19+
def _get_bayeux_methods(self, model):
20+
# Bambi only supports bayeux MCMC methods
21+
mcmc_methods = model.methods.get("mcmc")
22+
return {"mcmc": mcmc_methods}
23+
24+
def _pymc_methods(self):
25+
return {"mcmc": ["mcmc"], "vi": ["vi"]}
26+
27+
def _remove_parameters(self, fn_signature_dict):
28+
# Remove 'pm.sample' parameters that are irrelevant for Bambi users
29+
params_to_remove = [
30+
"progressbar",
31+
"progressbar_theme",
32+
"var_names",
33+
"nuts_sampler",
34+
"return_inferencedata",
35+
"idata_kwargs",
36+
"callback",
37+
"mp_ctx",
38+
"model",
39+
]
40+
return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove}
41+
42+
def get_kwargs(self, method):
43+
"""Get the default kwargs for a given inference method.
44+
45+
Parameters
46+
----------
47+
method : str
48+
The name of the inference method.
49+
50+
Returns
51+
-------
52+
dict
53+
The default kwargs for the inference method.
54+
"""
55+
if method in self.bayeux_methods.get("mcmc"):
56+
bx_method = operator.attrgetter(method)(
57+
self.bayeux_model.mcmc # pylint: disable=no-member
58+
)
59+
return bx_method.get_kwargs()
60+
elif method in self.pymc_methods.get("mcmc"):
61+
return self._remove_parameters(get_default_signature(pm.sample))
62+
elif method in self.pymc_methods.get("vi"):
63+
return get_default_signature(pm.ADVI.fit)
64+
else:
65+
raise ValueError(
66+
f"Inference method '{method}' not found in the list of available"
67+
" methods. Use `bmb.inference_methods.names` to list the available methods."
68+
)
69+
70+
@property
71+
def names(self):
72+
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}
73+
74+
75+
def bayeux_model():
76+
"""Dummy bayeux model for obtaining inference methods.
77+
78+
A dummy model is needed because algorithms are dynamically determined at
79+
runtime, based on the libraries that are installed. A model can give
80+
programmatic access to the available algorithms via the `methods` attribute.
81+
82+
Returns
83+
-------
84+
bayeux.Model
85+
A dummy model with a simple quadratic likelihood function.
86+
"""
87+
if importlib.util.find_spec("bayeux") is None:
88+
return {"mcmc": []}
89+
90+
import bayeux as bx # pylint: disable=import-outside-toplevel
91+
92+
return bx.Model(lambda x: -(x**2), 0.0)
93+
94+
95+
def get_default_signature(fn):
96+
"""Get the default parameter values of a function.
97+
98+
This function inspects the signature of the provided function and returns
99+
a dictionary containing the default values of its parameters.
100+
101+
Parameters
102+
----------
103+
fn : callable
104+
The function for which default argument values are to be retrieved.
105+
106+
Returns
107+
-------
108+
dict
109+
A dictionary mapping argument names to their default values.
110+
111+
"""
112+
defaults = {}
113+
for key, val in inspect.signature(fn).parameters.items():
114+
if val.default is not inspect.Signature.empty:
115+
defaults[key] = val.default
116+
return defaults
117+
118+
119+
inference_methods = InferenceMethods()

bambi/backend/pymc.py

+4-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import importlib
32
import logging
43
import operator
54
import traceback
@@ -14,6 +13,7 @@
1413
import pytensor.tensor as pt
1514
from pytensor.tensor.special import softmax
1615

16+
from bambi.backend.inference_methods import inference_methods
1717
from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
1818
from bambi.backend.model_components import ConstantComponent, DistributionalComponent
1919
from bambi.utils import get_aliased_name
@@ -47,8 +47,8 @@ def __init__(self):
4747
self.model = None
4848
self.spec = None
4949
self.components = {}
50-
self.bayeux_methods = _get_bayeux_methods()
51-
self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]}
50+
self.bayeux_methods = inference_methods.names["bayeux"]
51+
self.pymc_methods = inference_methods.names["pymc"]
5252

5353
def build(self, spec):
5454
"""Compile the PyMC model from an abstract model specification.
@@ -348,8 +348,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean):
348348
349349
Mainly for pedagogical use, provides reasonable results for approximately
350350
Gaussian posteriors. The approximation can be very poor for some models
351-
like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods
352-
for better approximations.
351+
like hierarchical ones. Use MCMC or VI methods for better approximations.
353352
354353
Parameters
355354
----------
@@ -398,10 +397,6 @@ def constant_components(self):
398397
def distributional_components(self):
399398
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}
400399

401-
@property
402-
def inference_methods(self):
403-
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}
404-
405400

406401
def _posterior_samples_to_idata(samples, model):
407402
"""Create InferenceData from samples.
@@ -441,22 +436,3 @@ def _posterior_samples_to_idata(samples, model):
441436

442437
idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model)
443438
return idata
444-
445-
446-
def _get_bayeux_methods():
447-
"""Gets a dictionary of usable bayeux methods if the bayeux package is installed
448-
within the user's environment.
449-
450-
Returns
451-
-------
452-
dict
453-
A dict where the keys are the module names and the values are the methods
454-
available in that module.
455-
"""
456-
if importlib.util.find_spec("bayeux") is None:
457-
return {"mcmc": []}
458-
459-
import bayeux as bx # pylint: disable=import-outside-toplevel
460-
461-
# Dummy log density to get access to all methods
462-
return bx.Model(lambda x: -(x**2), 0.0).methods

bambi/models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def fit(
267267
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
268268
recommended other than for pedagogical use.
269269
To get a list of JAX based inference methods, call
270-
``model.backend.inference_methods['bayeux']``. This will return a dictionary of the
270+
``bmb.inference_methods.names['bayeux']``. This will return a dictionary of the
271271
available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others.
272272
init : str
273273
Initialization method. Defaults to ``"auto"``. The available methods are:
@@ -307,7 +307,7 @@ def fit(
307307
-------
308308
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
309309
"laplace", or one of the MCMC methods in
310-
``model.backend.inference_methods['bayeux']['mcmc]``.
310+
``bmb.inference_methods.names['bayeux']['mcmc]``.
311311
An ``Approximation`` object if ``"vi"``.
312312
"""
313313
method = kwargs.pop("method", None)

0 commit comments

Comments
 (0)