Skip to content

Commit

Permalink
Add plot_ppc_dist (#138)
Browse files Browse the repository at this point in the history
* add plot_ppc_dist

* remove comments

* add test and small fixes

* fix typo
  • Loading branch information
aloctavodia authored Feb 19, 2025
1 parent 4c76c7d commit 9ef8965
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at
plot_ess_evolution
plot_forest
plot_pava_calibration
plot_ppc_dist
plot_psense_dist
plot_psense_quantities
plot_ridge
Expand Down
25 changes: 25 additions & 0 deletions docs/source/gallery/model_criticism/plot_ppc_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
# Posterior Predictive Checks
Plot of samples from the posterior predictive and observed data.
---
:::{seealso}
API Documentation: {func}`~arviz_plots.plot_ppc_dist`
:::
"""
from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

dt = load_arviz_data("rugby")
pc = azp.plot_ppc_dist(
dt,
pc_kwargs={"aes": {"color": ["__variable__"]}}, # map color to variable
aes_map={"title": ["color"]}, # also map color to title
backend="none",
)
pc.show()
2 changes: 2 additions & 0 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .evolutionplot import plot_ess_evolution
from .forestplot import plot_forest
from .pavacalibrationplot import plot_pava_calibration
from .ppcdistplot import plot_ppc_dist
from .psensedistplot import plot_psense_dist
from .psensequantitiesplot import plot_psense_quantities
from .ridgeplot import plot_ridge
Expand All @@ -24,6 +25,7 @@
"plot_energy",
"plot_ess",
"plot_ess_evolution",
"plot_ppc_dist",
"plot_ridge",
"plot_pava_calibration",
"plot_psense_dist",
Expand Down
282 changes: 282 additions & 0 deletions src/arviz_plots/plots/ppcdistplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
"""Posterior/prior predictive check using densities."""
import warnings
from copy import copy
from importlib import import_module

import numpy as np
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection, process_facet_dims
from arviz_plots.plots.distplot import plot_dist
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
from arviz_plots.visuals import ecdf_line, hist, line_xy


def plot_ppc_dist(
dt,
var_names=None,
filter_vars=None,
group="posterior_predictive",
coords=None,
sample_dims=None,
kind=None,
num_samples=50,
plot_collection=None,
backend=None,
labeller=None,
aes_map=None,
plot_kwargs=None,
stats_kwargs=None,
pc_kwargs=None,
):
"""
Plot 1D marginals for the posterior/prior predictive distribution and the observed data.
Parameters
----------
dt : DataTree
Input data
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, “like”, “regex”}, default=None
If None, interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str,
Group to be plotted. Defaults to "posterior_predictive".
It could also be "prior_predictive".
coords : dict, optional
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"kde", "hist", "ecdf"}, optional
How to represent the marginal density.
Defaults to ``rcParams["plot.density_kind"]``
num_samples : int, optional
Number of samples to plot. Defaults to 100.
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_map : mapping of {str : sequence of str}, optional
Mapping of artists to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `plot_kwargs`.
With a single model, no aesthetic mappings are generated by default,
each variable+coord combination gets a :term:`plot` but they all look the same,
unless there are user provided aesthetic mappings.
With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension.
By default, all aesthetics but "y" are mapped to the density representation,
and if multiple models are present, "color" and "y" are mapped to the
credible interval and the point estimate.
When "point_estimate" key is provided but "point_estimate_text" isn't,
the values assigned to the first are also used for the second.
plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:
* predictive_density -> passed to a function that depends on the `kind` argument.
* observed_density -> passed to a function that depends on the `kind` argument.
* `kind="kde"` -> passed to :func:`~arviz_plots.visuals.line_xy`
* `kind="ecdf"` -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* `kind="hist"` -> passed to :func: `~arviz_plots.visuals.hist`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function
stats_kwargs : mapping, optional
Valid keys are:
* density -> passed to kde, ecdf, ...
Returns
-------
PlotCollection
See Also
--------
:ref:`plots_intro` :
General introduction to batteries-included plotting functions, common use and logic overview
Examples
--------
Make a plot of the posterior predictive distribution vs the observed data.
We used an ECDF representation customized the colors.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_ppc_dist, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> radon = load_arviz_data('radon')
>>> pc = plot_ppc_dist(
>>> radon,
>>> kind="ecdf",
>>> plot_kwargs={"predictive_density": {"color":"C1"},
>>> "observed_density": {"color":"C3"}},
>>> )
.. minigallery:: plot_ppc_dist
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
sample_dims = list(sample_dims)
if kind is None:
kind = rcParams["plot.density_kind"]
if stats_kwargs is None:
stats_kwargs = {}
else:
stats_kwargs = stats_kwargs.copy()
if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = plot_kwargs.copy()
if pc_kwargs is None:
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()

