@@ -128,6 +128,7 @@ def plot_cap(
128
128
model ,
129
129
idata ,
130
130
covariates ,
131
+ target = "mean" ,
131
132
use_hdi = True ,
132
133
hdi_prob = None ,
133
134
transforms = None ,
@@ -152,6 +153,8 @@ def plot_cap(
152
153
and the third is mapped to different plot panels.
153
154
If a dictionary, keys must be taken from ("horizontal", "color", "panel") and the values
154
155
are the names of the variables.
156
+ target : str
157
+ Which model parameter to plot. Defaults to 'mean'.
155
158
use_hdi : bool, optional
156
159
Whether to compute the highest density interval (defaults to True) or the quantiles.
157
160
hdi_prob : float, optional
@@ -203,11 +206,11 @@ def plot_cap(
203
206
response_name = get_aliased_name (model .response_component .response_term )
204
207
response_transform = transforms .get (response_name , identity )
205
208
206
- y_hat = response_transform (idata .posterior [f"{ response_name } _mean " ])
209
+ y_hat = response_transform (idata .posterior [f"{ response_name } _ { target } " ])
207
210
y_hat_mean = y_hat .mean (("chain" , "draw" ))
208
211
209
212
if use_hdi :
210
- y_hat_bounds = az .hdi (y_hat , hdi_prob )[f"{ response_name } _mean " ].T
213
+ y_hat_bounds = az .hdi (y_hat , hdi_prob )[f"{ response_name } _ { target } " ].T
211
214
else :
212
215
lower_bound = round ((1 - hdi_prob ) / 2 , 4 )
213
216
upper_bound = 1 - lower_bound
@@ -222,7 +225,6 @@ def plot_cap(
222
225
axes = np .atleast_1d (axes )
223
226
else :
224
227
axes = np .atleast_1d (ax )
225
- print (axes )
226
228
if isinstance (axes [0 ], np .ndarray ):
227
229
fig = axes [0 ][0 ].get_figure ()
228
230
else :
@@ -238,8 +240,9 @@ def plot_cap(
238
240
else :
239
241
raise ValueError ("Main covariate must be numeric or categoric." )
240
242
243
+ ylabel = response_name if target == "mean" else target
241
244
for ax in axes .ravel (): # pylint: disable = redefined-argument-from-local
242
- ax .set (xlabel = main , ylabel = response_name )
245
+ ax .set (xlabel = main , ylabel = ylabel )
243
246
244
247
return fig , axes
245
248
0 commit comments