From aa0ee8e4c4ae542cc13f6bfd971695ba6787de96 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 8 Dec 2020 13:09:47 +0100 Subject: [PATCH] drop duplicate metrics --- .../classification/average_precision.py | 5 +- .../classification/precision_recall_curve.py | 5 +- .../metrics/classification/roc.py | 5 +- .../metrics/functional/classification.py | 109 +----------------- .../metrics/functional/explained_variance.py | 18 +-- .../metrics/functional/test_classification.py | 4 +- 6 files changed, 23 insertions(+), 123 deletions(-) diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index 0a8a952470dbc..33878cb48965d 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -92,9 +92,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `AveragePrecision` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `AveragePrecision` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 052a25a7a977d..620904898535d 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -102,9 +102,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `PrecisionRecallCurve` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 89e8265b19fc1..2b7d82488b491 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -105,9 +105,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `ROC` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `ROC` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 1b407f2a7ec9e..7e5584659076a 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -17,6 +17,8 @@ import torch from torch.nn import functional as F +from pytorch_lightning.metrics.functional import roc +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce from pytorch_lightning.utilities import rank_zero_warn @@ -332,107 +334,6 @@ def recall( num_classes=num_classes, class_reduction=class_reduction)[1] -def _binary_clf_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py - """ - if sample_weight is not None and not isinstance(sample_weight, torch.Tensor): - sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float) - - # remove class dimension if necessary - if pred.ndim > target.ndim: - pred = pred[:, 0] - desc_score_indices = torch.argsort(pred, descending=True) - - pred = pred[desc_score_indices] - target = target[desc_score_indices] - - if sample_weight is not None: - weight = sample_weight[desc_score_indices] - else: - weight = 1. - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) - - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weight is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] - else: - fps = 1 + threshold_idxs - tps - - return fps, tps, pred[threshold_idxs] - - -# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - - .. warning:: Deprecated - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - false-positive rate (fpr), true-positive rate (tpr), thresholds - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = __roc(x, y) - >>> fpr - tensor([0., 0., 0., 0., 1.]) - >>> tpr - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) - >>> thresholds - tensor([4, 3, 2, 1, 0]) - - """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - - # Add an extra threshold position - # to make sure that the curve starts at (0, 0) - tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) - fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) - thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) - - if fps[-1] <= 0: - raise ValueError("No negative samples in targets, false positive value should be meaningless") - - fpr = fps / fps[-1] - - if tps[-1] <= 0: - raise ValueError("No positive samples in targets, true positive value should be meaningless") - - tpr = tps / tps[-1] - - return fpr, tpr, thresholds - - # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py def __multiclass_roc( pred: torch.Tensor, @@ -474,7 +375,7 @@ def __multiclass_roc( for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) + class_roc_vals.append(roc(preds=pred_c, target=target, sample_weights=sample_weight, pos_label=c, num_classes=1)) return tuple(class_roc_vals) @@ -589,7 +490,7 @@ def auroc( @auc_decorator(reorder=True) def _auroc(pred, target, sample_weight, pos_label): - return __roc(pred, target, sample_weight, pos_label) + return roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, num_classes=1) return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -642,7 +543,7 @@ def multiclass_auroc( @multiclass_auc_decorator(reorder=False) def _multiclass_auroc(pred, target, sample_weight, num_classes): - return __multiclass_roc(pred, target, sample_weight, num_classes) + return roc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 012e1486ebb1f..20b38c58a2a6b 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -23,10 +23,11 @@ def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tup return preds, target -def _explained_variance_compute(preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', - ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +def _explained_variance_compute( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: diff_avg = torch.mean(target - preds, dim=0) numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0) @@ -52,10 +53,11 @@ def _explained_variance_compute(preds: torch.Tensor, return torch.sum(denominator / denom_sum * output_scores) -def explained_variance(preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', - ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +def explained_variance( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ Computes explained variance. diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index f7bd7d558f5b4..a6fbe9e849785 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -17,13 +17,13 @@ accuracy, precision, recall, - _binary_clf_curve, dice_score, auroc, multiclass_auroc, auc, iou, ) +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical @@ -222,7 +222,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): if sample_weight is not None: sample_weight = torch.ones_like(pred) * sample_weight - fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label) + fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) assert isinstance(tps, torch.Tensor) assert isinstance(fps, torch.Tensor)