Skip to content

Commit

Permalink
Refactor levels and hdi_levels to hdi_probs
Browse files Browse the repository at this point in the history
This commit will refactor the `levels` argument of `plot_kde()` and the
`hdi_levels` argument of `_find_hdi_contours()` to `hdi_probs` for both,
with the hope that this provides an interface consistent with that given
for `arviz.hdi()`.
  • Loading branch information
wm1995 committed Apr 21, 2021
1 parent a9745b3 commit cf10583
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 24 deletions.
16 changes: 7 additions & 9 deletions arviz/plots/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def plot_kde(
quantiles=None,
rotated=False,
contour=True,
levels=None,
hdi_probs=None,
fill_last=False,
figsize=None,
textsize=None,
Expand Down Expand Up @@ -75,7 +75,7 @@ def plot_kde(
contour : bool
If True plot the 2D KDE using contours, otherwise plot a smooth 2D KDE.
Defaults to True.
levels : list
hdi_probs : list
Confidence levels for highest density (2-dimensional) interval contours of a 2D KDE.
fill_last : bool
If True fill the last contour of the 2D KDE plot. Defaults to False.
Expand Down Expand Up @@ -264,15 +264,13 @@ def plot_kde(
gridsize = (128, 128) if contour else (256, 256)
density, xmin, xmax, ymin, ymax = _fast_kde_2d(values, values2, gridsize=gridsize)

if levels is not None:
# Check hdi levels are within bounds (0, 1)
if min(levels) <= 0 or max(levels) >= 1:
raise ValueError(
"Highest density interval confidence levels must be between 0 and 1"
)
if hdi_probs is not None:
# Check hdi probs are within bounds (0, 1)
if min(hdi_probs) <= 0 or max(hdi_probs) >= 1:
raise ValueError("Highest density interval probabilities must be between 0 and 1")

# Calculate contour levels and sort for matplotlib
contour_levels = _find_hdi_contours(density, levels)
contour_levels = _find_hdi_contours(density, hdi_probs)
contour_levels.sort()

contour_level_list = [0] + list(contour_levels) + [density.max()]
Expand Down
14 changes: 7 additions & 7 deletions arviz/stats/density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,31 +1038,31 @@ def histogram(data, bins, range_hist=None):
return hist, hist_dens, bin_edges


def _find_hdi_contours(density, hdi_levels):
def _find_hdi_contours(density, hdi_probs):
"""
Find contours enclosing regions of highest posterior density.
Parameters
----------
density : array-like
A 2D KDE on a grid with cells of equal area.
hdi_levels : array-like
An array of highest density interval confidence levels.
hdi_probs : array-like
An array of highest density interval confidence probabilities.
Returns
-------
contour_levels : array
The contour levels corresponding to the given HDI levels.
The contour levels corresponding to the given HDI probabilities.
"""
# Using the algorithm from corner.py
sorted_density = np.sort(density, axis=None)[::-1]
sm = sorted_density.cumsum()
sm /= sm[-1]

contours = np.empty_like(hdi_levels)
for idx, hdi_level in enumerate(hdi_levels):
contours = np.empty_like(hdi_probs)
for idx, hdi_prob in enumerate(hdi_probs):
try:
contours[idx] = sorted_density[sm <= hdi_level][-1]
contours[idx] = sorted_density[sm <= hdi_prob][-1]
except IndexError:
contours[idx] = sorted_density[0]

Expand Down
16 changes: 8 additions & 8 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ def test_plot_joint_bad(models):
{"is_circular": "radians"},
{"is_circular": "degrees"},
{"adaptive": True},
{"levels": [0.3, 0.9, 0.6]},
{"levels": [0.3, 0.6, 0.9], "contourf_kwargs": {"cmap": "Blues"}},
{"levels": [0.9, 0.6, 0.3], "contour_kwargs": {"alpha": 0}},
{"hdi_probs": [0.3, 0.9, 0.6]},
{"hdi_probs": [0.3, 0.6, 0.9], "contourf_kwargs": {"cmap": "Blues"}},
{"hdi_probs": [0.9, 0.6, 0.3], "contour_kwargs": {"alpha": 0}},
],
)
def test_plot_kde(continuous_model, kwargs):
Expand All @@ -410,13 +410,13 @@ def test_plot_kde(continuous_model, kwargs):
@pytest.mark.parametrize(
"kwargs",
[
{"levels": [1, 2, 3]},
{"levels": [-0.3, 0.6, 0.9]},
{"levels": [0, 0.3, 0.6]},
{"levels": [0.3, 0.6, 1]},
{"hdi_probs": [1, 2, 3]},
{"hdi_probs": [-0.3, 0.6, 0.9]},
{"hdi_probs": [0, 0.3, 0.6]},
{"hdi_probs": [0.3, 0.6, 1]},
],
)
def test_plot_kde_levels_bad(continuous_model, kwargs):
def test_plot_kde_hdi_probs_bad(continuous_model, kwargs):
"""Ensure invalid hdi probabilities are rejected."""
with pytest.raises(ValueError):
plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
Expand Down

0 comments on commit cf10583

Please sign in to comment.