Skip to content

Commit

Permalink
remove ess_mean and ess_sd from summary (arviz-devs#1539)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored and utkarsh-maheshwari committed May 27, 2021
1 parent a3d5d41 commit 83db3d4
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 28 deletions.
4 changes: 1 addition & 3 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def _multichain_statistics(ary):
"""
ary = np.atleast_2d(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
return np.nan, np.nan, np.nan, np.nan, np.nan
# ess mean
ess_mean_value = _ess_mean(ary)

Expand Down Expand Up @@ -1033,8 +1033,6 @@ def _multichain_statistics(ary):
return (
mcse_mean_value,
mcse_sd_value,
ess_mean_value,
ess_sd_value,
ess_bulk_value,
ess_tail_value,
rhat_value,
Expand Down
18 changes: 5 additions & 13 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,11 +1261,11 @@ def summary(
circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True)

if kind in ["all", "diagnostics"]:
mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc(
_make_ufunc(_multichain_statistics, n_output=7, ravel=False),
mcse_mean, mcse_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc(
_make_ufunc(_multichain_statistics, n_output=5, ravel=False),
dataset,
input_core_dims=(("chain", "draw"),),
output_core_dims=tuple([] for _ in range(7)),
output_core_dims=tuple([] for _ in range(5)),
)

# Combine metrics
Expand All @@ -1280,8 +1280,6 @@ def summary(
f"hdi_{100 * (1 - alpha / 2):g}%",
"mcse_mean",
"mcse_sd",
"ess_mean",
"ess_sd",
"ess_bulk",
"ess_tail",
"r_hat",
Expand All @@ -1294,8 +1292,6 @@ def summary(
hdi_higher,
mcse_mean,
mcse_sd,
ess_mean,
ess_sd,
ess_bulk,
ess_tail,
r_hat,
Expand All @@ -1304,7 +1300,7 @@ def summary(
metrics_ = (mean, sd, hdi_lower, hdi_higher)
metrics_names_ = metrics_names_[:4]
elif kind == "diagnostics":
metrics_ = (mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat)
metrics_ = (mcse_mean, mcse_sd, ess_bulk, ess_tail, r_hat)
metrics_names_ = metrics_names_[4:]
metrics.extend(metrics_)
metric_names.extend(metrics_names_)
Expand Down Expand Up @@ -1358,11 +1354,7 @@ def summary(
elif round_to not in ("None", "none") and (fmt.lower() in ("long", "wide")):
# Don't round xarray object by default (even with "none")
decimals = {
col: 3
if col not in {"ess_mean", "ess_sd", "ess_bulk", "ess_tail", "r_hat"}
else 2
if col == "r_hat"
else 0
col: 3 if col not in {"ess_bulk", "ess_tail", "r_hat"} else 2 if col == "r_hat" else 0
for col in summary_df.columns
}
summary_df = summary_df.round(decimals)
Expand Down
10 changes: 0 additions & 10 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,16 +426,12 @@ def test_multichain_summary_array(self, draws, chains):

mcse_mean_hat = mcse(ary, method="mean")
mcse_sd_hat = mcse(ary, method="sd")
ess_mean_hat = ess(ary, method="mean")
ess_sd_hat = ess(ary, method="sd")
ess_bulk_hat = ess(ary, method="bulk")
ess_tail_hat = ess(ary, method="tail")
rhat_hat = _rhat_rank(ary)
(
mcse_mean_hat_,
mcse_sd_hat_,
ess_mean_hat_,
ess_sd_hat_,
ess_bulk_hat_,
ess_tail_hat_,
rhat_hat_,
Expand All @@ -445,8 +441,6 @@ def test_multichain_summary_array(self, draws, chains):
(
mcse_mean_hat,
mcse_sd_hat,
ess_mean_hat,
ess_sd_hat,
ess_bulk_hat,
ess_tail_hat,
rhat_hat,
Expand All @@ -456,8 +450,6 @@ def test_multichain_summary_array(self, draws, chains):
(
mcse_mean_hat_,
mcse_sd_hat_,
ess_mean_hat_,
ess_sd_hat_,
ess_bulk_hat_,
ess_tail_hat_,
rhat_hat_,
Expand All @@ -466,8 +458,6 @@ def test_multichain_summary_array(self, draws, chains):
else:
assert_almost_equal(mcse_mean_hat, mcse_mean_hat_)
assert_almost_equal(mcse_sd_hat, mcse_sd_hat_)
assert_almost_equal(ess_mean_hat, ess_mean_hat_)
assert_almost_equal(ess_sd_hat, ess_sd_hat_)
assert_almost_equal(ess_bulk_hat, ess_bulk_hat_)
assert_almost_equal(ess_tail_hat, ess_tail_hat_)
if chains in (None, 1):
Expand Down
2 changes: 0 additions & 2 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,6 @@ def test_summary_wrong_group(centered_eight):
"hdi_97%",
"mcse_mean",
"mcse_sd",
"ess_mean",
"ess_sd",
"ess_bulk",
"ess_tail",
"r_hat",
Expand Down

0 comments on commit 83db3d4

Please sign in to comment.