Skip to content

Commit

Permalink
pointwise elpd diagnostics (text formatting and plot) (#678)
Browse files Browse the repository at this point in the history
* Initial commit of pointwise elpd plot

* Fix coords typo

* Add xlabel option

* black

* Add print method for loo and waic

* black

* Add coloring options

* Improve legend handling and some fixes

* Add plot_pointwise_elpd to docs

* Rename to plot_elpd

* >3d support and add title legend

* Uppercase format constants

* Return pointwise elpd as dataarray in waic and loo

* Make ELPDData class public

* Accept ELPDData as plot_elpd arguments

* black and some fixes

* Move functions to plot utils to use also in plot_khat

* Fixes

* Add threshold argument to highlight worse points

* black and fix tests

* waic and loo to use xarray

* psislw improvement

Do not require "samples" dim to be last.

* Add tests

* fix lint

* Extend tests
  • Loading branch information
OriolAbril authored and aloctavodia committed Jun 9, 2019
1 parent 7e34e8d commit a3b1c78
Show file tree
Hide file tree
Showing 10 changed files with 809 additions and 102 deletions.
2 changes: 2 additions & 0 deletions arviz/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .forestplot import plot_forest
from .kdeplot import plot_kde, _fast_kde, _fast_kde_2d
from .parallelplot import plot_parallel
from .elpdplot import plot_elpd
from .posteriorplot import plot_posterior
from .traceplot import plot_trace
from .pairplot import plot_pair
Expand All @@ -28,6 +29,7 @@
"_fast_kde",
"_fast_kde_2d",
"plot_parallel",
"plot_elpd",
"plot_posterior",
"plot_trace",
"plot_pair",
Expand Down
276 changes: 276 additions & 0 deletions arviz/plots/elpdplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
"""Plot pointwise elpd estimations of inference data."""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.ticker import NullFormatter
from matplotlib.lines import Line2D

from ..data import convert_to_inference_data
from .plot_utils import (
_scale_fig_size,
get_coords,
color_from_dim,
format_coords_as_labels,
set_xticklabels,
)
from ..stats import waic, loo, ELPDData


def plot_elpd(
compare_dict,
color=None,
xlabels=False,
figsize=None,
textsize=None,
coords=None,
legend=False,
threshold=None,
ax=None,
ic="waic",
scale="deviance",
plot_kwargs=None,
):
"""
Plot a scatter or hexbin matrix of the sampled parameters.
Parameters
----------
compare_dict : mapping, str -> ELPDData or InferenceData
A dictionary mapping the model name to the object containing its inference data or
the result of `waic`/`loo` functions.
Refer to az.convert_to_inference_data for details on possible dict items
color : str or array_like, optional
Colors of the scatter plot, if color is a str all dots will have the same color,
if it is the size of the observations, each dot will have the specified color,
otherwise, it will be interpreted as a list of the dims to be used for the color code
xlabels : bool, optional
Use coords as xticklabels
figsize : figure size tuple, optional
If None, size is (8 + numvars, 8 + numvars)
textsize: int, optional
Text size for labels. If None it will be autoscaled based on figsize.
coords : mapping, optional
Coordinates of points to plot. **All** values are used for computation, but only a
a subset can be plotted for convenience.
legend : bool, optional
Include a legend to the plot. Only taken into account when color argument is a dim name.
threshold : float
If some elpd difference is larger than `threshold * elpd.std()`, show its label. If
`None`, no observations will be highlighted.
ax: axes, optional
Matplotlib axes
ic : str, optional
Information Criterion (WAIC or LOO) used to compare models. Default WAIC. Only taken
into account when input is InferenceData.
scale : str, optional
scale argument passed to az.waic or az.loo, see their docs for details. Only taken
into account when input is InferenceData.
plot_kwargs : dicts, optional
Additional keywords passed to ax.plot
Returns
-------
ax : matplotlib axes
Examples
--------
Compare pointwise WAIC for centered and non centered models of the 8school problem
.. plot::
:context: close-figs
>>> import arviz as az
>>> idata1 = az.load_arviz_data("centered_eight")
>>> idata2 = az.load_arviz_data("non_centered_eight")
>>> az.plot_elpd(
>>> {"centered model": idata1, "non centered model": idata2},
>>> xlabels=True
>>> )
"""
valid_ics = ["waic", "loo"]
ic = ic.lower()
if ic not in valid_ics:
raise ValueError(
("Information Criteria type {} not recognized." "IC must be in {}").format(
ic, valid_ics
)
)
ic_fun = waic if ic == "waic" else loo

# Make sure all object are ELPDData
for k, item in compare_dict.items():
if not isinstance(item, ELPDData):
compare_dict[k] = ic_fun(convert_to_inference_data(item), pointwise=True, scale=scale)
ics = [elpd_data.index[0] for elpd_data in compare_dict.values()]
if not all(x == ics[0] for x in ics):
raise SyntaxError(
"All Information Criteria must be of the same kind, but both loo and waic data present"
)
ic = ics[0]
scales = [elpd_data["{}_scale".format(ic)] for elpd_data in compare_dict.values()]
if not all(x == scales[0] for x in scales):
raise SyntaxError(
"All Information Criteria must be on the same scale, but {} are present".format(
set(scales)
)
)
numvars = len(compare_dict)
models = list(compare_dict.keys())

if coords is None:
coords = {}

if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs.setdefault("marker", "+")

pointwise_data = [
get_coords(compare_dict[model]["{}_i".format(ic)], coords) for model in models
]
xdata = np.arange(pointwise_data[0].size)

if isinstance(color, str):
if color in pointwise_data[0].dims:
colors, color_mapping = color_from_dim(pointwise_data[0], color)
if legend:
cmap_name = plot_kwargs.pop("cmap", plt.rcParams["image.cmap"])
markersize = plot_kwargs.pop("s", plt.rcParams["lines.markersize"])
cmap = getattr(cm, cmap_name)
handles = [
Line2D(
[],
[],
color=cmap(float_color),
label=coord,
ms=markersize,
lw=0,
**plot_kwargs
)
for coord, float_color in color_mapping.items()
]
plot_kwargs.setdefault("cmap", cmap_name)
plot_kwargs.setdefault("s", markersize ** 2)
plot_kwargs.setdefault("c", colors)
else:
plot_kwargs.setdefault("c", color)
legend = False
else:
legend = False
plot_kwargs.setdefault("c", color)

if xlabels:
coord_labels = format_coords_as_labels(pointwise_data[0])

if numvars < 2:
raise Exception("Number of models to compare must be 2 or greater.")

if numvars == 2:
(figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size(
figsize, textsize, numvars - 1, numvars - 1
)
plot_kwargs.setdefault("s", markersize ** 2)

if ax is None:
fig, ax = plt.subplots(figsize=figsize, constrained_layout=(not xlabels and not legend))

ydata = pointwise_data[0] - pointwise_data[1]
ax.scatter(xdata, ydata, **plot_kwargs)
if threshold is not None:
ydata = ydata.values.flatten()
diff_abs = np.abs(ydata - ydata.mean())
bool_ary = diff_abs > threshold * ydata.std()
try:
coord_labels
except NameError:
coord_labels = xdata.astype(str)
outliers = np.argwhere(bool_ary).squeeze()
for outlier in outliers:
label = coord_labels[outlier]
ax.text(
outlier,
ydata[outlier],
label,
horizontalalignment="center",
verticalalignment="bottom" if ydata[outlier] > 0 else "top",
fontsize=0.8 * xt_labelsize,
)

ax.set_title("{} - {}".format(*models), fontsize=titlesize, wrap=True)
ax.set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True)
ax.tick_params(labelsize=xt_labelsize)
if xlabels:
set_xticklabels(ax, coord_labels)
fig.autofmt_xdate()
if legend:
ncols = len(handles) // 6 + 1
ax.legend(handles=handles, ncol=ncols, title=color)

else:
(figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size(
figsize, textsize, numvars - 2, numvars - 2
)
plot_kwargs.setdefault("s", markersize ** 2)

if ax is None:
fig, ax = plt.subplots(
numvars - 1,
numvars - 1,
figsize=figsize,
constrained_layout=(not xlabels and not legend),
)

for i in range(0, numvars - 1):
var1 = pointwise_data[i]

for j in range(0, numvars - 1):
if j < i:
ax[j, i].axis("off")
continue

var2 = pointwise_data[j + 1]
ax[j, i].scatter(xdata, var1 - var2, **plot_kwargs)
if threshold is not None:
ydata = (var1 - var2).values.flatten()
diff_abs = np.abs(ydata - ydata.mean())
bool_ary = diff_abs > threshold * ydata.std()
try:
coord_labels
except NameError:
coord_labels = xdata.astype(str)
outliers = np.argwhere(bool_ary).squeeze()
for outlier in outliers:
label = coord_labels[outlier]
ax[j, i].text(
outlier,
ydata[outlier],
label,
horizontalalignment="center",
verticalalignment="bottom" if ydata[outlier] > 0 else "top",
fontsize=0.8 * xt_labelsize,
)

if j + 1 != numvars - 1:
ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter())
ax[j, i].set_xticks([])
elif xlabels:
set_xticklabels(ax[j, i], coord_labels)

if i != 0:
ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter())
ax[j, i].set_yticks([])
else:
ax[j, i].set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True)

