Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add contours of highest posterior density for 2D KDE plot #1665

Merged
merged 33 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d49e526
Add a function to calculate HPD regions on 2D KDEs
wm1995 Apr 15, 2021
435d746
Add option to plot HPD regions to `plot_kde`
wm1995 Apr 15, 2021
9257f2b
Change `hpd_levels` to `levels` in `plot_kde`
wm1995 Apr 15, 2021
dcab869
Refactor `hpd_levels` to `hdi_levels` in helper fn
wm1995 Apr 15, 2021
4ff80d1
Refactor HDI contour fn to `_find_hdi_contours`
wm1995 Apr 15, 2021
cc8e72e
Update contour-finding to use corner.py algorithm
wm1995 Apr 16, 2021
23320e5
Increase clarity in `plot_kde` contour-finding
wm1995 Apr 16, 2021
19e8275
Add unit test for `_find_hdi_contours`
wm1995 Apr 16, 2021
45ef939
Add unit tests for `levels` keyword of `plot_kde`
wm1995 Apr 16, 2021
23757a8
Make minor stylistic changes for pylint compliance
wm1995 Apr 21, 2021
633899d
Update `_find_hdi_contours` docstring
wm1995 Apr 21, 2021
9a9239a
Update `plot_kde` docstring
wm1995 Apr 21, 2021
a9745b3
Update new unit tests for mypy compliance
wm1995 Apr 21, 2021
cf10583
Refactor `levels` and `hdi_levels` to `hdi_probs`
wm1995 Apr 21, 2021
7b5a53e
Add an example of the HDI contour plot interface
wm1995 Apr 21, 2021
1ddcacb
Add example to `plot_kde` docstring
wm1995 Apr 21, 2021
a5bbd61
Update CHANGELOG.md to add HDI contour feature
wm1995 Apr 22, 2021
3731a30
Update `plot_pair` to pass `kde_kwargs` by value
wm1995 Apr 22, 2021
d352b2f
Add example of KDE pair plot with HDI contours
wm1995 Apr 22, 2021
0fc524e
Update `kde_plot` HDI contour example docstring
wm1995 Apr 22, 2021
5a2d2df
Update docstring for `plot_kde`
wm1995 Apr 22, 2021
4aef438
Merge branch 'upstream/main' into hpd-contours
wm1995 Apr 22, 2021
9352c12
Add tests to catch bug in `plot_kde` with Bokeh
wm1995 Apr 22, 2021
6f2d912
Update docstring for `hdi_probs` arg in `plot_kde`
wm1995 Apr 23, 2021
014e0b7
Add unit tests for HDI contours with Bokeh backend
wm1995 Apr 23, 2021
9526274
Update title of existing Bokeh 2d KDE example
wm1995 Apr 23, 2021
14b126d
Add new examples for 2d KDE plots with Bokeh
wm1995 Apr 23, 2021
4b25e52
Update Bokeh `plot_pair` to deepcopy `kde_kwargs`
wm1995 Apr 24, 2021
3b56be1
Update example 2d KDE plot titles for clarity
wm1995 Apr 25, 2021
5595173
Update `plot_kde` to use `hdi_probs` over `levels`
wm1995 Apr 25, 2021
cf9be3f
Add unit tests to test `plot_kde` warnings
wm1995 Apr 25, 2021
2d70070
Update code to be Pylint-2.8.1 compliant
wm1995 Apr 25, 2021
1c496ac
Merge branch 'main' into hpd-contours
wm1995 Apr 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion arviz/plots/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..data import InferenceData
from ..rcparams import rcParams
from ..stats.density_utils import _fast_kde_2d, kde
from ..stats.density_utils import _fast_kde_2d, kde, _find_hdi_contours
from .plot_utils import get_plotting_function


