Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pointwise elpd diagnostics (text formatting and plot) #678

Merged
merged 27 commits into from
Jun 9, 2019
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8d2016e
Initial commit of pointwise elpd plot
OriolAbril May 23, 2019
f968777
Fix coords typo
OriolAbril May 23, 2019
dad9f2c
Merge branch 'master' into pointwise-elpd
OriolAbril May 24, 2019
9d0fdb1
Add xlabel option
OriolAbril May 27, 2019
a3f9d7a
black
OriolAbril May 28, 2019
38f0c91
Add print method for loo and waic
OriolAbril May 28, 2019
b0fbd2c
black
OriolAbril May 28, 2019
ae37877
Add coloring options
OriolAbril May 28, 2019
0e22de0
Improve legend handling and some fixes
OriolAbril May 28, 2019
0a50955
Add plot_pointwise_elpd to docs
OriolAbril May 28, 2019
6434e45
Rename to plot_elpd
OriolAbril May 30, 2019
ddc47db
>3d support and add title legend
OriolAbril Jun 2, 2019
3796e51
Uppercase format constants
OriolAbril Jun 2, 2019
2208fed
Return pointwise elpd as dataarray in waic and loo
OriolAbril Jun 4, 2019
76cb91b
Make ELPDData class public
OriolAbril Jun 4, 2019
24b72e2
Accept ELPDData as plot_elpd arguments
OriolAbril Jun 4, 2019
77a5272
black and some fixes
OriolAbril Jun 4, 2019
e22d658
Move functions to plot utils to use also in plot_khat
OriolAbril Jun 4, 2019
9775b2b
Fixes
OriolAbril Jun 4, 2019
755f7c2
Add threshold argument to highlight worse points
OriolAbril Jun 5, 2019
fd6edb7
black and fix tests
OriolAbril Jun 6, 2019
02e9433
Merge branch 'master' into pointwise-elpd
OriolAbril Jun 6, 2019
792506f
waic and loo to use xarray
OriolAbril Jun 6, 2019
48a69a5
psislw improvement
OriolAbril Jun 6, 2019
65ea945
Add tests
OriolAbril Jun 7, 2019
b4d292b
fix lint
OriolAbril Jun 7, 2019
745a19a
Extend tests
OriolAbril Jun 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arviz/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .forestplot import plot_forest
from .kdeplot import plot_kde, _fast_kde, _fast_kde_2d
from .parallelplot import plot_parallel
from .elpdplot import plot_elpd
from .posteriorplot import plot_posterior
from .traceplot import plot_trace
from .pairplot import plot_pair
Expand All @@ -28,6 +29,7 @@
"_fast_kde",
"_fast_kde_2d",
"plot_parallel",
"plot_elpd",
"plot_posterior",
"plot_trace",
"plot_pair",
Expand Down
276 changes: 276 additions & 0 deletions arviz/plots/elpdplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
"""Plot pointwise elpd estimations of inference data."""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.ticker import NullFormatter
from matplotlib.lines import Line2D

from ..data import convert_to_inference_data
from .plot_utils import (
_scale_fig_size,
get_coords,
color_from_dim,
format_coords_as_labels,
set_xticklabels,
)
from ..stats import waic, loo, ELPDData