ax[j, i].tick_params(labelsize=xt_labelsize)
ax[j, i].set_title(
"{} - {}".format(models[i], models[j + 1]), fontsize=titlesize, wrap=True
)
if xlabels:
fig.autofmt_xdate()
if legend:
ncols = len(handles) // 6 + 1
ax[0, 1].legend(
handles=handles, ncol=ncols, title=color, bbox_to_anchor=(0, 1), loc="upper left"
)
return ax
53 changes: 53 additions & 0 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,56 @@ def get_coords(data, coords):
" dimensions are valid. {}"
).format(err)
)


def color_from_dim(dataarray, dim_name):
"""Return colors and color mapping of a DataArray using coord values as color code.
Parameters
----------
dataarray : xarray.DataArray
dim_name : str
dimension whose coordinates will be used as color code.
Returns
-------
colors : array of floats
Array of colors (as floats for use with a cmap) for each element in the dataarray.
color_mapping : mapping coord_value -> float
Mapping from coord values to corresponding color
"""
present_dims = dataarray.dims
coord_values = dataarray[dim_name].values
unique_coords = set(coord_values)
color_mapping = {coord: num / len(unique_coords) for num, coord in enumerate(unique_coords)}
if len(present_dims) > 1:
multi_coords = dataarray.coords.to_index()
coord_idx = present_dims.index(dim_name)
colors = [color_mapping[coord[coord_idx]] for coord in multi_coords]
else:
colors = [color_mapping[coord] for coord in coord_values]
return colors, color_mapping


def format_coords_as_labels(dataarray):
"""Format 1d or multi-d dataarray coords as strings."""
coord_labels = dataarray.coords.to_index().values
if isinstance(coord_labels[0], tuple):
fmt = ", ".join(["{}" for _ in coord_labels[0]])
coord_labels[:] = [fmt.format(*x) for x in coord_labels]
else:
coord_labels[:] = ["{}".format(s) for s in coord_labels]
return coord_labels


def set_xticklabels(ax, coord_labels):
"""Set xticklabels to label list using Matplotlib default formatter."""
ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
xticks = ax.get_xticks().astype(np.int64)
xticks = xticks[(xticks >= 0) & (xticks < len(coord_labels))]
if len(xticks) > len(coord_labels):
ax.set_xticks(np.arange(len(coord_labels)))
ax.set_xticklabels(coord_labels)
else:
ax.set_xticks(xticks)
ax.set_xticklabels(coord_labels[xticks])
1 change: 1 addition & 0 deletions arviz/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"summary",
"waic",
"effective_sample_size",
"ELPDData",
"ess",
"rhat",
"mcse",
Expand Down
Loading

0 comments on commit a3b1c78

Please sign in to comment.