Skip to content

Commit

Permalink
use interpolation instead of randomization
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Sep 22, 2020
1 parent 46a24eb commit 0bba9a5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 44 deletions.
32 changes: 10 additions & 22 deletions arviz/plots/backends/bokeh/bpvplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def plot_bpv(
bpv,
plot_mean,
reference,
mse,
n_ref,
hdi_prob,
color,
Expand Down Expand Up @@ -88,17 +89,9 @@ def plot_bpv(
pp_vals = pp_vals.reshape(total_pp_samples, -1)

if kind == "p_value":
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
for i in range(n_ref):
obs_vals = obs_vals + np.random.uniform(-0.01, 0.01, size=obs_vals.shape)
pp_vals = pp_vals + np.random.uniform(-0.01, 0.01, size=pp_vals.shape)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=1)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.line(x_s, tstat_pit_dens, line_width=linewidth, line_color=color)
else:
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.line(x_s, tstat_pit_dens, line_width=linewidth, line_color=color)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.line(x_s, tstat_pit_dens, line_width=linewidth, line_color=color)
if reference is not None:
dist = stats.beta(obs_vals.size / 2, obs_vals.size / 2)
if reference == "analytical":
Expand All @@ -120,17 +113,9 @@ def plot_bpv(
)

elif kind == "u_value":
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
for i in range(n_ref):
obs_vals = obs_vals + np.random.uniform(-0.01, 0.01, size=obs_vals.shape)
pp_vals = pp_vals + np.random.uniform(-0.01, 0.01, size=pp_vals.shape)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, color=color)
else:
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, color=color)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.line(x_s, tstat_pit_dens, color=color)
if reference is not None:
if reference == "analytical":
n_obs = obs_vals.size
Expand All @@ -151,6 +136,9 @@ def plot_bpv(
x_ss, u_dens = sample_reference_distribution(dist, (tstat_pit_dens.size, n_ref))
for x_ss_i, u_dens_i in zip(x_ss.T, u_dens.T):
ax_i.line(x_ss_i, u_dens_i, line_width=linewidth, **plot_ref_kwargs)
if mse:
ax_i.line(0, 0, legend_label=f"mse={np.mean((1 - tstat_pit_dens)**2) * 100:.2f}")

ax_i.line(0, 0)
else:
if t_stat in ["mean", "median", "std"]:
Expand Down
42 changes: 20 additions & 22 deletions arviz/plots/backends/matplotlib/bpvplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import CubicSpline

from ....stats.density_utils import kde
from ...kdeplot import plot_kde
Expand All @@ -27,6 +28,7 @@ def plot_bpv(
bpv,
plot_mean,
reference,
mse,
n_ref,
hdi_prob,
color,
Expand Down Expand Up @@ -86,18 +88,19 @@ def plot_bpv(
obs_vals = obs_vals.flatten()
pp_vals = pp_vals.reshape(total_pp_samples, -1)

if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
x = np.linspace(0, 1, len(obs_vals))
csi = CubicSpline(x, obs_vals)
obs_vals = csi(np.linspace(0.001, 0.999, len(obs_vals)))

x = np.linspace(0, 1, len(pp_vals))
csi = CubicSpline(x, pp_vals)
pp_vals = csi(np.linspace(0.001, 0.999, len(pp_vals)))

if kind == "p_value":
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
for i in range(n_ref):
obs_vals = obs_vals + np.random.uniform(-0.01, 0.01, size=obs_vals.shape)
pp_vals = pp_vals + np.random.uniform(-0.01, 0.01, size=pp_vals.shape)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=1)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, linewidth=linewidth, color=color)
else:
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, linewidth=linewidth, color=color)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, linewidth=linewidth, color=color)
ax_i.set_yticks([])
if reference is not None:
dist = stats.beta(obs_vals.size / 2, obs_vals.size / 2)
Expand All @@ -118,17 +121,9 @@ def plot_bpv(
ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)

elif kind == "u_value":
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
for i in range(n_ref):
obs_vals = obs_vals + np.random.uniform(-0.01, 0.01, size=obs_vals.shape)
pp_vals = pp_vals + np.random.uniform(-0.01, 0.01, size=pp_vals.shape)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, color=color)
else:
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, color=color)
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
x_s, tstat_pit_dens = kde(tstat_pit)
ax_i.plot(x_s, tstat_pit_dens, color=color)
if reference is not None:
if reference == "analytical":
n_obs = obs_vals.size
Expand All @@ -140,6 +135,9 @@ def plot_bpv(
dist = stats.uniform(0, 1)
x_ss, u_dens = sample_reference_distribution(dist, (tstat_pit_dens.size, n_ref))
ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)
if mse:
ax_i.plot(0, 0, label=f"mse={np.mean((1 - tstat_pit_dens)**2) * 100:.2f}")
ax_i.legend()

ax_i.set_ylim(0, None)
ax_i.set_xlim(0, 1)
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/bpvplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def plot_bpv(
bpv=True,
plot_mean=True,
reference="analytical",
mse=False,
n_ref=100,
hdi_prob=0.94,
color="C0",
Expand Down Expand Up @@ -62,6 +63,9 @@ def plot_bpv(
How to compute the distributions used as reference for u_values or p_values. Allowed values
are "analytical" (default) and "samples". Use `None` to do not plot any reference.
Defaults to "samples".
mse :bool
Show scaled mean square error between uniform distribution and marginal p_value
distribution. Defaults to False.
n_ref : int, optional
Number of reference distributions to sample when `reference=samples`. Defaults to 100.
hdi_prob: float, optional
Expand Down Expand Up @@ -245,6 +249,7 @@ def plot_bpv(
bpv=bpv,
t_stat=t_stat,
reference=reference,
mse=mse,
n_ref=n_ref,
hdi_prob=hdi_prob,
plot_mean=plot_mean,
Expand Down

0 comments on commit 0bba9a5

Please sign in to comment.