Skip to content

Commit

Permalink
redefine jitter so limit works as no jitter
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Jan 13, 2019
1 parent f2396f6 commit 9b16165
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def plot_ppc(
else:
alpha = 0.2

assert jitter >= 0

observed = data.observed_data
posterior_predictive = data.posterior_predictive

Expand Down Expand Up @@ -228,6 +230,7 @@ def plot_ppc(
plot_kwargs={"color": "k", "linewidth": linewidth, "zorder": 3},
fill_kwargs={"alpha": 0},
ax=ax,
legend=legend,
)
else:
nbins = round(len(obs_vals) ** 0.5)
Expand Down Expand Up @@ -356,10 +359,11 @@ def plot_ppc(
"color": "C0",
"linestyle": "--",
"linewidth": linewidth,
"zorder": 2,
"zorder": 3,
},
label="Posterior predictive mean {}".format(pp_var_name),
ax=ax,
legend=legend,
)
else:
vals = pp_vals.flatten()
Expand All @@ -378,22 +382,14 @@ def plot_ppc(
)

limit = ax.get_ylim()[1] * 1.05
if jitter:
y_rows = np.linspace(0, limit, num_pp_samples + 1)
else:
y_rows = np.r_[0.0, np.linspace(0, limit, num_pp_samples)]

jitter_scale = (y_rows[2] - y_rows[1]) / 2
scale_low = -jitter_scale
scale_high = jitter_scale
y_rows = np.linspace(0, limit, num_pp_samples + 1)
jitter_scale = y_rows[1] - y_rows[0]
scale_low = 0
scale_high = jitter_scale * jitter

obs_yvals = np.zeros_like(obs_vals)
if jitter:
obs_yvals = (
y_rows[1]
+ np.random.uniform(low=scale_low, high=scale_high, size=len(obs_vals)) * jitter
)
else:
obs_yvals = np.zeros_like(obs_vals)
obs_yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(obs_vals))
ax.plot(
obs_vals,
obs_yvals,
Expand All @@ -402,16 +398,14 @@ def plot_ppc(
markersize=markersize,
alpha=alpha,
label="Observed {}".format(var_name),
zorder=3,
)

for vals, y in zip(pp_sampled_vals, y_rows[2:]):
for vals, y in zip(pp_sampled_vals, y_rows[1:]):
vals = np.array([vals]).flatten()
yvals = np.full(len(vals), y)
if jitter:
yvals = [y] * len(vals) + np.random.uniform(
low=scale_low, high=scale_high, size=len(vals)
) * jitter
else:
yvals = [y] * len(vals)
yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals))
ax.plot(vals, yvals, "o", zorder=1, color="C5", markersize=markersize, alpha=alpha)
ax.scatter([], [], color="C5", label="Posterior predictive {}".format(pp_var_name))

Expand Down

0 comments on commit 9b16165

Please sign in to comment.