-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6231cf9
commit 522baa7
Showing
5 changed files
with
261 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import numbers | ||
from typing import Optional, Union | ||
|
||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
from model_diagnostics._utils._array import ( | ||
get_array_min_max, | ||
get_second_dimension, | ||
get_sorted_array_names, | ||
length_of_second_dimension, | ||
) | ||
|
||
from .scoring import ElementaryScore | ||
|
||
|
||
def plot_murphy_diagram( | ||
y_obs: npt.ArrayLike, | ||
y_pred: npt.ArrayLike, | ||
weights: Optional[npt.ArrayLike] = None, | ||
*, | ||
etas: Union[int, npt.ArrayLike] = 100, | ||
functional: str = "mean", | ||
level: float = 0.5, | ||
ax: Optional[mpl.axes.Axes] = None, | ||
): | ||
r"""Plot a Murphy diagram. | ||
A reliability diagram or calibration curve assesses auto-calibration. It plots the | ||
conditional expectation given the predictions `E(y_obs|y_pred)` (y-axis) vs the | ||
predictions `y_pred` (x-axis). | ||
The conditional expectation is estimated via isotonic regression (PAV algorithm) | ||
of `y_obs` on `y_pred`. | ||
See Notes for further details. | ||
Parameters | ||
---------- | ||
y_obs : array-like of shape (n_obs) | ||
Observed values of the response variable. | ||
For binary classification, y_obs is expected to be in the interval [0, 1]. | ||
y_pred : array-like of shape (n_obs) or (n_obs, n_models) | ||
Predicted values of the conditional expectation of Y, `E(Y|X)`. | ||
weights : array-like of shape (n_obs) or None | ||
Case weights. | ||
etas : int or array-like | ||
If an integer is given, equidistant points between min and max y values are | ||
generater. If an array-like is given, those points are used. | ||
functional : str | ||
The functional that is induced by the identification function `V`. Options are: | ||
- `"mean"`. Argument `level` is neglected. | ||
- `"median"`. Argument `level` is neglected. | ||
- `"expectile"` | ||
- `"quantile"` | ||
level : float | ||
The level of the expectile of quantile. (Often called \(\alpha\).) | ||
It must be `0 < level < 1`. | ||
`level=0.5` and `functional="expectile"` gives the mean. | ||
`level=0.5` and `functional="quantile"` gives the median. | ||
ax : matplotlib.axes.Axes | ||
Axes object to draw the plot onto, otherwise uses the current Axes. | ||
Returns | ||
------- | ||
ax | ||
Notes | ||
----- | ||
The expectation conditional on the predictions is \(E(Y|y_{pred})\). This object is | ||
estimated by the pool-adjacent violator (PAV) algorithm, which has very desirable | ||
properties: | ||
- It is non-parametric without any tuning parameter. Thus, the results are | ||
easily reproducible. | ||
- Optimal selection of bins | ||
- Statistical consistent estimator | ||
For details, refer to [Dimitriadis2021]. | ||
References | ||
---------- | ||
`[Ehm2015]` | ||
: W. Ehm, T. Gneiting, A. Jordan, F. Krüger. | ||
"Of Quantiles and Expectiles: Consistent Scoring Functions, Choquet | ||
Representations, and Forecast Rankings". | ||
[arxiv:1503.08195](https://arxiv.org/abs/1503.08195). | ||
""" | ||
if ax is None: | ||
ax = plt.gca() | ||
|
||
if (n_cols := length_of_second_dimension(y_obs)) > 0: | ||
if n_cols == 1: | ||
y_obs = get_second_dimension(y_obs, 0) | ||
else: | ||
msg = ( | ||
f"Array-like y_obs has more than 2 dimensions, y_obs.shape[1]={n_cols}" | ||
) | ||
raise ValueError(msg) | ||
|
||
y_pred_min, y_pred_max = get_array_min_max(y_pred) | ||
y_obs_min, y_obs_max = get_array_min_max(y_obs) | ||
y_min, y_max = min(y_pred_min, y_obs_min), max(y_pred_max, y_obs_max) | ||
|
||
if y_min == y_max: | ||
msg = "All values y_obs and y_pred are one single and same value." | ||
raise ValueError(msg) | ||
elif isinstance(etas, numbers.Integral): | ||
etas = np.linspace(y_min, y_max, num=etas, endpoint=True) | ||
else: | ||
etas = np.asarray(etas).astype(float) | ||
if etas.ndim > 1: | ||
etas = etas.reshape(max(etas.shape)) | ||
|
||
def elementary_score(y_obs, y_pred, weights, eta): | ||
sf = ElementaryScore(eta, functional=functional, level=level) | ||
return sf(y_obs=y_obs, y_pred=y_pred, weights=weights) | ||
|
||
n_pred = length_of_second_dimension(y_pred) | ||
pred_names, _ = get_sorted_array_names(y_pred) | ||
|
||
for i in range(len(pred_names)): | ||
y_pred_i = y_pred if n_pred == 0 else get_second_dimension(y_pred, i) | ||
|
||
y_plot = [ | ||
elementary_score(y_obs=y_obs, y_pred=y_pred_i, weights=weights, eta=eta) | ||
for eta in etas | ||
] | ||
label = pred_names[i] if n_pred >= 2 else None | ||
ax.plot(etas, y_plot, label=label) | ||
|
||
title = "Murphy Diagram" | ||
ax.set(xlabel="eta", ylabel="score") | ||
|
||
if n_pred >= 2: | ||
ax.set_title(title) | ||
ax.legend() | ||
else: | ||
y_pred_i = y_pred if n_pred == 0 else get_second_dimension(y_pred, i) | ||
if len(pred_names[0]) > 0: | ||
ax.set_title(title + " " + pred_names[0]) | ||
else: | ||
ax.set_title(title) | ||
|
||
return ax |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from sklearn.datasets import make_classification | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.model_selection import train_test_split | ||
|
||
from model_diagnostics.scoring import plot_murphy_diagram | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("param", "value", "msg"), | ||
[ | ||
("etas", [[1, 2], [3, 4]], "cannot reshape array of size 4 into shape"), | ||
( | ||
"y", | ||
[[1, 1], [1, 1]], | ||
"All values y_obs and y_pred are one single and same value", | ||
), | ||
], | ||
) | ||
def test_plot_murphy_diagram_raises(param, value, msg): | ||
"""Test that plot_murphy_diagram raises errors.""" | ||
if param == "y": | ||
y_obs, y_pred = value[0], value[1] | ||
kwargs = {} | ||
else: | ||
y_obs = [0, 1] | ||
y_pred = [-1, 1] | ||
kwargs = {param: value} | ||
with pytest.raises(ValueError, match=msg): | ||
plot_murphy_diagram(y_obs=y_obs, y_pred=y_pred, **kwargs) | ||
|
||
|
||
def test_plot_murphy_diagram_raises_y_obs_multdim(): | ||
"""Test that plot_murphy_diagram raises errors for y_obs.ndim > 1.""" | ||
y_obs = [[0], [1]] | ||
y_pred = [-1, 1] | ||
plot_murphy_diagram(y_obs=y_obs, y_pred=y_pred) | ||
y_obs = [[0, 1], [1, 2]] | ||
with pytest.raises(ValueError, match="Array-like y_obs has more than 2 dimensions"): | ||
plot_murphy_diagram(y_obs=y_obs, y_pred=y_pred) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("functional", "level"), [("expectile", 0.5), ("quantile", 0.8)] | ||
) | ||
@pytest.mark.parametrize("etas", [10, np.arange(10)]) | ||
@pytest.mark.parametrize("weights", [None, True]) | ||
@pytest.mark.parametrize("ax", [None, plt.subplots()[1]]) | ||
def test_plot_murphy_diagram(functional, level, etas, weights, ax): | ||
"""Test that plot_murphy_diagram works.""" | ||
X, y = make_classification(random_state=42, n_classes=2) | ||
if weights is None: | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) | ||
w_train, w_test = None, None | ||
else: | ||
weights = np.random.default_rng(42).integers(low=0, high=10, size=y.shape) | ||
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split( | ||
X, y, weights, random_state=0 | ||
) | ||
clf = LogisticRegression(solver="newton-cholesky") | ||
clf.fit(X_train, y_train, w_train) | ||
y_pred = clf.predict_proba(X_test)[:, 1] | ||
plt_ax = plot_murphy_diagram( | ||
y_obs=y_test, | ||
y_pred=y_pred, | ||
weights=w_test, | ||
etas=etas, | ||
functional=functional, | ||
level=level, | ||
ax=ax, | ||
) | ||
|
||
if ax is not None: | ||
assert ax is plt_ax | ||
assert plt_ax.get_xlabel() == "eta" | ||
assert plt_ax.get_ylabel() == "score" | ||
assert plt_ax.get_title() == "Murphy Diagram" | ||
|
||
plt_ax = plot_murphy_diagram( | ||
y_obs=y_test, | ||
y_pred=pd.Series(y_pred, name="simple"), | ||
weights=w_test, | ||
etas=etas, | ||
functional=functional, | ||
level=level, | ||
ax=ax, | ||
) | ||
assert plt_ax.get_title() == "Murphy Diagram simple" | ||
|
||
|
||
def test_plot_murphy_diagram_multiple_predictions(): | ||
"""Test that plot_murphy_diagram works for multiple predictions.""" | ||
n_obs = 10 | ||
y_obs = np.arange(n_obs) | ||
y_obs[::2] = 0 | ||
y_pred = pd.DataFrame({"model_2": np.ones(n_obs), "model_1": 3 * np.ones(n_obs)}) | ||
fig, ax = plt.subplots() | ||
plt_ax = plot_murphy_diagram( | ||
y_obs=y_obs, | ||
y_pred=y_pred, | ||
ax=ax, | ||
) | ||
assert plt_ax.get_title() == "Murphy Diagram" | ||
legend_text = plt_ax.get_legend().get_texts() | ||
assert len(legend_text) == 2 | ||
assert legend_text[0].get_text() == "model_2" | ||
assert legend_text[1].get_text() == "model_1" |