-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmetrics.py
62 lines (53 loc) · 2.11 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import Any, cast
import numpy as np
import scipy.special
import sklearn.metrics
from .util import PredictionType, TaskType
def _get_labels_and_probs(
prediction: np.ndarray,
task_type: TaskType,
prediction_type: PredictionType,
) -> tuple[np.ndarray, None | np.ndarray]:
assert task_type in (TaskType.BINCLASS, TaskType.MULTICLASS)
if prediction_type == PredictionType.LABELS:
return prediction, None
elif prediction_type == PredictionType.PROBS:
probs = prediction
elif prediction_type == PredictionType.LOGITS:
probs = (
scipy.special.expit(prediction)
if task_type == TaskType.BINCLASS
else scipy.special.softmax(prediction, axis=1)
)
else:
raise ValueError(f'Unknown prediction type: {prediction_type}')
assert probs is not None
labels = np.round(probs) if task_type == TaskType.BINCLASS else probs.argmax(axis=1)
return labels.astype(np.int64), probs
def calculate_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
task_type: str | TaskType,
prediction_type: str | PredictionType,
) -> dict[str, Any]:
task_type = TaskType(task_type)
prediction_type = PredictionType(prediction_type)
if task_type == TaskType.REGRESSION:
assert prediction_type == PredictionType.LABELS
result = {
'rmse': float(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5),
'mae': float(sklearn.metrics.mean_absolute_error(y_true, y_pred)),
'r2': float(sklearn.metrics.r2_score(y_true, y_pred)),
}
else:
assert prediction_type is not None
labels, probs = _get_labels_and_probs(y_pred, task_type, prediction_type)
result = cast(
dict[str, Any],
sklearn.metrics.classification_report(y_true, labels, output_dict=True),
)
if probs is not None:
result['cross-entropy'] = sklearn.metrics.log_loss(y_true, probs)
if task_type == TaskType.BINCLASS and probs is not None:
result['roc-auc'] = sklearn.metrics.roc_auc_score(y_true, probs)
return result