Skip to content

Commit

Permalink
clean, add test and example
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Nov 3, 2022
1 parent 7be2121 commit 241d586
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## v0.x.x Unreleased

### New features
- Adds Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152))


### Maintenance and fixes
- Fix dimension ordering for `plot_trace` with divergences ([2151](https://github.com/arviz-devs/arviz/pull/2151))
Expand Down
3 changes: 2 additions & 1 deletion arviz/plots/backends/matplotlib/bfplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ def plot_bf(

ax.plot(ref_val, posterior_at_ref_val, "ko", lw=1.5)
ax.plot(ref_val, prior_at_ref_val, "ko", lw=1.5)
ax.axvline(ref_val, color="k", ls="--")
ax.set_xlabel(var_name)
ax.set_ylabel("Density")
ax.set_title(f"The Bayes Factor 10 is {bf_10:.2f}\nThe Bayes Factor 01 is {bf_01:.2f}")
ax.set_title(f"The BF_10 is {bf_10:.2f}\nThe BF_01 is {bf_01:.2f}")
plt.legend()

if backend_show(show):
Expand Down
28 changes: 22 additions & 6 deletions arviz/plots/bfplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ def plot_bf(
):
"""
Bayes Factor approximated as the Savage-Dickey density ratio.
The Bayes factor is estimated by comparing a model
against a model in which the parameter of interest has been restricted to a point-null.
The Bayes factor is estimated by comparing a model (H1) against a model in which the
parameter of interest has been restricted to be a point-null (H0). This computation
assumes the models are nested and thus H0 is a special case of H1.
Parameters
-----------
Expand Down Expand Up @@ -70,18 +74,30 @@ def plot_bf(
dict : A dictionary with BF10 (Bayes Factor 10 (H1/H0 ratio), and BF01 (H0/H1 ratio).
axes : matplotlib axes or bokeh figures
Examples
--------
TBN
Moderate evidence indicating that the parameter "a" is different from zero
.. plot::
:context: close-figs
>>> import numpy as np
>>> import arviz as az
>>> idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
... prior={"a":np.random.normal(0, 1, 5000)})
>>> az.plot_bf(idata, var_name="a", ref_val=0)
"""
posterior = extract(idata, var_names=var_name)

if ref_val > posterior.max() or ref_val < posterior.min():
raise ValueError("Reference value is out of bounds of posterior")
_log.warning(
"The reference value is outside of the posterior. "
"This translate into infinite support for H1, which is most likely an overstatement."
)

if posterior.ndim > 1:
_log.info("Posterior distribution has {posterior.ndim} dimensions")
_log.warning("Posterior distribution has {posterior.ndim} dimensions")

if prior is None:
prior = extract(idata, var_names=var_name, group="prior")
Expand All @@ -102,7 +118,7 @@ def plot_bf(
posterior_at_ref_val = (posterior == ref_val).mean()
prior_at_ref_val = (prior == ref_val).mean()

bf_10 = posterior_at_ref_val / prior_at_ref_val
bf_10 = prior_at_ref_val / posterior_at_ref_val
bf_01 = 1 / bf_10

bfplot_kwargs = dict(
Expand Down
11 changes: 11 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ...plots import (
plot_autocorr,
plot_bpv,
plot_bf,
plot_compare,
plot_density,
plot_dist,
Expand Down Expand Up @@ -1956,3 +1957,13 @@ def test_plot_ts_valueerror(multidim_models, val_err_kwargs):
idata2 = multidim_models.model_1
with pytest.raises(ValueError):
plot_ts(idata=idata2, y="y", **val_err_kwargs)


def test_plot_bf():
idata = from_dict(
posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
)
bf_dict0, _ = plot_bf(idata, var_name="a", ref_val=0)
bf_dict1, _ = plot_bf(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
assert bf_dict0["BF10"] > bf_dict0["BF01"]
assert bf_dict1["BF10"] < bf_dict1["BF01"]

0 comments on commit 241d586

Please sign in to comment.