if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend

plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")

rng = np.random.default_rng(4214)

pp_dims = list(sample_dims) + [
dims for dims in dt.posterior_predictive.dims if dims not in sample_dims
]

distribution = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)

# Select a random subset of samples
n_pp_samples = np.prod(
[distribution.sizes[dim] for dim in sample_dims if dim in distribution.dims]
)
if num_samples > n_pp_samples:
num_samples = n_pp_samples
warnings.warn("num_samples is larger than the number of predictive samples.")

pp_sample_ix = rng.choice(n_pp_samples, size=num_samples, replace=False)
distribution = distribution.stack(sample=sample_dims).isel(sample=pp_sample_ix)

if plot_collection is None:
pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy()

pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("overlay", ["sample"])
pc_kwargs.setdefault("col_wrap", 5)
pc_kwargs.setdefault("cols", "__variable__")
pc_kwargs.setdefault("rows", None)

figsize = pc_kwargs["plot_grid_kws"].get("figsize", None)
figsize_units = pc_kwargs["plot_grid_kws"].get("figsize_units", "inches")
col_dims = pc_kwargs["cols"]
row_dims = pc_kwargs["rows"]
if figsize is None:
figsize = plot_bknd.scale_fig_size(
figsize,
rows=process_facet_dims(distribution, row_dims)[0],
cols=process_facet_dims(distribution, col_dims)[0],
figsize_units=figsize_units,
)
figsize_units = "dots"
pc_kwargs["plot_grid_kws"]["figsize"] = figsize
pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units

plot_collection = PlotCollection.grid(
distribution,
backend=backend,
**pc_kwargs,
)

if aes_map is None:
aes_map = {}
else:
aes_map = aes_map.copy()
if labeller is None:
labeller = BaseLabeller()

# We don't want credible_interval or point_estimate to be mapped to the density representation
plot_kwargs.setdefault("credible_interval", False)
plot_kwargs.setdefault("point_estimate", False)
plot_kwargs.setdefault("point_estimate_text", False)

# Plot the predictive density
pred_density_kwargs = copy(plot_kwargs.get("predictive_density", {}))
if pred_density_kwargs is not False:
plot_kwargs.setdefault(kind, pred_density_kwargs)
plot_kwargs[kind].setdefault("alpha", 0.3)
if kind == "hist":
if plot_kwargs["hist"] is not False:
plot_kwargs["hist"].setdefault("edgecolor", None)
stats_kwargs.setdefault("density", True)

plot_collection = plot_dist(
distribution,
group=group,
sample_dims=pp_dims,
kind=kind,
plot_kwargs=plot_kwargs,
aes_map=aes_map,
pc_kwargs=pc_kwargs,
plot_collection=plot_collection,
)

# Plot the observed density
observed_density_kwargs = copy(
plot_kwargs.get("observed_density", False if group == "prior_predictive" else {})
)

if observed_density_kwargs is not False:
observed_density_kwargs.setdefault("color", "black")
if kind == "hist":
observed_density_kwargs.setdefault("alpha", 0.3)
observed_density_kwargs.setdefault("edgecolor", None)
stats_kwargs.setdefault("density", True)

_, _, observed_ignore = filter_aes(
plot_collection, aes_map, "observed_density", sample_dims
)

if kind == "kde":
dt_observed = dt.observed_data.ds.azstats.kde(dims=pp_dims, **stats_kwargs)
plot_collection.map(
line_xy,
"observe_density",
data=dt_observed,
ignore_aes=observed_ignore,
**observed_density_kwargs,
)

if kind == "hist":
dt_observed = dt.observed_data.ds.azstats.histogram(dims=pp_dims, **stats_kwargs)
plot_collection.map(
hist,
"observe_density",
data=dt_observed,
ignore_aes=observed_ignore,
**observed_density_kwargs,
)

if kind == "ecdf":
dt_observed = dt.observed_data.ds.azstats.ecdf(**stats_kwargs)
plot_collection.map(
ecdf_line,
"observe_density",
data=dt_observed,
ignore_aes=observed_ignore,
**observed_density_kwargs,
)

return plot_collection
Loading

0 comments on commit 9ef8965

Please sign in to comment.