|
1 | 1 | import logging
|
2 | 2 | import traceback
|
| 3 | +from copy import deepcopy |
3 | 4 |
|
4 | 5 | import numpy as np
|
5 | 6 | import pymc as pm
|
@@ -107,7 +108,7 @@ def run(
|
107 | 108 | elif inference_method == "vi":
|
108 | 109 | result = self._run_vi(**kwargs)
|
109 | 110 | elif inference_method == "laplace":
|
110 |
| - result = self._run_laplace() |
| 111 | + result = self._run_laplace(draws) |
111 | 112 | else:
|
112 | 113 | raise NotImplementedError(f"{inference_method} method has not been implemented")
|
113 | 114 |
|
@@ -437,44 +438,82 @@ def _run_vi(self, **kwargs):
|
437 | 438 | self.vi_approx = pm.fit(**kwargs)
|
438 | 439 | return self.vi_approx
|
439 | 440 |
|
440 |
| - def _run_laplace(self): |
| 441 | + def _run_laplace(self, draws): |
441 | 442 | """Fit a model using a Laplace approximation.
|
442 | 443 |
|
443 |
| - Mainly for pedagogical use. ``mcmc`` and ``vi`` are better approximations. |
| 444 | + Mainly for pedagogical use, provides reasonable results for approximately |
| 445 | + Gaussian posteriors. The approximation can be very poor for some models |
| 446 | + like hierarchical ones. Use ``mcmc``, ``nuts_numpyro``, ``nuts_blackjax`` |
| 447 | + or ``vi`` for better approximations. |
444 | 448 |
|
445 | 449 | Parameters
|
446 | 450 | ----------
|
447 | 451 | model: PyMC model
|
| 452 | + draws: int |
| 453 | + The number of samples to draw from the posterior distribution. |
448 | 454 |
|
449 | 455 | Returns
|
450 | 456 | -------
|
451 |
| - Dictionary, the keys are the names of the variables and the values tuples of modes and |
452 |
| - standard deviations. |
| 457 | + An ArviZ's InferenceData object. |
453 | 458 | """
|
454 |
| - unobserved_rvs = self.model.unobserved_RVs |
455 |
| - test_point = self.model.initial_point(seed=None) |
456 | 459 | with self.model:
|
457 |
| - varis = [v for v in unobserved_rvs if not pm.util.is_transformed_name(v.name)] |
458 |
| - maps = pm.find_MAP(start=test_point, vars=varis) |
459 |
| - # Remove transform from the value variable associated with varis |
460 |
| - for var in varis: |
461 |
| - v_value = self.model.rvs_to_values[var] |
462 |
| - v_value.tag.transform = None |
463 |
| - hessian = pm.find_hessian(maps, vars=varis) |
464 |
| - if np.linalg.det(hessian) == 0: |
465 |
| - raise np.linalg.LinAlgError("Singular matrix. Use mcmc or vi method") |
466 |
| - stds = np.diag(np.linalg.inv(hessian) ** 0.5) |
467 |
| - maps = [v for (k, v) in maps.items() if not pm.util.is_transformed_name(k)] |
468 |
| - modes = [v.item() if v.size == 1 else v for v in maps] |
469 |
| - names = [v.name for v in varis] |
470 |
| - shapes = [np.atleast_1d(mode).shape for mode in modes] |
471 |
| - stds_reshaped = [] |
472 |
| - idx0 = 0 |
473 |
| - for shape in shapes: |
474 |
| - idx1 = idx0 + sum(shape) |
475 |
| - stds_reshaped.append(np.reshape(stds[idx0:idx1], shape)) |
476 |
| - idx0 = idx1 |
477 |
| - return dict(zip(names, zip(modes, stds_reshaped))) |
| 460 | + maps = pm.find_MAP() |
| 461 | + n_maps = deepcopy(maps) |
| 462 | + for m in maps: |
| 463 | + if pm.util.is_transformed_name(m): |
| 464 | + n_maps.pop(pm.util.get_untransformed_name(m)) |
| 465 | + |
| 466 | + hessian = pm.find_hessian(n_maps) |
| 467 | + |
| 468 | + if np.linalg.det(hessian) == 0: |
| 469 | + raise np.linalg.LinAlgError("Singular matrix. Use mcmc or vi method") |
| 470 | + |
| 471 | + cov = np.linalg.inv(hessian) |
| 472 | + modes = np.concatenate([np.atleast_1d(v) for v in n_maps.values()]) |
| 473 | + |
| 474 | + samples = np.random.multivariate_normal(modes, cov, size=draws) |
| 475 | + |
| 476 | + return _posterior_samples_to_idata(samples, self.model) |
| 477 | + |
| 478 | + |
| 479 | +def _posterior_samples_to_idata(samples, model): |
| 480 | + """Create InferenceData from samples. |
| 481 | +
|
| 482 | + Parameters |
| 483 | + ---------- |
| 484 | + samples: array |
| 485 | + Posterior samples |
| 486 | + model: PyMC model |
| 487 | +
|
| 488 | + Returns |
| 489 | + ------- |
| 490 | + An ArviZ's InferenceData object. |
| 491 | + """ |
| 492 | + initial_point = model.initial_point(seed=None) |
| 493 | + variables = model.value_vars |
| 494 | + |
| 495 | + var_info = {} |
| 496 | + for name, value in initial_point.items(): |
| 497 | + var_info[name] = (value.shape, value.size) |
| 498 | + |
| 499 | + length_pos = len(samples) |
| 500 | + varnames = [v.name for v in variables] |
| 501 | + |
| 502 | + with model: |
| 503 | + strace = pm.backends.ndarray.NDArray(name=model.name) # pylint:disable=no-member |
| 504 | + strace.setup(length_pos, 0) |
| 505 | + for i in range(length_pos): |
| 506 | + value = [] |
| 507 | + size = 0 |
| 508 | + for varname in varnames: |
| 509 | + shape, new_size = var_info[varname] |
| 510 | + var_samples = samples[i][size : size + new_size] |
| 511 | + value.append(var_samples.reshape(shape)) |
| 512 | + size += new_size |
| 513 | + strace.record(point=dict(zip(varnames, value))) |
| 514 | + |
| 515 | + idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model) |
| 516 | + return idata |
478 | 517 |
|
479 | 518 |
|
480 | 519 | def add_lkj(backend, terms, eta=1):
|
|
0 commit comments