Skip to content

Commit

Permalink
Add support for discrete variables in rank plots (#1433)
Browse files Browse the repository at this point in the history
* fix rank plots for discrete variables

* add test, add compute_ranks function and update changelog

* fix docstring
  • Loading branch information
aloctavodia committed Nov 9, 2020
1 parent 44ceb08 commit 63759c2
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Add `ref_line`, `bar`, `vlines` and `marker_vlines` kwargs to `plot_rank` ([1419](https://github.com/arviz-devs/arviz/pull/1419))
* Add observed argument to (un)plot observed data in `plot_ppc` ([1422](https://github.com/arviz-devs/arviz/pull/1422))
* Add support for named dims and coordinates with multivariate observations ([1429](https://github.com/arviz-devs/arviz/pull/1429))
* Add support for discrete variables in rank plots ([1433](https://github.com/arviz-devs/arviz/pull/1433))
* Add skipna argument to `plot_posterior` ([1432](https://github.com/arviz-devs/arviz/pull/1432))
* Make stacking the default method to compute weights in `compare` ([1438](https://github.com/arviz-devs/arviz/pull/1438))

Expand Down
6 changes: 3 additions & 3 deletions arviz/plots/backends/bokeh/rankplot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Bokeh rankplot."""
import numpy as np
import scipy.stats

from bokeh.models import Span
from bokeh.models.annotations import Title
from bokeh.models.tickers import FixedTicker

from ....stats.density_utils import histogram
from ...plot_utils import _scale_fig_size, make_label
from ...plot_utils import _scale_fig_size, make_label, compute_ranks
from .. import show_layout
from . import backend_kwarg_defaults, create_axes_grid

Expand Down Expand Up @@ -74,7 +74,7 @@ def plot_rank(
for ax, (var_name, selection, var_data) in zip(
(item for item in axes.flatten() if item is not None), plotters
):
ranks = scipy.stats.rankdata(var_data, method="average").reshape(var_data.shape)
ranks = compute_ranks(var_data)
bin_ary = np.histogram_bin_edges(ranks, bins=bins, range=(0, ranks.size))
all_counts = np.empty((len(ranks), len(bin_ary) - 1))
for idx, row in enumerate(ranks):
Expand Down
5 changes: 2 additions & 3 deletions arviz/plots/backends/matplotlib/rankplot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Matplotlib rankplot."""
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

from ....stats.density_utils import histogram
from ...plot_utils import _scale_fig_size, make_label
from ...plot_utils import _scale_fig_size, make_label, compute_ranks
from . import backend_kwarg_defaults, backend_show, create_axes_grid


Expand Down Expand Up @@ -66,7 +65,7 @@ def plot_rank(
)

for ax, (var_name, selection, var_data) in zip(np.ravel(axes), plotters):
ranks = scipy.stats.rankdata(var_data, method="average").reshape(var_data.shape)
ranks = compute_ranks(var_data)
bin_ary = np.histogram_bin_edges(ranks, bins=bins, range=(0, ranks.size))
all_counts = np.empty((len(ranks), len(bin_ary) - 1))
for idx, row in enumerate(ranks):
Expand Down
18 changes: 17 additions & 1 deletion arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import packaging
import xarray as xr
from matplotlib.colors import to_hex
from scipy.stats import mode
from scipy.stats import mode, rankdata
from scipy.interpolate import CubicSpline


from ..rcparams import rcParams
from ..stats.density_utils import kde
Expand Down Expand Up @@ -659,3 +661,17 @@ def set_bokeh_circular_ticks_labels(ax, hist, labels):
)

return ax


def compute_ranks(ary):
"""Compute ranks for continuos and discrete variables."""
if ary.dtype.kind == "i":
ary_shape = ary.shape
ary = ary.flatten()
min_ary, max_ary = min(ary), max(ary)
x = np.linspace(min_ary, max_ary, len(ary))
csi = CubicSpline(x, ary)
ary = csi(np.linspace(min_ary + 0.001, max_ary - 0.001, len(ary))).reshape(ary_shape)
ranks = rankdata(ary, method="average").reshape(ary.shape)

return ranks
18 changes: 18 additions & 0 deletions arviz/tests/base_tests/test_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
vectorized_to_hex,
xarray_to_ndarray,
xarray_var_iter,
compute_ranks,
)
from ...rcparams import rc_context
from ...stats.density_utils import get_bins
Expand Down Expand Up @@ -322,3 +323,20 @@ def test_set_bokeh_circular_ticks_labels():
assert len(renderers) == 3
assert renderers[2].data_source.data["text"] == labels
assert len(renderers[0].data_source.data["start_angle"]) == len(labels)


def test_compute_ranks():
pois_data = np.array([[5, 4, 1, 4, 0], [2, 8, 2, 1, 1]])
expected = np.array([[9.0, 7.0, 3.0, 8.0, 1.0], [5.0, 10.0, 6.0, 2.0, 4.0]])
ranks = compute_ranks(pois_data)
np.testing.assert_equal(ranks, expected)

norm_data = np.array(
[
[0.2644187, -1.3004813, -0.80428456, 1.01319068, 0.62631143],
[1.34498018, -0.13428933, -0.69855487, -0.9498981, -0.34074092],
]
)
expected = np.array([[7.0, 1.0, 3.0, 9.0, 8.0], [10.0, 6.0, 4.0, 2.0, 5.0]])
ranks = compute_ranks(norm_data)
np.testing.assert_equal(ranks, expected)

0 comments on commit 63759c2

Please sign in to comment.