Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight_predictions function #2147

Merged
merged 9 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Change Log


## v0.x.x Unreleased

### New features
- Adds Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152))

* Add `weight_predictions` function to allow generation of weighted predictions from two or more InfereceData with `posterior_predictive` groups and a set of weights ([2147](https://github.com/arviz-devs/arviz/pull/2147))
- Add Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152

### Maintenance and fixes
- Fix dimension ordering for `plot_trace` with divergences ([2151](https://github.com/arviz-devs/arviz/pull/2151))
Expand Down
1 change: 1 addition & 0 deletions arviz/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"r2_score",
"summary",
"waic",
"weight_predictions",
"ELPDData",
"ess",
"rhat",
Expand Down
72 changes: 71 additions & 1 deletion arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
NO_GET_ARGS = True

from .. import _log
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data, extract
from ..rcparams import rcParams, ScaleKeyword, ICKeyword
from ..utils import Numba, _numba_var, _var_names, get_coords
from .density_utils import get_bins as _get_bins
Expand Down Expand Up @@ -49,6 +49,7 @@
"r2_score",
"summary",
"waic",
"weight_predictions",
"_calculate_ics",
]

Expand Down Expand Up @@ -2043,3 +2044,72 @@ def apply_test_function(
setattr(out, grp, out_group)

return out


def weight_predictions(idatas, weights=None):
"""
Generate weighted posterior predictive samples from a list of InferenceData
and a set of weights.

Parameters
---------
idatas : list[InferenceData]
List of :class:`arviz.InferenceData` objects containing the groups `posterior_predictive`
and `observed_data`. Observations should be the same for all InferenceData objects.
weights : array-like, optional
Individual weights for each model. Weights should be positive. If they do not sum up to 1,
they will be normalized. Default, same weight for each model.
Weights can be computed using many different methods including those in
:func:`arviz.compare`.

Returns
-------
idata: InferenceData
Output InferenceData object with the groups `posterior_predictive` and `observed_data`.

See Also
--------
compare : Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation
"""
if len(idatas) < 2:
raise ValueError("You should provide a list with at least two InferenceData objects")

if not all("posterior_predictive" in idata.groups() for idata in idatas):
raise ValueError(
"All the InferenceData objects must contain the `posterior_predictive` group"
)

if not all(idatas[0].observed_data.equals(idata.observed_data) for idata in idatas[1:]):
raise ValueError("The observed data should be the same for all InferenceData objects")

if weights is None:
weights = np.ones(len(idatas)) / len(idatas)
elif len(idatas) != len(weights):
raise ValueError(
"The number of weights should be the same as the number of InferenceData objects"
)

weights = np.array(weights, dtype=float)
weights /= weights.sum()
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved

len_idatas = [
idata.posterior_predictive.dims["chain"] * idata.posterior_predictive.dims["draw"]
for idata in idatas
]

if not all(len_idatas):
raise ValueError("At least one of your idatas has 0 samples")

new_samples = (np.min(len_idatas) * weights).astype(int)

aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
new_idatas = [
extract(idata, group="posterior_predictive", num_samples=samples).reset_coords()
for samples, idata in zip(new_samples, idatas)
]

weighted_samples = InferenceData(
posterior_predictive=xr.concat(new_idatas, dim="sample"),
observed_data=idatas[0].observed_data,
)

return weighted_samples
31 changes: 30 additions & 1 deletion arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_almost_equal,
assert_array_equal,
)
from scipy.special import logsumexp
from scipy.stats import linregress
from xarray import DataArray, Dataset
Expand All @@ -21,6 +26,7 @@
r2_score,
summary,
waic,
weight_predictions,
_calculate_ics,
)
from ...stats.stats import _gpinv
Expand Down Expand Up @@ -800,3 +806,26 @@ def test_apply_test_function_should_overwrite_error(centered_eight):
"""Test error when overwrite=False but out_name is already a present variable."""
with pytest.raises(ValueError, match="Should overwrite"):
apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs")


def test_weight_predictions():
idata0 = from_dict(
posterior_predictive={"a": np.random.normal(-1, 1, 1000)}, observed_data={"a": [1]}
)
idata1 = from_dict(
posterior_predictive={"a": np.random.normal(1, 1, 1000)}, observed_data={"a": [1]}
)

new = weight_predictions([idata0, idata1])
assert (
idata1.posterior_predictive.mean()
> new.posterior_predictive.mean()
> idata0.posterior_predictive.mean()
)
assert "posterior_predictive" in new
assert "observed_data" in new

new = weight_predictions([idata0, idata1], weights=[0.5, 0.5])
assert_almost_equal(new.posterior_predictive["a"].mean(), 0, decimal=1)
new = weight_predictions([idata0, idata1], weights=[0.9, 0.1])
assert_almost_equal(new.posterior_predictive["a"].mean(), -0.8, decimal=1)