def plot_elpd(
compare_dict,
color=None,
xlabels=False,
figsize=None,
textsize=None,
coords=None,
legend=False,
threshold=None,
ax=None,
ic="waic",
scale="deviance",
plot_kwargs=None,
):
"""
Plot a scatter or hexbin matrix of the sampled parameters.

Parameters
----------
compare_dict : mapping, str -> ELPDData or InferenceData
A dictionary mapping the model name to the object containing its inference data or
the result of `waic`/`loo` functions.
Refer to az.convert_to_inference_data for details on possible dict items
color : str or array_like, optional
Colors of the scatter plot, if color is a str all dots will have the same color,
if it is the size of the observations, each dot will have the specified color,
otherwise, it will be interpreted as a list of the dims to be used for the color code
xlabels : bool, optional
Use coords as xticklabels
figsize : figure size tuple, optional
If None, size is (8 + numvars, 8 + numvars)
textsize: int, optional
Text size for labels. If None it will be autoscaled based on figsize.
coords : mapping, optional
Coordinates of points to plot. **All** values are used for computation, but only a
a subset can be plotted for convenience.
legend : bool, optional
Include a legend to the plot. Only taken into account when color argument is a dim name.
threshold : float
If some elpd difference is larger than `threshold * elpd.std()`, show its label. If
`None`, no observations will be highlighted.
ax: axes, optional
Matplotlib axes
ic : str, optional
Information Criterion (WAIC or LOO) used to compare models. Default WAIC. Only taken
into account when input is InferenceData.
scale : str, optional
scale argument passed to az.waic or az.loo, see their docs for details. Only taken
into account when input is InferenceData.
plot_kwargs : dicts, optional
Additional keywords passed to ax.plot

Returns
-------
ax : matplotlib axes

Examples
--------
Compare pointwise WAIC for centered and non centered models of the 8school problem

.. plot::
:context: close-figs

>>> import arviz as az
>>> idata1 = az.load_arviz_data("centered_eight")
>>> idata2 = az.load_arviz_data("non_centered_eight")
>>> az.plot_elpd(
>>> {"centered model": idata1, "non centered model": idata2},
>>> xlabels=True
>>> )

"""
valid_ics = ["waic", "loo"]
ic = ic.lower()
if ic not in valid_ics:
raise ValueError(
("Information Criteria type {} not recognized." "IC must be in {}").format(
ic, valid_ics
)
)
ic_fun = waic if ic == "waic" else loo

# Make sure all object are ELPDData
for k, item in compare_dict.items():
if not isinstance(item, ELPDData):
compare_dict[k] = ic_fun(convert_to_inference_data(item), pointwise=True, scale=scale)
ics = [elpd_data.index[0] for elpd_data in compare_dict.values()]
if not all(x == ics[0] for x in ics):
raise SyntaxError(
"All Information Criteria must be of the same kind, but both loo and waic data present"
)
ic = ics[0]
scales = [elpd_data["{}_scale".format(ic)] for elpd_data in compare_dict.values()]
if not all(x == scales[0] for x in scales):
raise SyntaxError(
"All Information Criteria must be on the same scale, but {} are present".format(
set(scales)
)
)
numvars = len(compare_dict)
models = list(compare_dict.keys())

if coords is None:
coords = {}

if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs.setdefault("marker", "+")

pointwise_data = [
get_coords(compare_dict[model]["{}_i".format(ic)], coords) for model in models
]
xdata = np.arange(pointwise_data[0].size)

if isinstance(color, str):
if color in pointwise_data[0].dims:
colors, color_mapping = color_from_dim(pointwise_data[0], color)
if legend:
cmap_name = plot_kwargs.pop("cmap", plt.rcParams["image.cmap"])
markersize = plot_kwargs.pop("s", plt.rcParams["lines.markersize"])
cmap = getattr(cm, cmap_name)
handles = [
Line2D(
[],
[],
color=cmap(float_color),
label=coord,
ms=markersize,
lw=0,
**plot_kwargs
)
for coord, float_color in color_mapping.items()
]
plot_kwargs.setdefault("cmap", cmap_name)
plot_kwargs.setdefault("s", markersize ** 2)
plot_kwargs.setdefault("c", colors)
else:
plot_kwargs.setdefault("c", color)
legend = False
else:
legend = False
plot_kwargs.setdefault("c", color)

if xlabels:
coord_labels = format_coords_as_labels(pointwise_data[0])

if numvars < 2:
raise Exception("Number of models to compare must be 2 or greater.")

if numvars == 2:
(figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size(
figsize, textsize, numvars - 1, numvars - 1
)
plot_kwargs.setdefault("s", markersize ** 2)

if ax is None:
fig, ax = plt.subplots(figsize=figsize, constrained_layout=(not xlabels and not legend))

ydata = pointwise_data[0] - pointwise_data[1]
ax.scatter(xdata, ydata, **plot_kwargs)
if threshold is not None:
ydata = ydata.values.flatten()
diff_abs = np.abs(ydata - ydata.mean())
bool_ary = diff_abs > threshold * ydata.std()
try:
coord_labels
except NameError:
coord_labels = xdata.astype(str)
outliers = np.argwhere(bool_ary).squeeze()
for outlier in outliers:
label = coord_labels[outlier]
ax.text(
outlier,
ydata[outlier],
label,
horizontalalignment="center",
verticalalignment="bottom" if ydata[outlier] > 0 else "top",
fontsize=0.8 * xt_labelsize,
)

ax.set_title("{} - {}".format(*models), fontsize=titlesize, wrap=True)
ax.set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True)
ax.tick_params(labelsize=xt_labelsize)
if xlabels:
set_xticklabels(ax, coord_labels)
fig.autofmt_xdate()
if legend:
ncols = len(handles) // 6 + 1
ax.legend(handles=handles, ncol=ncols, title=color)