Expand All @@ -20,6 +20,7 @@ def plot_kde(
quantiles=None,
rotated=False,
contour=True,
levels=None,
fill_last=False,
figsize=None,
textsize=None,
Expand Down Expand Up @@ -74,6 +75,8 @@ def plot_kde(
contour : bool
If True plot the 2D KDE using contours, otherwise plot a smooth 2D KDE.
Defaults to True.
levels : list
Confidence levels for highest density interval contours of a 2D KDE.
wm1995 marked this conversation as resolved.
Show resolved Hide resolved
fill_last : bool
If True fill the last contour of the 2D KDE plot. Defaults to False.
figsize : tuple
Expand Down Expand Up @@ -261,6 +264,29 @@ def plot_kde(
gridsize = (128, 128) if contour else (256, 256)
density, xmin, xmax, ymin, ymax = _fast_kde_2d(values, values2, gridsize=gridsize)

if levels is not None:
# Check hdi levels are within bounds (0, 1)
if min(levels) <= 0 or max(levels) >= 1:
raise ValueError(
"Highest density interval confidence levels must be between 0 and 1"
)

# Calculate contour levels and sort for matplotlib
contour_levels = _find_hdi_contours(density, levels)
contour_levels.sort()

contour_level_list = [0] + list(contour_levels) + [density.max()]

# Add keyword arguments to contour, contourf
if contour_kwargs is None:
contour_kwargs = {"levels": contour_level_list}
else:
contour_kwargs.setdefault("levels", contour_level_list)
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
if contourf_kwargs is None:
contourf_kwargs = {"levels": contour_level_list}
else:
contourf_kwargs.setdefault("levels", contour_level_list)

lower, upper, density_q = [None] * 3

kde_plot_args = dict(
Expand Down
31 changes: 31 additions & 0 deletions arviz/stats/density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,3 +1036,34 @@ def histogram(data, bins, range_hist=None):
hist, bin_edges = np.histogram(data, bins=bins, range=range_hist)
hist_dens = hist / (hist.sum() * np.diff(bin_edges))
return hist, hist_dens, bin_edges


def _find_hdi_contours(density, hdi_levels):
"""
Find contours enclosing regions of highest posterior density.

Parameters
----------
density : array-like
A gridded 2D KDE.
wm1995 marked this conversation as resolved.
Show resolved Hide resolved
hdi_levels : array-like
An array of highest density interval confidence levels.

Returns
-------
contour_levels : array
The contour levels corresponding to the given HDI levels.
"""
# Using the algorithm from corner.py
sorted_density = np.sort(density, axis=None)[::-1]
sm = sorted_density.cumsum()
sm /= sm[-1]

contours = np.empty_like(hdi_levels)
for idx, hdi_level in enumerate(hdi_levels):
try:
contours[idx] = sorted_density[sm <= hdi_level][-1]
except IndexError:
contours[idx] = sorted_density[0]

return contours
17 changes: 17 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ def test_plot_joint_bad(models):
{"is_circular": "radians"},
{"is_circular": "degrees"},
{"adaptive": True},
{"levels": [0.3, 0.9, 0.6]},
{"levels": [0.3, 0.6, 0.9], "contourf_kwargs": {"cmap": "Blues"}},
{"levels": [0.9, 0.6, 0.3], "contour_kwargs": {"alpha": 0}},
],
)
def test_plot_kde(continuous_model, kwargs):
Expand All @@ -404,6 +407,20 @@ def test_plot_kde(continuous_model, kwargs):
assert axes is axes1


@pytest.mark.parametrize(
"kwargs",
[
{"levels": [1, 2, 3]},
{"levels": [-0.3, 0.6, 0.9]},
{"levels": [0, 0.3, 0.6]},
{"levels": [0.3, 0.6, 1]},
],
)
def test_plot_kde_levels_bad(continuous_model, kwargs):
with pytest.raises(ValueError):
plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)


@pytest.mark.parametrize("shape", [(8,), (8, 8), (8, 8, 8)])
def test_cov(shape):
x = np.random.randn(*shape)
Expand Down
47 changes: 46 additions & 1 deletion arviz/tests/base_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scipy.stats as st

from ...data import dict_to_dataset, from_dict, load_arviz_data
from ...stats.density_utils import _circular_mean, _normalize_angle
from ...stats.density_utils import _circular_mean, _normalize_angle, _find_hdi_contours
from ...utils import (
_stack,
_subset_list,
Expand Down Expand Up @@ -291,3 +291,48 @@ def test_normalize_angle(mean):

values = _normalize_angle(rvs, zero_centered=False)
assert ((values >= 0) & (values <= 2 * np.pi)).all()


@pytest.mark.parametrize("mean", [[0, 0], [1, 1]])
@pytest.mark.parametrize(
"cov",
[
np.diag([1, 1]),
np.diag([0.5, 0.5]),
np.diag([0.25, 1]),
np.array([[0.4, 0.2], [0.2, 0.8]]),
],
)
@pytest.mark.parametrize("contour_sigma", [np.array([1, 2, 3])])
def test_find_hdi_contours(mean, cov, contour_sigma):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried that this unit test is overly complicated, but I wasn't sure how best to approach it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is ok.

"""Test `_find_hdi_contours()` against SciPy's multivariate normal distribution."""
# Set up scipy distribution
rv = st.multivariate_normal(mean, cov)

# Find standard deviations and eigenvectors
eigenvals, eigenvecs = np.linalg.eig(cov)
eigenvecs = eigenvecs.T
stdevs = np.sqrt(eigenvals)

# Find min and max for grid at 7-sigma contour
extremes = np.empty((4, 2))
for i in range(4):
extremes[i] = mean + (-1) ** i * 7 * stdevs[i // 2] * eigenvecs[i // 2]
x_min, y_min = np.amin(extremes, axis=0)
x_max, y_max = np.amax(extremes, axis=0)

# Create 256x256 grid
x = np.linspace(x_min, x_max, 256)
y = np.linspace(y_min, y_max, 256)
grid = np.dstack(np.meshgrid(x, y))

density = rv.pdf(grid)

contour_sp = np.empty(contour_sigma.shape)
for idx, sigma in enumerate(contour_sigma):
contour_sp[idx] = rv.pdf(mean + sigma * stdevs[0] * eigenvecs[0])

hdi_probs = 1 - np.exp(-0.5 * contour_sigma ** 2)
contour_az = _find_hdi_contours(density, hdi_probs)

np.testing.assert_allclose(contour_sp, contour_az, rtol=1e-2, atol=1e-4)