Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pointwise elpd diagnostics (text formatting and plot) #678

Merged
merged 27 commits into from
Jun 9, 2019

Conversation

OriolAbril
Copy link
Member

@OriolAbril OriolAbril commented May 23, 2019

PR to improve the output of waic and loo to make it more verbose and to create the pointwise elpd comparison plot. Will fix #496 and will fix #660.

This first version of the plot already works, but there is still much work needed.

Plots

  • coloring?
    image

Data from the toy model in this notebook. This was generated with color="river_distance" and legend=True.

  • tick labels?
    image

  • highlight worst points?
    image
    Added a threshold argument in order to show labels of point further away than threshold times elpd_i.std(). In the example, threshold=1 to force showing the labels.

  • move common coloring/labeling data functions to plot_utils.py to have the same functionality in plot_elpd and plot_khat

Stats functions

Both

  • check behaviour on more than 3D items (not only chain, draw and one extra dim).

Everything working after some corrections, still debating between reshaping pointwise loo/waic to original shape or work with multi-index objects from there on (eventually the scatter plot is flattened, thus at some point they are needed)

  • Optimize API: find a way to call waic and loo only once per InferenceData object ideally.

Idea: Store pointwise loo, waic and pareto_k as dataarrays in the ELPDData object instead of as a flattened array and call plot_elpd with a dict of ELPDData and plot_khat with an ELPDData object instead of the array of pareto shape values. Using the dataarray will allow coloring, ticklabels, selection of a subset of observations and so on, and thanks to the overwritten __str__ method, including this extra information it the ELPDData object won't clutter the relevant info when printed.

Copy link
Contributor

@ahartikainen ahartikainen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments

doc/api.rst Outdated Show resolved Hide resolved
arviz/plots/__init__.py Outdated Show resolved Hide resolved
arviz/plots/pointwiseelpdplot.py Outdated Show resolved Hide resolved
arviz/stats/stats.py Outdated Show resolved Hide resolved
log_likelihood = log_likelihood.stack(samples=("chain", "draw"))
shape = log_likelihood.shape
n_samples = shape[-1]
n_data_points = np.product(shape[:-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if loglikelihood is nD?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still works (i checked locally with the same model as the first plot and the test for this passed too). I decided to not stack the observations dimensions in order to keep this information also in the pointwise loo and waic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the only braking change is the shape of psislw input. Which now, in the 2D (sample, and obs) is transposed from what it used to be. In the case of dataarray inputs, this should not affect (as long as the dimension is calles "samples") but for array inputs it does. It could be changed to the original shape and then call np.rollaxis to get the n_samples axis to the last position (as expected in xarray apply ufunc).

@OriolAbril
Copy link
Member Author

Ready for review!

I can't decide on the yaxis ticks in the case of more than 2 models. Right now the axis are not shared, thus, the ticks on the left do not correspond to the values in the other plots. I came across with two options, but none of them really convince me. The first is to share the axis (or force the ylims to be common throughout each row) and the second is to put yticklabels in all plots.

@OriolAbril
Copy link
Member Author

PR summary

This PR has adressed elpd statistics, loo and waic, and created a plot to compare various models according to pointwise elpd statistics. It adds a completely new plot to ArviZ, and modifies various stats functions. The main goal of all these changes in the stats functions is to use xarray for the computation. This allows to work easily with nd objects while maintaining the coords information. Eventually, plots like plot_elpd of plot_khat (on which I will work once this PR is merged) can use this coords information to automatically get relevant labels, use a color code generated from the coord values or plot only a subgroup of the observations.

It also adds the class ELPDData to ArviZ, but as it is a subclass of pd.Series, its only effect is when printing the result of loo or waic. Everything else stays the same.

Modified functions

loo and waic

Their arguments stay the same. The computation method has been modified to use wrap_xarray_ufunc, so that the result of the pointwise calculations is a dataarray (instead of an array) and it maintains the original shape (instead of being flattened). Therefore, the output of loo().loo_i is now a dataarray instead of an array. In addition, the output class is now ELPDData class, so that it is more verbose, organized and relevant when printed (i.e. in the case of waic, waic_i is never printed, but it is stored if pointwise is true).

psislw

it has also been modified to work wrapping ufuncs on xarray or array objects. The output should be the same type and shape as the input.

However, for convenience I have modified the dimension order of the input data. It currently supports only 2d inputs of shape (n_samples, n_observations), this PR proposes it to work on nd arguments of shape (..., n_samples) (which in the 2D case is the transposed).
In the case of dataarray objects it does not have much relevance, because the "samples" dimension is automatically moved to the last position by apply_ufunc, but in the case of arrays it does.

Created

ELPDData

Class to contain the result of loo and waic. It is a children of pd.Series, with the only difference of having the methods __str__ and __repr__ overwritten.

plot_elpd

Plot to compare different models according to waic_i or loo_i differences. It accepts an dictionay of InferenceData objects (for which the ic is computed before the plot) or a dictionary of ELPDData objects (which avoids recalculating the ic every time plot_elpd is called). For more examples on various parameters, see the first comment.

@OriolAbril
Copy link
Member Author

On the testing side, tests have already been included, covering nearly 100% of the lines and taking into account multidimensional objects.

ufunc_kwargs=ufunc_kwargs,
**kwargs
).values
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any speedup obtained via this ufunc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None. The computations performed are the same, I have not modified how logsumexp works. I did this so that there is no need to convert to arrays and the named dims and coords info is not lost.

@ahartikainen ahartikainen changed the title [WIP] pointwise elpd diagnostics (text formatting and plot) pointwise elpd diagnostics (text formatting and plot) Jun 9, 2019
@ahartikainen
Copy link
Contributor

LGTM

I checked the code, and everything looks good.

@aloctavodia aloctavodia merged commit a3b1c78 into arviz-devs:master Jun 9, 2019
@aloctavodia
Copy link
Contributor

Thanks @OriolAbril this looks great. Looking forward to try these changes!

@OriolAbril OriolAbril deleted the pointwise-elpd branch June 10, 2019 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

plot pointwise WAIC/LOO Improve loo functionality
5 participants