else:
(figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size(
figsize, textsize, numvars - 2, numvars - 2
)
plot_kwargs.setdefault("s", markersize ** 2)

if ax is None:
fig, ax = plt.subplots(
numvars - 1,
numvars - 1,
figsize=figsize,
constrained_layout=(not xlabels and not legend),
)

for i in range(0, numvars - 1):
var1 = pointwise_data[i]

for j in range(0, numvars - 1):
if j < i:
ax[j, i].axis("off")
continue

var2 = pointwise_data[j + 1]
ax[j, i].scatter(xdata, var1 - var2, **plot_kwargs)
if threshold is not None:
ydata = (var1 - var2).values.flatten()
diff_abs = np.abs(ydata - ydata.mean())
bool_ary = diff_abs > threshold * ydata.std()
try:
coord_labels
except NameError:
coord_labels = xdata.astype(str)
outliers = np.argwhere(bool_ary).squeeze()
for outlier in outliers:
label = coord_labels[outlier]
ax[j, i].text(
outlier,
ydata[outlier],
label,
horizontalalignment="center",
verticalalignment="bottom" if ydata[outlier] > 0 else "top",
fontsize=0.8 * xt_labelsize,
)

if j + 1 != numvars - 1:
ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter())
ax[j, i].set_xticks([])
elif xlabels:
set_xticklabels(ax[j, i], coord_labels)

if i != 0:
ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter())
ax[j, i].set_yticks([])
else:
ax[j, i].set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True)

ax[j, i].tick_params(labelsize=xt_labelsize)
ax[j, i].set_title(
"{} - {}".format(models[i], models[j + 1]), fontsize=titlesize, wrap=True
)
if xlabels:
fig.autofmt_xdate()
if legend:
ncols = len(handles) // 6 + 1
ax[0, 1].legend(
handles=handles, ncol=ncols, title=color, bbox_to_anchor=(0, 1), loc="upper left"
)
return ax
54 changes: 54 additions & 0 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,57 @@ def get_coords(data, coords):
" dimensions are valid. {}"
).format(err)
)


def color_from_dim(dataarray, dim_name):
"""Return colors and color mapping of a DataArray using coord values as color code.

Parameters
----------
dataarray : xarray.DataArray
dim_name : str
dimension whose coordinates will be used as color code.

Returns
-------
colors : array of floats
Array of colors (as floats for use with a cmap) for each element in the dataarray.
color_mapping : mapping coord_value -> float
Mapping from coord values to corresponding color
"""
present_dims = dataarray.dims
coord_values = dataarray[dim_name].values
unique_coords = set(coord_values)
color_mapping = {coord: num / len(unique_coords) for num, coord in enumerate(unique_coords)}
if len(present_dims) > 1:
multi_coords = dataarray.coords.to_index()
coord_idx = present_dims.index(dim_name)
colors = [color_mapping[coord[coord_idx]] for coord in multi_coords]
else:
colors = [color_mapping[coord] for coord in coord_values]
return colors, color_mapping


def format_coords_as_labels(dataarray):
"""Format 1d or multi-d dataarray coords as strings."""
coord_labels = dataarray.coords.to_index().values
if isinstance(coord_labels[0], tuple):
fmt = ", ".join(["{}" for _ in coord_labels[0]])
coord_labels[:] = [fmt.format(*x) for x in coord_labels]
else:
coord_labels[:] = ["{}".format(s) for s in coord_labels]
return coord_labels


def set_xticklabels(ax, coord_labels):
"""Set xticklabels to label list using Matplotlib default formatter."""
xlim = ax.get_xlim()
ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
xticks = ax.get_xticks().astype(np.int64)
xticks = xticks[(xticks > xlim[0]) & (xticks < xlim[1])]
if len(xticks) > len(coord_labels):
ax.set_xticks(np.arange(len(coord_labels)))
ax.set_xticklabels(coord_labels)
else:
ax.set_xticks(xticks)
ax.set_xticklabels(coord_labels[xticks])
1 change: 1 addition & 0 deletions arviz/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"summary",
"waic",
"effective_sample_size",
"ELPDData",
"ess",
"rhat",
"mcse",
Expand Down
Loading