-
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add plot_ppc_dist * remove comments * add test and small fixes * fix typo
- Loading branch information
1 parent
4c76c7d
commit 9ef8965
Showing
5 changed files
with
335 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
# Posterior Predictive Checks | ||
Plot of samples from the posterior predictive and observed data. | ||
--- | ||
:::{seealso} | ||
API Documentation: {func}`~arviz_plots.plot_ppc_dist` | ||
::: | ||
""" | ||
from arviz_base import load_arviz_data | ||
|
||
import arviz_plots as azp | ||
|
||
azp.style.use("arviz-variat") | ||
|
||
dt = load_arviz_data("rugby") | ||
pc = azp.plot_ppc_dist( | ||
dt, | ||
pc_kwargs={"aes": {"color": ["__variable__"]}}, # map color to variable | ||
aes_map={"title": ["color"]}, # also map color to title | ||
backend="none", | ||
) | ||
pc.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
"""Posterior/prior predictive check using densities.""" | ||
import warnings | ||
from copy import copy | ||
from importlib import import_module | ||
|
||
import numpy as np | ||
from arviz_base import rcParams | ||
from arviz_base.labels import BaseLabeller | ||
|
||
from arviz_plots.plot_collection import PlotCollection, process_facet_dims | ||
from arviz_plots.plots.distplot import plot_dist | ||
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords | ||
from arviz_plots.visuals import ecdf_line, hist, line_xy | ||
|
||
|
||
def plot_ppc_dist( | ||
dt, | ||
var_names=None, | ||
filter_vars=None, | ||
group="posterior_predictive", | ||
coords=None, | ||
sample_dims=None, | ||
kind=None, | ||
num_samples=50, | ||
plot_collection=None, | ||
backend=None, | ||
labeller=None, | ||
aes_map=None, | ||
plot_kwargs=None, | ||
stats_kwargs=None, | ||
pc_kwargs=None, | ||
): | ||
""" | ||
Plot 1D marginals for the posterior/prior predictive distribution and the observed data. | ||
Parameters | ||
---------- | ||
dt : DataTree | ||
Input data | ||
var_names : str or list of str, optional | ||
One or more variables to be plotted. | ||
Prefix the variables by ~ when you want to exclude them from the plot. | ||
filter_vars : {None, “like”, “regex”}, default=None | ||
If None, interpret var_names as the real variables names. | ||
If “like”, interpret var_names as substrings of the real variables names. | ||
If “regex”, interpret var_names as regular expressions on the real variables names. | ||
group : str, | ||
Group to be plotted. Defaults to "posterior_predictive". | ||
It could also be "prior_predictive". | ||
coords : dict, optional | ||
sample_dims : str or sequence of hashable, optional | ||
Dimensions to reduce unless mapped to an aesthetic. | ||
Defaults to ``rcParams["data.sample_dims"]`` | ||
kind : {"kde", "hist", "ecdf"}, optional | ||
How to represent the marginal density. | ||
Defaults to ``rcParams["plot.density_kind"]`` | ||
num_samples : int, optional | ||
Number of samples to plot. Defaults to 100. | ||
plot_collection : PlotCollection, optional | ||
backend : {"matplotlib", "bokeh"}, optional | ||
labeller : labeller, optional | ||
aes_map : mapping of {str : sequence of str}, optional | ||
Mapping of artists to aesthetics that should use their mapping in `plot_collection` | ||
when plotted. Valid keys are the same as for `plot_kwargs`. | ||
With a single model, no aesthetic mappings are generated by default, | ||
each variable+coord combination gets a :term:`plot` but they all look the same, | ||
unless there are user provided aesthetic mappings. | ||
With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension. | ||
By default, all aesthetics but "y" are mapped to the density representation, | ||
and if multiple models are present, "color" and "y" are mapped to the | ||
credible interval and the point estimate. | ||
When "point_estimate" key is provided but "point_estimate_text" isn't, | ||
the values assigned to the first are also used for the second. | ||
plot_kwargs : mapping of {str : mapping or False}, optional | ||
Valid keys are: | ||
* predictive_density -> passed to a function that depends on the `kind` argument. | ||
* observed_density -> passed to a function that depends on the `kind` argument. | ||
* `kind="kde"` -> passed to :func:`~arviz_plots.visuals.line_xy` | ||
* `kind="ecdf"` -> passed to :func:`~arviz_plots.visuals.ecdf_line` | ||
* `kind="hist"` -> passed to :func: `~arviz_plots.visuals.hist` | ||
* title -> passed to :func:`~arviz_plots.visuals.labelled_title` | ||
* remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function | ||
stats_kwargs : mapping, optional | ||
Valid keys are: | ||
* density -> passed to kde, ecdf, ... | ||
Returns | ||
------- | ||
PlotCollection | ||
See Also | ||
-------- | ||
:ref:`plots_intro` : | ||
General introduction to batteries-included plotting functions, common use and logic overview | ||
Examples | ||
-------- | ||
Make a plot of the posterior predictive distribution vs the observed data. | ||
We used an ECDF representation customized the colors. | ||
.. plot:: | ||
:context: close-figs | ||
>>> from arviz_plots import plot_ppc_dist, style | ||
>>> style.use("arviz-variat") | ||
>>> from arviz_base import load_arviz_data | ||
>>> radon = load_arviz_data('radon') | ||
>>> pc = plot_ppc_dist( | ||
>>> radon, | ||
>>> kind="ecdf", | ||
>>> plot_kwargs={"predictive_density": {"color":"C1"}, | ||
>>> "observed_density": {"color":"C3"}}, | ||
>>> ) | ||
.. minigallery:: plot_ppc_dist | ||
""" | ||
if sample_dims is None: | ||
sample_dims = rcParams["data.sample_dims"] | ||
if isinstance(sample_dims, str): | ||
sample_dims = [sample_dims] | ||
sample_dims = list(sample_dims) | ||
if kind is None: | ||
kind = rcParams["plot.density_kind"] | ||
if stats_kwargs is None: | ||
stats_kwargs = {} | ||
else: | ||
stats_kwargs = stats_kwargs.copy() | ||
if plot_kwargs is None: | ||
plot_kwargs = {} | ||
else: | ||
plot_kwargs = plot_kwargs.copy() | ||
if pc_kwargs is None: | ||
pc_kwargs = {} | ||
else: | ||
pc_kwargs = pc_kwargs.copy() | ||
|
||
if backend is None: | ||
if plot_collection is None: | ||
backend = rcParams["plot.backend"] | ||
else: | ||
backend = plot_collection.backend | ||
|
||
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") | ||
|
||
rng = np.random.default_rng(4214) | ||
|
||
pp_dims = list(sample_dims) + [ | ||
dims for dims in dt.posterior_predictive.dims if dims not in sample_dims | ||
] | ||
|
||
distribution = process_group_variables_coords( | ||
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords | ||
) | ||
|
||
# Select a random subset of samples | ||
n_pp_samples = np.prod( | ||
[distribution.sizes[dim] for dim in sample_dims if dim in distribution.dims] | ||
) | ||
if num_samples > n_pp_samples: | ||
num_samples = n_pp_samples | ||
warnings.warn("num_samples is larger than the number of predictive samples.") | ||
|
||
pp_sample_ix = rng.choice(n_pp_samples, size=num_samples, replace=False) | ||
distribution = distribution.stack(sample=sample_dims).isel(sample=pp_sample_ix) | ||
|
||
if plot_collection is None: | ||
pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() | ||
|
||
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() | ||
pc_kwargs["aes"].setdefault("overlay", ["sample"]) | ||
pc_kwargs.setdefault("col_wrap", 5) | ||
pc_kwargs.setdefault("cols", "__variable__") | ||
pc_kwargs.setdefault("rows", None) | ||
|
||
figsize = pc_kwargs["plot_grid_kws"].get("figsize", None) | ||
figsize_units = pc_kwargs["plot_grid_kws"].get("figsize_units", "inches") | ||
col_dims = pc_kwargs["cols"] | ||
row_dims = pc_kwargs["rows"] | ||
if figsize is None: | ||
figsize = plot_bknd.scale_fig_size( | ||
figsize, | ||
rows=process_facet_dims(distribution, row_dims)[0], | ||
cols=process_facet_dims(distribution, col_dims)[0], | ||
figsize_units=figsize_units, | ||
) | ||
figsize_units = "dots" | ||
pc_kwargs["plot_grid_kws"]["figsize"] = figsize | ||
pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units | ||
|
||
plot_collection = PlotCollection.grid( | ||
distribution, | ||
backend=backend, | ||
**pc_kwargs, | ||
) | ||
|
||
if aes_map is None: | ||
aes_map = {} | ||
else: | ||
aes_map = aes_map.copy() | ||
if labeller is None: | ||
labeller = BaseLabeller() | ||
|
||
# We don't want credible_interval or point_estimate to be mapped to the density representation | ||
plot_kwargs.setdefault("credible_interval", False) | ||
plot_kwargs.setdefault("point_estimate", False) | ||
plot_kwargs.setdefault("point_estimate_text", False) | ||
|
||
# Plot the predictive density | ||
pred_density_kwargs = copy(plot_kwargs.get("predictive_density", {})) | ||
if pred_density_kwargs is not False: | ||
plot_kwargs.setdefault(kind, pred_density_kwargs) | ||
plot_kwargs[kind].setdefault("alpha", 0.3) | ||
if kind == "hist": | ||
if plot_kwargs["hist"] is not False: | ||
plot_kwargs["hist"].setdefault("edgecolor", None) | ||
stats_kwargs.setdefault("density", True) | ||
|
||
plot_collection = plot_dist( | ||
distribution, | ||
group=group, | ||
sample_dims=pp_dims, | ||
kind=kind, | ||
plot_kwargs=plot_kwargs, | ||
aes_map=aes_map, | ||
pc_kwargs=pc_kwargs, | ||
plot_collection=plot_collection, | ||
) | ||
|
||
# Plot the observed density | ||
observed_density_kwargs = copy( | ||
plot_kwargs.get("observed_density", False if group == "prior_predictive" else {}) | ||
) | ||
|
||
if observed_density_kwargs is not False: | ||
observed_density_kwargs.setdefault("color", "black") | ||
if kind == "hist": | ||
observed_density_kwargs.setdefault("alpha", 0.3) | ||
observed_density_kwargs.setdefault("edgecolor", None) | ||
stats_kwargs.setdefault("density", True) | ||
|
||
_, _, observed_ignore = filter_aes( | ||
plot_collection, aes_map, "observed_density", sample_dims | ||
) | ||
|
||
if kind == "kde": | ||
dt_observed = dt.observed_data.ds.azstats.kde(dims=pp_dims, **stats_kwargs) | ||
plot_collection.map( | ||
line_xy, | ||
"observe_density", | ||
data=dt_observed, | ||
ignore_aes=observed_ignore, | ||
**observed_density_kwargs, | ||
) | ||
|
||
if kind == "hist": | ||
dt_observed = dt.observed_data.ds.azstats.histogram(dims=pp_dims, **stats_kwargs) | ||
plot_collection.map( | ||
hist, | ||
"observe_density", | ||
data=dt_observed, | ||
ignore_aes=observed_ignore, | ||
**observed_density_kwargs, | ||
) | ||
|
||
if kind == "ecdf": | ||
dt_observed = dt.observed_data.ds.azstats.ecdf(**stats_kwargs) | ||
plot_collection.map( | ||
ecdf_line, | ||
"observe_density", | ||
data=dt_observed, | ||
ignore_aes=observed_ignore, | ||
**observed_density_kwargs, | ||
) | ||
|
||
return plot_collection |
Oops, something went wrong.