Skip to content

Commit

Permalink
drop duplicate metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Dec 8, 2020
1 parent aeaa6b2 commit aa0ee8e
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 123 deletions.
5 changes: 2 additions & 3 deletions pytorch_lightning/metrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def __init__(
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `AveragePrecision` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
'Metric `AveragePrecision` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def __init__(
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `PrecisionRecallCurve` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/metrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ def __init__(
self.add_state("target", default=[], dist_reduce_fx=None)

rank_zero_warn(
'Metric `ROC` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.'
'Metric `ROC` will save all targets and predictions in buffer.'
' For large datasets this may lead to large memory footprint.'
)

def update(self, preds: torch.Tensor, target: torch.Tensor):
Expand Down
109 changes: 5 additions & 104 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.functional import roc
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce
from pytorch_lightning.utilities import rank_zero_warn

Expand Down Expand Up @@ -332,107 +334,6 @@ def recall(
num_classes=num_classes, class_reduction=class_reduction)[1]


def _binary_clf_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
"""
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)

# remove class dimension if necessary
if pred.ndim > target.ndim:
pred = pred[:, 0]
desc_score_indices = torch.argsort(pred, descending=True)

pred = pred[desc_score_indices]
target = target[desc_score_indices]

if sample_weight is not None:
weight = sample_weight[desc_score_indices]
else:
weight = 1.

# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)

target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]

if sample_weight is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps

return fps, tps, pred[threshold_idxs]


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def __roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
.. warning:: Deprecated
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class
Return:
false-positive rate (fpr), true-positive rate (tpr), thresholds
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = __roc(x, y)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)

# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")

fpr = fps / fps[-1]

if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")

tpr = tps / tps[-1]

return fpr, tpr, thresholds


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def __multiclass_roc(
pred: torch.Tensor,
Expand Down Expand Up @@ -474,7 +375,7 @@ def __multiclass_roc(
for c in range(num_classes):
pred_c = pred[:, c]

class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))
class_roc_vals.append(roc(preds=pred_c, target=target, sample_weights=sample_weight, pos_label=c, num_classes=1))

return tuple(class_roc_vals)

Expand Down Expand Up @@ -589,7 +490,7 @@ def auroc(

@auc_decorator(reorder=True)
def _auroc(pred, target, sample_weight, pos_label):
return __roc(pred, target, sample_weight, pos_label)
return roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, num_classes=1)

return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)

Expand Down Expand Up @@ -642,7 +543,7 @@ def multiclass_auroc(

@multiclass_auc_decorator(reorder=False)
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return __multiclass_roc(pred, target, sample_weight, num_classes)
return roc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes)

class_aurocs = _multiclass_auroc(pred=pred, target=target,
sample_weight=sample_weight,
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/metrics/functional/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tup
return preds, target


def _explained_variance_compute(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
def _explained_variance_compute(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
diff_avg = torch.mean(target - preds, dim=0)
numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0)

Expand All @@ -52,10 +53,11 @@ def _explained_variance_compute(preds: torch.Tensor,
return torch.sum(denominator / denom_sum * output_scores)


def explained_variance(preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
def explained_variance(
preds: torch.Tensor,
target: torch.Tensor,
multioutput: str = 'uniform_average',
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""
Computes explained variance.
Expand Down
4 changes: 2 additions & 2 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
accuracy,
precision,
recall,
_binary_clf_curve,
dice_score,
auroc,
multiclass_auroc,
auc,
iou,
)
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical


Expand Down Expand Up @@ -222,7 +222,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight

fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)

assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
Expand Down

0 comments on commit aa0ee8e

Please sign in to comment.