Skip to content

Commit

Permalink
Metric aggregation testing (#3517)
Browse files Browse the repository at this point in the history
* aggregation testing

* add more tests

* mse

* more tests

* fix tests

* fix doctest

* fix codefactor

* fix import error

* fix doctest

* revert docfix

* test for model integration

* fix integration test

* added test cases

* fix rmsle

* aggregation testing

* add more tests

* mse

* more tests

* fix tests

* fix doctest

* fix codefactor

* fix import error

* fix doctest

* revert docfix

* test for model integration

* fix integration test

* fix psnr

* add warning/valueerror to embedding similarity

* fixed f scores

* disable some test

* fix tests

* fixing codefactor

* fix pep8

* changelog

* fix doctest

* cleaning test

* fix pickle error

* pickle fix

* fix pickle error

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* code cleanup + changes based on suggestions

* update based on suggestion

* update based on suggestions

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 1, 2020
1 parent ac2b0f0 commit fe29028
Show file tree
Hide file tree
Showing 10 changed files with 539 additions and 92 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))

- Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
114 changes: 66 additions & 48 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
auroc,
average_precision,
confusion_matrix,
_confmat_normalize,
dice_score,
f1_score,
fbeta_score,
iou,
multiclass_precision_recall_curve,
multiclass_roc,
precision,
precision_recall_curve,
recall,
roc,
precision_recall
)
from pytorch_lightning.metrics.functional.reduction import class_reduce
from pytorch_lightning.metrics.metric import TensorMetric


Expand All @@ -44,8 +45,8 @@ class Accuracy(TensorMetric):
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Accuracy()
>>> metric(pred, target).item()
0.75
>>> metric(pred, target)
tensor(0.7500)
"""

Expand Down Expand Up @@ -84,7 +85,14 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the classification score.
"""
return accuracy(pred=pred, target=target,
num_classes=self.num_classes, class_reduction=self.class_reduction)
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
tps, sups = output['tps'], output['sups']
return class_reduce(tps, sups, sups, class_reduction=self.class_reduction)


class ConfusionMatrix(TensorMetric):
Expand Down Expand Up @@ -135,16 +143,16 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the confusion matrix.
"""
return confusion_matrix(pred=pred, target=target,
normalize=self.normalize,
normalize=False, # we normalize after ddp sync
num_classes=self.num_classes)

def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
"""Aggregates results by stacking them instead of concatenating before averaging.
Returns:
the aggregated results
"""
return torch.stack(tensors).mean(0)
@staticmethod
def compute(self, data: Any, output: Any):
""" Confusion matrix normalization needs to happen after ddp sync """
confmat = output
if self.normalize:
confmat = _confmat_normalize(confmat)
return confmat


class PrecisionRecallCurve(TensorMetric):
Expand Down Expand Up @@ -202,7 +210,8 @@ def forward(
- recall values
- threshold values
"""
return precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight, pos_label=self.pos_label)


class Precision(TensorMetric):
Expand Down Expand Up @@ -256,9 +265,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Return:
A Tensor with the classification score.
"""
return precision(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction=self.class_reduction)
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
tps, fps, sups = output['tps'], output['fps'], output['sups']
return class_reduce(tps, tps + fps, sups, class_reduction=self.class_reduction)


class Recall(TensorMetric):
Expand Down Expand Up @@ -313,10 +328,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Return:
A Tensor with the classification score.
"""
return recall(pred=pred,
target=target,
num_classes=self.num_classes,
class_reduction=self.class_reduction)
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
tps, fns, sups = output['tps'], output['fns'], output['sups']
return class_reduce(tps, tps + fns, sups, class_reduction=self.class_reduction)


class AveragePrecision(TensorMetric):
Expand Down Expand Up @@ -470,12 +490,28 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Return:
torch.Tensor: classification score
"""
return fbeta_score(pred=pred, target=target,
beta=self.beta, num_classes=self.num_classes,
class_reduction=self.class_reduction)
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)

@staticmethod
def compute(self, data: Any, output: Any):
""" tps, fps, fns, sups needs to be synced before we do any calculations """
tps, fps, fns, sups = output['tps'], output['fps'], output['fns'], output['sups']

intermidiate_reduction = 'none' if self.class_reduction != "micro" else 'micro'
precision = class_reduce(tps, tps + fps, sups, class_reduction=intermidiate_reduction)
recall = class_reduce(tps, tps + fns, sups, class_reduction=intermidiate_reduction)

num = (1 + self.beta ** 2) * precision * recall
denom = ((self.beta ** 2) * precision + recall)
if intermidiate_reduction == 'micro':
return torch.sum(num) / torch.sum(denom)
return class_reduce(num, denom, sups, class_reduction=self.class_reduction)

class F1(TensorMetric):

class F1(FBeta):
"""
Computes the F1 score, which is the harmonic mean of the precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Expand Down Expand Up @@ -507,29 +543,11 @@ def __init__(
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="f1",
reduce_group=reduce_group,
)

self.num_classes = num_classes
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
self.class_reduction = class_reduction

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
Return:
torch.Tensor: classification score
"""
return f1_score(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction=self.class_reduction)
super().__init__(beta=1.0,
num_classes=num_classes,
class_reduction=class_reduction,
reduce_group=reduce_group)
self.name = "f1"


class ROC(TensorMetric):
Expand Down
32 changes: 23 additions & 9 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def accuracy(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
class_reduction: str = 'micro'
class_reduction: str = 'micro',
return_state: bool = False
) -> torch.Tensor:
"""
Computes the accuracy classification score
Expand All @@ -256,7 +257,8 @@ def accuracy(
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class
return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
A Tensor with the accuracy score.
Expand All @@ -270,10 +272,21 @@ def accuracy(
"""
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes)

if return_state:
return {'tps': tps, 'sups': sups}
return class_reduce(tps, sups, sups, class_reduction=class_reduction)


def _confmat_normalize(cm):
""" Normalization function for confusion matrix """
cm = cm / cm.sum(-1, keepdim=True)
nan_elements = cm[torch.isnan(cm)].nelement()
if nan_elements != 0:
cm[torch.isnan(cm)] = 0
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
return cm


def confusion_matrix(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -311,11 +324,7 @@ def confusion_matrix(
cm = bins.reshape(num_classes, num_classes).squeeze().float()

if normalize:
cm = cm / cm.sum(-1, keepdim=True)
nan_elements = cm[torch.isnan(cm)].nelement()
if nan_elements != 0:
cm[torch.isnan(cm)] = 0
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
cm = _confmat_normalize(cm)

return cm

Expand All @@ -325,7 +334,8 @@ def precision_recall(
target: torch.Tensor,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
return_support: bool = False
return_support: bool = False,
return_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes precision and recall for different thresholds
Expand All @@ -342,6 +352,8 @@ def precision_recall(
- ``'none'``: returns calculated metric per class
return_support: returns the support for each class, need for fbeta/f1 calculations
return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
Tensor with precision and recall
Expand All @@ -358,6 +370,8 @@ def precision_recall(

precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction)
recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction)
if return_state:
return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups}
if return_support:
return precision, recall, sups
return precision, recall
Expand Down
Loading

0 comments on commit fe29028

Please sign in to comment.