Skip to content

Commit 9ba92e1

Browse files
plot_cap() can have different model parameters as output (#627)
* plot_cap gains a new argument 'target' * Add seed to test and update changelog * Update tests/test_plots.py Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com> * Update tests/test_plots.py Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com> * Update tests/test_plots.py Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com> * remove old 'print' * fix indentation Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>
1 parent 6fec04f commit 9ba92e1

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

Changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
* Refactored the codebase to support distributional models (#607)
88
* Added a default method to handle posterior predictive sampling for custom families (#625)
9+
* `plot_cap()` gains a new argument `target` that allows to plot different parameters of the response distribution (#627)
910

1011
### Maintenance and fixes
1112

bambi/plots/plot_cap.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def plot_cap(
128128
model,
129129
idata,
130130
covariates,
131+
target="mean",
131132
use_hdi=True,
132133
hdi_prob=None,
133134
transforms=None,
@@ -152,6 +153,8 @@ def plot_cap(
152153
and the third is mapped to different plot panels.
153154
If a dictionary, keys must be taken from ("horizontal", "color", "panel") and the values
154155
are the names of the variables.
156+
target : str
157+
Which model parameter to plot. Defaults to 'mean'.
155158
use_hdi : bool, optional
156159
Whether to compute the highest density interval (defaults to True) or the quantiles.
157160
hdi_prob : float, optional
@@ -203,11 +206,11 @@ def plot_cap(
203206
response_name = get_aliased_name(model.response_component.response_term)
204207
response_transform = transforms.get(response_name, identity)
205208

206-
y_hat = response_transform(idata.posterior[f"{response_name}_mean"])
209+
y_hat = response_transform(idata.posterior[f"{response_name}_{target}"])
207210
y_hat_mean = y_hat.mean(("chain", "draw"))
208211

209212
if use_hdi:
210-
y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_mean"].T
213+
y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_{target}"].T
211214
else:
212215
lower_bound = round((1 - hdi_prob) / 2, 4)
213216
upper_bound = 1 - lower_bound
@@ -222,7 +225,6 @@ def plot_cap(
222225
axes = np.atleast_1d(axes)
223226
else:
224227
axes = np.atleast_1d(ax)
225-
print(axes)
226228
if isinstance(axes[0], np.ndarray):
227229
fig = axes[0][0].get_figure()
228230
else:
@@ -238,8 +240,9 @@ def plot_cap(
238240
else:
239241
raise ValueError("Main covariate must be numeric or categoric.")
240242

243+
ylabel = response_name if target == "mean" else target
241244
for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local
242-
ax.set(xlabel=main, ylabel=response_name)
245+
ax.set(xlabel=main, ylabel=ylabel)
243246

244247
return fig, axes
245248

tests/test_plots.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import matplotlib.pyplot as plt
66
import pytest
77

8-
from bambi.models import Model
8+
from bambi.models import Model, Formula
99
from bambi.plots import plot_cap
1010

1111

@@ -150,3 +150,23 @@ def test_transforms(mtcars):
150150

151151
transforms = {"mpg": np.log, "hp": np.log}
152152
plot_cap(model, idata, ["hp"], transforms=transforms)
153+
154+
155+
def test_multiple_outputs():
156+
"""Test plot cap default and specified values for target argument"""
157+
rng = np.random.default_rng(121195)
158+
N = 200
159+
a, b = 0.5, 1.1
160+
x = rng.uniform(-1.5, 1.5, N)
161+
shape = np.exp(0.3 + x * 0.5 + rng.normal(scale=0.1, size=N))
162+
y = rng.gamma(shape, np.exp(a + b * x) / shape, N)
163+
data_gamma = pd.DataFrame({"x": x, "y": y})
164+
165+
166+
formula = Formula("y ~ x", "alpha ~ x")
167+
model = Model(formula, data_gamma, family="gamma")
168+
idata = model.fit(tune=100, draws=100, random_seed=1234)
169+
# Test default target
170+
plot_cap(model, idata, "x")
171+
# Test user supplied target argument
172+
plot_cap(model, idata, "x", "alpha")

0 commit comments

Comments
 (0)