diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index 83e8e14d47..a7778953ec 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -994,7 +994,7 @@ def _multichain_statistics(ary): """ ary = np.atleast_2d(ary) if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): - return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan + return np.nan, np.nan, np.nan, np.nan, np.nan # ess mean ess_mean_value = _ess_mean(ary) @@ -1033,8 +1033,6 @@ def _multichain_statistics(ary): return ( mcse_mean_value, mcse_sd_value, - ess_mean_value, - ess_sd_value, ess_bulk_value, ess_tail_value, rhat_value, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 1df477f98b..f9494a34a4 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -1261,11 +1261,11 @@ def summary( circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True) if kind in ["all", "diagnostics"]: - mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc( - _make_ufunc(_multichain_statistics, n_output=7, ravel=False), + mcse_mean, mcse_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc( + _make_ufunc(_multichain_statistics, n_output=5, ravel=False), dataset, input_core_dims=(("chain", "draw"),), - output_core_dims=tuple([] for _ in range(7)), + output_core_dims=tuple([] for _ in range(5)), ) # Combine metrics @@ -1280,8 +1280,6 @@ def summary( f"hdi_{100 * (1 - alpha / 2):g}%", "mcse_mean", "mcse_sd", - "ess_mean", - "ess_sd", "ess_bulk", "ess_tail", "r_hat", @@ -1294,8 +1292,6 @@ def summary( hdi_higher, mcse_mean, mcse_sd, - ess_mean, - ess_sd, ess_bulk, ess_tail, r_hat, @@ -1304,7 +1300,7 @@ def summary( metrics_ = (mean, sd, hdi_lower, hdi_higher) metrics_names_ = metrics_names_[:4] elif kind == "diagnostics": - metrics_ = (mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat) + metrics_ = (mcse_mean, mcse_sd, ess_bulk, ess_tail, r_hat) metrics_names_ = metrics_names_[4:] metrics.extend(metrics_) metric_names.extend(metrics_names_) @@ -1358,11 +1354,7 @@ def summary( elif round_to not in ("None", "none") and (fmt.lower() in ("long", "wide")): # Don't round xarray object by default (even with "none") decimals = { - col: 3 - if col not in {"ess_mean", "ess_sd", "ess_bulk", "ess_tail", "r_hat"} - else 2 - if col == "r_hat" - else 0 + col: 3 if col not in {"ess_bulk", "ess_tail", "r_hat"} else 2 if col == "r_hat" else 0 for col in summary_df.columns } summary_df = summary_df.round(decimals) diff --git a/arviz/tests/base_tests/test_diagnostics.py b/arviz/tests/base_tests/test_diagnostics.py index a843fd8356..135cd02e4b 100644 --- a/arviz/tests/base_tests/test_diagnostics.py +++ b/arviz/tests/base_tests/test_diagnostics.py @@ -426,16 +426,12 @@ def test_multichain_summary_array(self, draws, chains): mcse_mean_hat = mcse(ary, method="mean") mcse_sd_hat = mcse(ary, method="sd") - ess_mean_hat = ess(ary, method="mean") - ess_sd_hat = ess(ary, method="sd") ess_bulk_hat = ess(ary, method="bulk") ess_tail_hat = ess(ary, method="tail") rhat_hat = _rhat_rank(ary) ( mcse_mean_hat_, mcse_sd_hat_, - ess_mean_hat_, - ess_sd_hat_, ess_bulk_hat_, ess_tail_hat_, rhat_hat_, @@ -445,8 +441,6 @@ def test_multichain_summary_array(self, draws, chains): ( mcse_mean_hat, mcse_sd_hat, - ess_mean_hat, - ess_sd_hat, ess_bulk_hat, ess_tail_hat, rhat_hat, @@ -456,8 +450,6 @@ def test_multichain_summary_array(self, draws, chains): ( mcse_mean_hat_, mcse_sd_hat_, - ess_mean_hat_, - ess_sd_hat_, ess_bulk_hat_, ess_tail_hat_, rhat_hat_, @@ -466,8 +458,6 @@ def test_multichain_summary_array(self, draws, chains): else: assert_almost_equal(mcse_mean_hat, mcse_mean_hat_) assert_almost_equal(mcse_sd_hat, mcse_sd_hat_) - assert_almost_equal(ess_mean_hat, ess_mean_hat_) - assert_almost_equal(ess_sd_hat, ess_sd_hat_) assert_almost_equal(ess_bulk_hat, ess_bulk_hat_) assert_almost_equal(ess_tail_hat, ess_tail_hat_) if chains in (None, 1): diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 34a1c9a7a2..4416ed7a70 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -240,8 +240,6 @@ def test_summary_wrong_group(centered_eight): "hdi_97%", "mcse_mean", "mcse_sd", - "ess_mean", - "ess_sd", "ess_bulk", "ess_tail", "r_hat",