From 63759c2b12447967fa8bd1fe540e3cff5448cf34 Mon Sep 17 00:00:00 2001 From: Osvaldo Martin Date: Mon, 9 Nov 2020 08:21:37 -0300 Subject: [PATCH] Add support for discrete variables in rank plots (#1433) * fix rank plots for discrete variables * add test, add compute_ranks function and update changelog * fix docstring --- CHANGELOG.md | 1 + arviz/plots/backends/bokeh/rankplot.py | 6 +++--- arviz/plots/backends/matplotlib/rankplot.py | 5 ++--- arviz/plots/plot_utils.py | 18 +++++++++++++++++- arviz/tests/base_tests/test_plot_utils.py | 18 ++++++++++++++++++ 5 files changed, 41 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1206dcd6a7..982d08120d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Add `ref_line`, `bar`, `vlines` and `marker_vlines` kwargs to `plot_rank` ([1419](https://github.com/arviz-devs/arviz/pull/1419)) * Add observed argument to (un)plot observed data in `plot_ppc` ([1422](https://github.com/arviz-devs/arviz/pull/1422)) * Add support for named dims and coordinates with multivariate observations ([1429](https://github.com/arviz-devs/arviz/pull/1429)) +* Add support for discrete variables in rank plots ([1433](https://github.com/arviz-devs/arviz/pull/1433)) * Add skipna argument to `plot_posterior` ([1432](https://github.com/arviz-devs/arviz/pull/1432)) * Make stacking the default method to compute weights in `compare` ([1438](https://github.com/arviz-devs/arviz/pull/1438)) diff --git a/arviz/plots/backends/bokeh/rankplot.py b/arviz/plots/backends/bokeh/rankplot.py index 9b287bb214..bb7fb1dd8d 100644 --- a/arviz/plots/backends/bokeh/rankplot.py +++ b/arviz/plots/backends/bokeh/rankplot.py @@ -1,12 +1,12 @@ """Bokeh rankplot.""" import numpy as np -import scipy.stats + from bokeh.models import Span from bokeh.models.annotations import Title from bokeh.models.tickers import FixedTicker from ....stats.density_utils import histogram -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size, make_label, compute_ranks from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -74,7 +74,7 @@ def plot_rank( for ax, (var_name, selection, var_data) in zip( (item for item in axes.flatten() if item is not None), plotters ): - ranks = scipy.stats.rankdata(var_data, method="average").reshape(var_data.shape) + ranks = compute_ranks(var_data) bin_ary = np.histogram_bin_edges(ranks, bins=bins, range=(0, ranks.size)) all_counts = np.empty((len(ranks), len(bin_ary) - 1)) for idx, row in enumerate(ranks): diff --git a/arviz/plots/backends/matplotlib/rankplot.py b/arviz/plots/backends/matplotlib/rankplot.py index 363f94e0b1..00642d1e14 100644 --- a/arviz/plots/backends/matplotlib/rankplot.py +++ b/arviz/plots/backends/matplotlib/rankplot.py @@ -1,10 +1,9 @@ """Matplotlib rankplot.""" import matplotlib.pyplot as plt import numpy as np -import scipy.stats from ....stats.density_utils import histogram -from ...plot_utils import _scale_fig_size, make_label +from ...plot_utils import _scale_fig_size, make_label, compute_ranks from . import backend_kwarg_defaults, backend_show, create_axes_grid @@ -66,7 +65,7 @@ def plot_rank( ) for ax, (var_name, selection, var_data) in zip(np.ravel(axes), plotters): - ranks = scipy.stats.rankdata(var_data, method="average").reshape(var_data.shape) + ranks = compute_ranks(var_data) bin_ary = np.histogram_bin_edges(ranks, bins=bins, range=(0, ranks.size)) all_counts = np.empty((len(ranks), len(bin_ary) - 1)) for idx, row in enumerate(ranks): diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index a07774f6cd..38e8a00f1e 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -9,7 +9,9 @@ import packaging import xarray as xr from matplotlib.colors import to_hex -from scipy.stats import mode +from scipy.stats import mode, rankdata +from scipy.interpolate import CubicSpline + from ..rcparams import rcParams from ..stats.density_utils import kde @@ -659,3 +661,17 @@ def set_bokeh_circular_ticks_labels(ax, hist, labels): ) return ax + + +def compute_ranks(ary): + """Compute ranks for continuos and discrete variables.""" + if ary.dtype.kind == "i": + ary_shape = ary.shape + ary = ary.flatten() + min_ary, max_ary = min(ary), max(ary) + x = np.linspace(min_ary, max_ary, len(ary)) + csi = CubicSpline(x, ary) + ary = csi(np.linspace(min_ary + 0.001, max_ary - 0.001, len(ary))).reshape(ary_shape) + ranks = rankdata(ary, method="average").reshape(ary.shape) + + return ranks diff --git a/arviz/tests/base_tests/test_plot_utils.py b/arviz/tests/base_tests/test_plot_utils.py index 76b892b802..e616f7d1fe 100644 --- a/arviz/tests/base_tests/test_plot_utils.py +++ b/arviz/tests/base_tests/test_plot_utils.py @@ -16,6 +16,7 @@ vectorized_to_hex, xarray_to_ndarray, xarray_var_iter, + compute_ranks, ) from ...rcparams import rc_context from ...stats.density_utils import get_bins @@ -322,3 +323,20 @@ def test_set_bokeh_circular_ticks_labels(): assert len(renderers) == 3 assert renderers[2].data_source.data["text"] == labels assert len(renderers[0].data_source.data["start_angle"]) == len(labels) + + +def test_compute_ranks(): + pois_data = np.array([[5, 4, 1, 4, 0], [2, 8, 2, 1, 1]]) + expected = np.array([[9.0, 7.0, 3.0, 8.0, 1.0], [5.0, 10.0, 6.0, 2.0, 4.0]]) + ranks = compute_ranks(pois_data) + np.testing.assert_equal(ranks, expected) + + norm_data = np.array( + [ + [0.2644187, -1.3004813, -0.80428456, 1.01319068, 0.62631143], + [1.34498018, -0.13428933, -0.69855487, -0.9498981, -0.34074092], + ] + ) + expected = np.array([[7.0, 1.0, 3.0, 9.0, 8.0], [10.0, 6.0, 4.0, 2.0, 5.0]]) + ranks = compute_ranks(norm_data) + np.testing.assert_equal(ranks, expected)