Skip to content

Commit 2287de4

Browse files
authored
Improve Laplace computation and integration (#555)
* improve laplace computation and integration * rename _posterior_to_idata and add docstring
1 parent d84080b commit 2287de4

File tree

3 files changed

+72
-33
lines changed

3 files changed

+72
-33
lines changed

bambi/backend/pymc.py

+67-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import traceback
3+
from copy import deepcopy
34

45
import numpy as np
56
import pymc as pm
@@ -107,7 +108,7 @@ def run(
107108
elif inference_method == "vi":
108109
result = self._run_vi(**kwargs)
109110
elif inference_method == "laplace":
110-
result = self._run_laplace()
111+
result = self._run_laplace(draws)
111112
else:
112113
raise NotImplementedError(f"{inference_method} method has not been implemented")
113114

@@ -437,44 +438,82 @@ def _run_vi(self, **kwargs):
437438
self.vi_approx = pm.fit(**kwargs)
438439
return self.vi_approx
439440

440-
def _run_laplace(self):
441+
def _run_laplace(self, draws):
441442
"""Fit a model using a Laplace approximation.
442443
443-
Mainly for pedagogical use. ``mcmc`` and ``vi`` are better approximations.
444+
Mainly for pedagogical use, provides reasonable results for approximately
445+
Gaussian posteriors. The approximation can be very poor for some models
446+
like hierarchical ones. Use ``mcmc``, ``nuts_numpyro``, ``nuts_blackjax``
447+
or ``vi`` for better approximations.
444448
445449
Parameters
446450
----------
447451
model: PyMC model
452+
draws: int
453+
The number of samples to draw from the posterior distribution.
448454
449455
Returns
450456
-------
451-
Dictionary, the keys are the names of the variables and the values tuples of modes and
452-
standard deviations.
457+
An ArviZ's InferenceData object.
453458
"""
454-
unobserved_rvs = self.model.unobserved_RVs
455-
test_point = self.model.initial_point(seed=None)
456459
with self.model:
457-
varis = [v for v in unobserved_rvs if not pm.util.is_transformed_name(v.name)]
458-
maps = pm.find_MAP(start=test_point, vars=varis)
459-
# Remove transform from the value variable associated with varis
460-
for var in varis:
461-
v_value = self.model.rvs_to_values[var]
462-
v_value.tag.transform = None
463-
hessian = pm.find_hessian(maps, vars=varis)
464-
if np.linalg.det(hessian) == 0:
465-
raise np.linalg.LinAlgError("Singular matrix. Use mcmc or vi method")
466-
stds = np.diag(np.linalg.inv(hessian) ** 0.5)
467-
maps = [v for (k, v) in maps.items() if not pm.util.is_transformed_name(k)]
468-
modes = [v.item() if v.size == 1 else v for v in maps]
469-
names = [v.name for v in varis]
470-
shapes = [np.atleast_1d(mode).shape for mode in modes]
471-
stds_reshaped = []
472-
idx0 = 0
473-
for shape in shapes:
474-
idx1 = idx0 + sum(shape)
475-
stds_reshaped.append(np.reshape(stds[idx0:idx1], shape))
476-
idx0 = idx1
477-
return dict(zip(names, zip(modes, stds_reshaped)))
460+
maps = pm.find_MAP()
461+
n_maps = deepcopy(maps)
462+
for m in maps:
463+
if pm.util.is_transformed_name(m):
464+
n_maps.pop(pm.util.get_untransformed_name(m))
465+
466+
hessian = pm.find_hessian(n_maps)
467+
468+
if np.linalg.det(hessian) == 0:
469+
raise np.linalg.LinAlgError("Singular matrix. Use mcmc or vi method")
470+
471+
cov = np.linalg.inv(hessian)
472+
modes = np.concatenate([np.atleast_1d(v) for v in n_maps.values()])
473+
474+
samples = np.random.multivariate_normal(modes, cov, size=draws)
475+
476+
return _posterior_samples_to_idata(samples, self.model)
477+
478+
479+
def _posterior_samples_to_idata(samples, model):
480+
"""Create InferenceData from samples.
481+
482+
Parameters
483+
----------
484+
samples: array
485+
Posterior samples
486+
model: PyMC model
487+
488+
Returns
489+
-------
490+
An ArviZ's InferenceData object.
491+
"""
492+
initial_point = model.initial_point(seed=None)
493+
variables = model.value_vars
494+
495+
var_info = {}
496+
for name, value in initial_point.items():
497+
var_info[name] = (value.shape, value.size)
498+
499+
length_pos = len(samples)
500+
varnames = [v.name for v in variables]
501+
502+
with model:
503+
strace = pm.backends.ndarray.NDArray(name=model.name) # pylint:disable=no-member
504+
strace.setup(length_pos, 0)
505+
for i in range(length_pos):
506+
value = []
507+
size = 0
508+
for varname in varnames:
509+
shape, new_size = var_info[varname]
510+
var_samples = samples[i][size : size + new_size]
511+
value.append(var_samples.reshape(shape))
512+
size += new_size
513+
strace.record(point=dict(zip(varnames, value)))
514+
515+
idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model)
516+
return idata
478517

479518

480519
def add_lkj(backend, terms, eta=1):

bambi/models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def fit(
248248
249249
Returns
250250
-------
251-
An ArviZ ``InferenceData`` instance if inference_method ``"mcmc"`` (default),
252-
"nuts_numpyro" or "nuts_blackjax".
253-
An ``Approximation`` object if ``"vi"`` and a dictionary if ``"laplace"``.
251+
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
252+
"nuts_numpyro", "nuts_blackjax" or "laplace".
253+
An ``Approximation`` object if ``"vi"``.
254254
"""
255255
method = kwargs.pop("method", None)
256256
if method is not None:

bambi/tests/test_built_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,8 @@ def test_laplace():
423423
priors = {"Intercept": Prior("Uniform", lower=0, upper=1)}
424424
model = Model("w ~ 1", data=data, family="bernoulli", priors=priors, link="identity")
425425
results = model.fit(inference_method="laplace")
426-
mode_n = np.round(results["Intercept"][0], 2)
427-
std_n = np.round(results["Intercept"][1][0], 2)
426+
mode_n = results.posterior["Intercept"].mean().item()
427+
std_n = results.posterior["Intercept"].std().item()
428428
mode_a = data.mean()
429429
std_a = data.std() / len(data) ** 0.5
430430
np.testing.assert_array_almost_equal((mode_n, std_n), (mode_a.item(), std_a.item()), decimal=2)

0 commit comments

Comments
 (0)