Skip to content

Commit

Permalink
feat: deviance and Pearson residuals of LOO version (#103)
Browse files Browse the repository at this point in the history
* feat: deviance and Pearson residuals of LOO version

* refactor: reduce MCMC steps to save memory
  • Loading branch information
wcxve authored Aug 27, 2024
1 parent 2a4e3da commit 936f3f3
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 56 deletions.
26 changes: 13 additions & 13 deletions src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ class BayesFit(Fit):
def nuts(
self,
warmup=2000,
samples=20000,
steps=5000,
chains: int | None = None,
init: dict[str, float] | None = None,
chain_method: str = 'parallel',
Expand All @@ -541,8 +541,8 @@ def nuts(
----------
warmup : int, optional
Number of warmup steps.
samples : int, optional
Number of samples to generate from each chain.
steps : int, optional
Number of steps to run for each chain.
chains : int, optional
Number of MCMC chains to run. If there are not enough devices
available, chains will run in sequence. Defaults to the number of
Expand Down Expand Up @@ -573,11 +573,11 @@ def nuts(
else:
chains = int(chains)

samples = int(samples)
steps = int(steps)

# the total samples number should be multiple of the device number
if chains * samples % device_count != 0:
samples += device_count - samples % device_count
if chains * steps % device_count != 0:
steps += device_count - steps % device_count

# TODO: option to let sampler starting from MLE
if init is None:
Expand All @@ -597,7 +597,7 @@ def nuts(
sampler = MCMC(
NUTS(**nuts_kwargs),
num_warmup=warmup,
num_samples=samples,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
Expand Down Expand Up @@ -704,7 +704,7 @@ def jaxns(

def ultranest(
self,
ess: int = 10000,
ess: int = 3000,
ignore_nan: bool = True,
*,
constructor_kwargs: dict | None = None,
Expand Down Expand Up @@ -814,7 +814,7 @@ def transform_(samples):

def nautilus(
self,
ess: int = 10000,
ess: int = 3000,
ignore_nan: bool = True,
parallel: bool = True,
n_batch: int = 5000,
Expand Down Expand Up @@ -923,7 +923,7 @@ def transform_(samples):
def aies(
self,
warmup=2000,
samples=20000,
steps=5000,
chains: int | None = None,
init: dict[str, float] | None = None,
chain_method: str = 'vectorized',
Expand All @@ -947,8 +947,8 @@ def aies(
----------
warmup : int, optional
Number of warmup steps.
samples : int, optional
Number of samples to generate from each chain.
steps : int, optional
Number of steps to run for each chain.
chains : int, optional
Number of MCMC chains to run. Defaults to 4 * `D`, where `D` is
the dimension of model parameters.
Expand Down Expand Up @@ -1005,7 +1005,7 @@ def aies(
sampler = MCMC(
AIES(**aies_kwargs),
num_warmup=warmup,
num_samples=samples,
num_samples=steps,
num_chains=chains,
chain_method=chain_method,
progress_bar=progress,
Expand Down
78 changes: 56 additions & 22 deletions src/elisa/infer/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from astropy.cosmology.flrw.lambdacdm import LambdaCDM
from astropy.units import Quantity as Q
from iminuit.util import FMin
from xarray import DataArray

from elisa.infer.fit import BayesFit
from elisa.infer.helper import Helper
Expand Down Expand Up @@ -868,6 +869,7 @@ class PosteriorResult(FitResult):
_deviance: dict | None = None
_mle_result: dict | None = None
_ppc: PPCResult | None = None
_psislw_: DataArray | None = None
_loo: az.stats.stats_utils.ELPDData | None = None
_waic: az.stats.stats_utils.ELPDData | None = None
_rhat: dict[str, float] | None = None
Expand Down Expand Up @@ -942,10 +944,10 @@ def _tabs(self):
rows = [
[
k,
f'{mean[k]:.4g}',
f'{std[k]:.4g}',
f'{median[k]:.4g}',
f'[{ci[k][0]:.4g}, {ci[k][1]:.4g}]',
f'{mean[k]:.3g}',
f'{std[k]:.3g}',
f'{median[k]:.3g}',
f'[{ci[k][0]:.3g}, {ci[k][1]:.3g}]',
f'{ess[k]}',
f'{rhat[k]:.2f}' if not np.isnan(rhat[k]) else 'N/A',
]
Expand All @@ -954,9 +956,9 @@ def _tabs(self):
names = [
'Parameter',
'Mean',
'Std',
'StdDev',
'Median',
'68.3% Quantile CI',
'68.3% Quantile',
'ESS',
'Rhat',
]
Expand Down Expand Up @@ -1762,16 +1764,10 @@ def rhat(self) -> dict[str, float]:
if self._rhat is None:
params_names = self._helper.params_names['all']
posterior = self.idata['posterior'][params_names]

if len(posterior['chain']) == 1:
rhat = {k: float('nan') for k in posterior.data_vars.keys()}
else:
rhat = {
k: float(v.values)
for k, v in az.rhat(posterior).data_vars.items()
}

self._rhat = rhat
self._rhat = {
k: float(v.values)
for k, v in az.rhat(posterior).data_vars.items()
}

return self._rhat

Expand Down Expand Up @@ -1838,24 +1834,62 @@ def lnZ(self) -> tuple[float, float] | tuple[None, None]:
"""Log model evidence and uncertainty."""
return self._lnZ

@property
def _psislw(self) -> DataArray:
if self._psislw_ is None:
idata = self.idata
reff = self.reff
stack_kwargs = {'__sample__': ('chain', 'draw')}
log_weights, kss = az.psislw(
-idata['log_likelihood']['channels'].stack(**stack_kwargs),
reff,
)
self._psislw_ = log_weights
return self._psislw_

def _loo_expectation(self, values: DataArray, data: str) -> DataArray:
"""Computes weighted expectations using the PSIS weights.
Notes
-----
The expectations estimated assume that the PSIS approximation is
working well. A small Pareto k estimate is necessary, but not
sufficient to give reliable estimates.
Parameters
----------
values : DataArray
Values to compute the expectation.
data : str
The data name.
Returns
-------
DataArray
The expectation of the values.
"""
assert data in self._helper.data_names
channel = self._helper.channels[f'{data}_channel']
log_weights = self._psislw.sel(channel=channel)
log_weights = log_weights.rename({'channel': f'{data}_channel'})
log_expectation = log_weights + np.log(np.abs(values))
weighted = np.sign(values) * np.exp(log_expectation)
return weighted.sum(dim='__sample__')

@property
def _loo_pit(self) -> dict[str, tuple]:
"""Leave-one-out probability integral transform."""
if self._pit is not None:
return self._pit

idata = self.idata
reff = self.reff
helper = self._helper
stack_kwargs = {'__sample__': ('chain', 'draw')}
log_weights, kss = az.psislw(
-idata['log_likelihood']['channels'].stack(**stack_kwargs), reff
)
y_hat = idata['posterior_predictive']['channels'].stack(**stack_kwargs)
loo_pit = az.loo_pit(
y=idata['observed_data']['channels'],
y_hat=y_hat,
log_weights=log_weights,
log_weights=self._psislw,
)

loo_pit = {
Expand Down Expand Up @@ -1892,7 +1926,7 @@ def _loo_pit(self) -> dict[str, tuple]:
loo_pit_minus = az.loo_pit(
y=y_miuns,
y_hat=y_hat,
log_weights=log_weights,
log_weights=self._psislw,
)
loo_pit_minus = {
name: loo_pit_minus.sel(channel=data.channel).values
Expand Down
Loading

0 comments on commit 936f3f3

Please sign in to comment.