diff --git a/CHANGELOG.md b/CHANGELOG.md index 396bb6becaec0..c76175858f42e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764)) +- Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517)) + ## [0.9.0] - YYYY-MM-DD ### Added diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index a279cb3a4747e..d44d93f01a8ac 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -21,17 +21,18 @@ auroc, average_precision, confusion_matrix, + _confmat_normalize, dice_score, f1_score, fbeta_score, iou, multiclass_precision_recall_curve, multiclass_roc, - precision, precision_recall_curve, - recall, roc, + precision_recall ) +from pytorch_lightning.metrics.functional.reduction import class_reduce from pytorch_lightning.metrics.metric import TensorMetric @@ -44,8 +45,8 @@ class Accuracy(TensorMetric): >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = Accuracy() - >>> metric(pred, target).item() - 0.75 + >>> metric(pred, target) + tensor(0.7500) """ @@ -84,7 +85,14 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: A Tensor with the classification score. """ return accuracy(pred=pred, target=target, - num_classes=self.num_classes, class_reduction=self.class_reduction) + num_classes=self.num_classes, + class_reduction='none', + return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + tps, sups = output['tps'], output['sups'] + return class_reduce(tps, sups, sups, class_reduction=self.class_reduction) class ConfusionMatrix(TensorMetric): @@ -135,16 +143,16 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: A Tensor with the confusion matrix. """ return confusion_matrix(pred=pred, target=target, - normalize=self.normalize, + normalize=False, # we normalize after ddp sync num_classes=self.num_classes) - def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: - """Aggregates results by stacking them instead of concatenating before averaging. - - Returns: - the aggregated results - """ - return torch.stack(tensors).mean(0) + @staticmethod + def compute(self, data: Any, output: Any): + """ Confusion matrix normalization needs to happen after ddp sync """ + confmat = output + if self.normalize: + confmat = _confmat_normalize(confmat) + return confmat class PrecisionRecallCurve(TensorMetric): @@ -202,7 +210,8 @@ def forward( - recall values - threshold values """ - return precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) + return precision_recall_curve(pred=pred, target=target, + sample_weight=sample_weight, pos_label=self.pos_label) class Precision(TensorMetric): @@ -256,9 +265,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return precision(pred=pred, target=target, - num_classes=self.num_classes, - class_reduction=self.class_reduction) + return precision_recall(pred=pred, target=target, + num_classes=self.num_classes, + class_reduction='none', + return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + tps, fps, sups = output['tps'], output['fps'], output['sups'] + return class_reduce(tps, tps + fps, sups, class_reduction=self.class_reduction) class Recall(TensorMetric): @@ -313,10 +328,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the classification score. """ - return recall(pred=pred, - target=target, - num_classes=self.num_classes, - class_reduction=self.class_reduction) + return precision_recall(pred=pred, target=target, + num_classes=self.num_classes, + class_reduction='none', + return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + tps, fns, sups = output['tps'], output['fns'], output['sups'] + return class_reduce(tps, tps + fns, sups, class_reduction=self.class_reduction) class AveragePrecision(TensorMetric): @@ -470,12 +490,28 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: torch.Tensor: classification score """ - return fbeta_score(pred=pred, target=target, - beta=self.beta, num_classes=self.num_classes, - class_reduction=self.class_reduction) + return precision_recall(pred=pred, target=target, + num_classes=self.num_classes, + class_reduction='none', + return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + """ tps, fps, fns, sups needs to be synced before we do any calculations """ + tps, fps, fns, sups = output['tps'], output['fps'], output['fns'], output['sups'] + + intermidiate_reduction = 'none' if self.class_reduction != "micro" else 'micro' + precision = class_reduce(tps, tps + fps, sups, class_reduction=intermidiate_reduction) + recall = class_reduce(tps, tps + fns, sups, class_reduction=intermidiate_reduction) + num = (1 + self.beta ** 2) * precision * recall + denom = ((self.beta ** 2) * precision + recall) + if intermidiate_reduction == 'micro': + return torch.sum(num) / torch.sum(denom) + return class_reduce(num, denom, sups, class_reduction=self.class_reduction) -class F1(TensorMetric): + +class F1(FBeta): """ Computes the F1 score, which is the harmonic mean of the precision and recall. It ranges between 1 and 0, where 1 is perfect and the worst value is 0. @@ -507,29 +543,11 @@ def __init__( reduce_group: the process group to reduce metric results from DDP """ - super().__init__( - name="f1", - reduce_group=reduce_group, - ) - - self.num_classes = num_classes - assert class_reduction in ('micro', 'macro', 'weighted', 'none') - self.class_reduction = class_reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: groundtruth labels - - Return: - torch.Tensor: classification score - """ - return f1_score(pred=pred, target=target, - num_classes=self.num_classes, - class_reduction=self.class_reduction) + super().__init__(beta=1.0, + num_classes=num_classes, + class_reduction=class_reduction, + reduce_group=reduce_group) + self.name = "f1" class ROC(TensorMetric): diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 2a3a2f5dba133..20510c4c088f8 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -241,7 +241,8 @@ def accuracy( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - class_reduction: str = 'micro' + class_reduction: str = 'micro', + return_state: bool = False ) -> torch.Tensor: """ Computes the accuracy classification score @@ -256,7 +257,8 @@ def accuracy( - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - ``'none'``: returns calculated metric per class - + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: A Tensor with the accuracy score. @@ -270,10 +272,21 @@ def accuracy( """ tps, fps, tns, fns, sups = stat_scores_multiple_classes( pred=pred, target=target, num_classes=num_classes) - + if return_state: + return {'tps': tps, 'sups': sups} return class_reduce(tps, sups, sups, class_reduction=class_reduction) +def _confmat_normalize(cm): + """ Normalization function for confusion matrix """ + cm = cm / cm.sum(-1, keepdim=True) + nan_elements = cm[torch.isnan(cm)].nelement() + if nan_elements != 0: + cm[torch.isnan(cm)] = 0 + rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') + return cm + + def confusion_matrix( pred: torch.Tensor, target: torch.Tensor, @@ -311,11 +324,7 @@ def confusion_matrix( cm = bins.reshape(num_classes, num_classes).squeeze().float() if normalize: - cm = cm / cm.sum(-1, keepdim=True) - nan_elements = cm[torch.isnan(cm)].nelement() - if nan_elements != 0: - cm[torch.isnan(cm)] = 0 - rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') + cm = _confmat_normalize(cm) return cm @@ -325,7 +334,8 @@ def precision_recall( target: torch.Tensor, num_classes: Optional[int] = None, class_reduction: str = 'micro', - return_support: bool = False + return_support: bool = False, + return_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds @@ -342,6 +352,8 @@ def precision_recall( - ``'none'``: returns calculated metric per class return_support: returns the support for each class, need for fbeta/f1 calculations + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with precision and recall @@ -358,6 +370,8 @@ def precision_recall( precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction) + if return_state: + return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups} if return_support: return precision, recall, sups return precision, recall diff --git a/pytorch_lightning/metrics/functional/regression.py b/pytorch_lightning/metrics/functional/regression.py index 75d8f1adf9a86..63d4615cae1a1 100644 --- a/pytorch_lightning/metrics/functional/regression.py +++ b/pytorch_lightning/metrics/functional/regression.py @@ -9,7 +9,8 @@ def mse( pred: torch.Tensor, target: torch.Tensor, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes mean squared error @@ -22,6 +23,8 @@ def mse( - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with MSE @@ -35,6 +38,8 @@ def mse( """ mse = F.mse_loss(pred, target, reduction='none') + if return_state: + return {'squared_error': mse.sum(), 'n_observations': torch.tensor(mse.numel())} mse = reduce(mse, reduction=reduction) return mse @@ -42,7 +47,8 @@ def mse( def rmse( pred: torch.Tensor, target: torch.Tensor, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes root mean squared error @@ -55,6 +61,8 @@ def rmse( - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with RMSE @@ -66,14 +74,18 @@ def rmse( tensor(0.5000) """ - rmse = torch.sqrt(mse(pred, target, reduction=reduction)) - return rmse + mean_squared_error = mse(pred, target, reduction=reduction) + if return_state: + return {'squared_error': mean_squared_error.sum(), + 'n_observations': torch.tensor(mean_squared_error.numel())} + return torch.sqrt(mean_squared_error) def mae( pred: torch.Tensor, target: torch.Tensor, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes mean absolute error @@ -86,6 +98,8 @@ def mae( - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with MAE @@ -99,6 +113,8 @@ def mae( """ mae = F.l1_loss(pred, target, reduction='none') + if return_state: + return {'absolute_error': mae.sum(), 'n_observations': torch.tensor(mae.numel())} mae = reduce(mae, reduction=reduction) return mae @@ -140,7 +156,8 @@ def psnr( target: torch.Tensor, data_range: float = None, base: float = 10.0, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes the peak signal-to-noise ratio @@ -155,25 +172,30 @@ def psnr( - ``'elementwise_mean'``: takes the mean (default) - ``'sum'``: takes the sum - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with PSNR score Example: - >>> from pytorch_lightning.metrics.regression import PSNR >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> metric = PSNR() - >>> metric(pred, target) + >>> psnr(pred, target) tensor(2.5527) """ if data_range is None: - data_range = max(target.max() - target.min(), pred.max() - pred.min()) + data_range = target.max() - target.min() else: data_range = torch.tensor(float(data_range)) + if return_state: + return {'data_range': data_range, + 'sum_squared_error': F.mse_loss(pred, target, reduction='none').sum(), + 'n_obs': torch.tensor(target.numel())} + mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction) psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 45c50b084956f..e97f054f05f89 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -72,6 +72,7 @@ def __init__(self, name: str, reduce_group: Optional[Any] = None): self.reduce_group = reduce_group + # Buffer for holding aggregated state after each batch self._step_vals = [] # Register hooks @@ -131,9 +132,6 @@ def ddp_sync(self, tensor: Any): """ gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group) - - self._step_vals.append(gathered_tensors) - return gathered_tensors @staticmethod @@ -150,7 +148,9 @@ def ddp_reduce(self, data: Any, output: Any): """ synced = self.ddp_sync(output) - return self.aggregate(synced) + agg_val = self.aggregate(synced) + self._step_vals.append(agg_val) + return agg_val def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: """ @@ -163,17 +163,27 @@ def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: aggregated values """ - try: - return torch.cat(tensors).mean(0) - except (ValueError, TypeError): - if isinstance(tensors[0], Mapping): - return {k: torch.stack([tensor[k] for tensor in tensors]).mean(0) for k in tensors[0].keys()} - elif isinstance(tensors[0], Sequence) and not isinstance(tensors[0], torch.Tensor): - return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)]) - elif isinstance(tensors[0], torch.Tensor): - return torch.stack(tensors).mean(0) - else: - raise TypeError("unknown metric value format to aggregate") + # single tensor + if len(tensors) == 1: + tensors = tensors[0] + if isinstance(tensors, Mapping): + return {k: _stack_and_agg(tensors[k]) for k in tensors.keys()} + if isinstance(tensors, list): + return _stack_and_agg(tensors) + if isinstance(tensors, tuple): + return tensors + if isinstance(tensors, torch.Tensor): + return _stack_and_agg(tensors) + + # multiple tensors (from aggregation over batches) + if isinstance(tensors[0], Mapping): + return {k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys()} + if isinstance(tensors[0], Sequence): + return tuple([torch.stack(tmp).sum(0) for tmp in zip(*tensors)]) + if isinstance(tensors[0], torch.Tensor): + return torch.stack(tensors).sum(0) + + raise TypeError("unknown metric value format to aggregate") @staticmethod def compute(self, data: Any, output: Any): @@ -192,7 +202,7 @@ def compute(self, data: Any, output: Any): @property def aggregated(self) -> torch.Tensor: - aggr = self.aggregate(*self._step_vals) + aggr = self.aggregate(*self._step_vals if len(self._step_vals) > 1 else self._step_vals) self.reset() return self.compute(self, None, aggr) @@ -200,6 +210,13 @@ def reset(self): self._step_vals = [] +def _stack_and_agg(tensors): + """ Utility function for stacking and aggregating tensors """ + if isinstance(tensors, list): + return torch.sum(torch.stack([t for t in tensors]), 0) + return tensors.squeeze() if tensors.numel() == 1 else tensors + + class TensorMetric(Metric): """ Base class for metric implementation operating directly on tensors. diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py index 66963f19bcc56..a152d684a43a8 100644 --- a/pytorch_lightning/metrics/regression.py +++ b/pytorch_lightning/metrics/regression.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from typing import Sequence, Any import torch @@ -67,7 +67,12 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the mse loss. """ - return mse(pred, target, self.reduction) + return mse(pred, target, return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + sse, n = output['squared_error'], output['n_observations'] + return sse / n class RMSE(Metric): @@ -110,7 +115,13 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the rmse loss. """ - return rmse(pred, target, self.reduction) + return rmse(pred, target, reduction='none', return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + """ Squaring needs to happend after ddp sync """ + sse, n = output['squared_error'], output['n_observations'] + return torch.sqrt(sse / n) class MAE(Metric): @@ -153,7 +164,12 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the mae loss. """ - return mae(pred, target, self.reduction) + return mae(pred, target, return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + sae, n = output['absolute_error'], output['n_observations'] + return sae / n class RMSLE(Metric): @@ -196,7 +212,14 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with the rmsle loss. """ - return rmsle(pred, target, self.reduction) + return mse(torch.log(pred + 1), torch.log(target + 1), + self.reduction, return_state=True) + + @staticmethod + def compute(self, data: Any, output: Any): + """ Squaring needs to happend after ddp sync """ + sse, n = output['squared_error'], output['n_observations'] + return torch.sqrt(sse / n) class PSNR(Metric): @@ -245,7 +268,38 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Return: A Tensor with psnr score. """ - return psnr(pred, target, self.data_range, self.base, self.reduction) + return psnr(pred, target, self.data_range, self.base, self.reduction, return_state=True) + + def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: + """ Special aggregation function as the data range needs to be correctly synced """ + if len(tensors) == 1: + tensors = tensors[0] + output = {'data_range': torch.stack([t for t in tensors['data_range']]).max()} + output.update({k: torch.stack([t for t in tensors[k]]).sum(0) for k in tensors.keys() if k != 'data_range'}) + return output + + output = {'data_range': torch.stack([tensor['data_range'] for tensor in tensors]).max()} + output.update({k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys() if k != 'data_range'}) + return output + + @staticmethod + def compute(self, data: Any, output: Any): + """ + Compute final value based on the synced data_range, sum of squared errors + and number of samples. + + Args: + data: input to forward method + output: output from the `aggregate` hook + + Returns: + final metric value + + """ + sse, n, data_range = output['sum_squared_error'], output['n_obs'], output['data_range'] + psnr_base_e = 2 * torch.log(data_range) - torch.log(sse / n) + psnr = psnr_base_e * (10 / torch.log(torch.tensor(self.base))) + return psnr class SSIM(Metric): diff --git a/pytorch_lightning/metrics/self_supervised.py b/pytorch_lightning/metrics/self_supervised.py index b5e492662e854..9e57c15026fcc 100644 --- a/pytorch_lightning/metrics/self_supervised.py +++ b/pytorch_lightning/metrics/self_supervised.py @@ -18,6 +18,7 @@ from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity from pytorch_lightning.metrics.metric import TensorMetric +from pytorch_lightning.utilities import rank_zero_warn class EmbeddingSimilarity(TensorMetric): @@ -56,6 +57,8 @@ def __init__( assert reduction in ('none', 'sum', 'mean') self.reduction = reduction + rank_zero_warn('Please note that Metric `EmbeddingSimilarity` does not support aggregation.') + def forward(self, batch: torch.Tensor) -> torch.Tensor: """ Actual metric computation @@ -71,3 +74,12 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: similarity=self.similarity, zero_diagonal=self.zero_diagonal, reduction=self.reduction) + + @staticmethod + def ddp_reduce(self, data: Any, output: Any): + """ reduction for this metric does not make sense """ + return output + + @property + def aggregated(self): + raise ValueError('Metric `EmbeddingSimilarity` does not support aggregation.') diff --git a/tests/metrics/functional/test_regression.py b/tests/metrics/functional/test_regression.py index bbba5a421ebfb..49a79f9424f13 100644 --- a/tests/metrics/functional/test_regression.py +++ b/tests/metrics/functional/test_regression.py @@ -90,7 +90,7 @@ def test_rmsle(pred, target, expected): ]) def test_psnr_with_skimage(pred, target): score = psnr(pred=torch.tensor(pred), - target=torch.tensor(target)) + target=torch.tensor(target), data_range=3) sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3) assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3) diff --git a/tests/metrics/test_aggregation.py b/tests/metrics/test_aggregation.py new file mode 100644 index 0000000000000..73c1e05118554 --- /dev/null +++ b/tests/metrics/test_aggregation.py @@ -0,0 +1,297 @@ +import pytest +import sys +from collections import namedtuple +from functools import partial +import math + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import numpy as np + +from tests.base import EvalModelTemplate +from pytorch_lightning import Trainer +import tests.base.develop_utils as tutils +from pytorch_lightning.metrics import ( + Accuracy, + ConfusionMatrix, + PrecisionRecallCurve, + Precision, + Recall, + AveragePrecision, + AUROC, + FBeta, + F1, + ROC, + MulticlassROC, + MulticlassPrecisionRecallCurve, + DiceCoefficient, + IoU, + MAE, + MSE, + RMSE, + RMSLE, + PSNR, + SSIM, +) + +from sklearn.metrics import ( + accuracy_score, + confusion_matrix, + precision_recall_curve, + precision_score, + recall_score, + average_precision_score, + roc_auc_score, + fbeta_score, + f1_score, + roc_curve, + jaccard_score, + mean_squared_error, + mean_absolute_error, + mean_squared_log_error +) + +from skimage.metrics import ( + peak_signal_noise_ratio, + structural_similarity +) + +# example structure +TestCase = namedtuple('example', ['name', 'lightning_metric', 'comparing_metric', 'test_input']) + +# setup some standard testcases +NB_SAMPLES = 200 +multiclass_example = [(torch.randint(10, (NB_SAMPLES,)), torch.randint(10, (NB_SAMPLES,)))] +binary_example = [(torch.randint(2, (NB_SAMPLES,)), torch.randint(2, (NB_SAMPLES,)))] +multiclass_and_binary_example = [*multiclass_example, *binary_example] +binary_example_logits = (torch.randint(2, (NB_SAMPLES,)), torch.randint(5, (NB_SAMPLES,))) +multiclass_example_probs = (torch.randint(10, (NB_SAMPLES,)), torch.randn((NB_SAMPLES, 10)).softmax(-1)) +regression_example = [(torch.rand((NB_SAMPLES,)), torch.rand((NB_SAMPLES,)))] + + +# construct additional test functions +def root_mean_squared_error(x, y): + return math.sqrt(mean_squared_error(x, y)) + + +def root_mean_squared_log_error(x, y): + return math.sqrt(mean_squared_log_error(x, y)) + + +# Define testcases +# TODO: update remaining metrics and uncomment the corresponding test cases +TESTS = [ + TestCase('accuracy', + Accuracy, + accuracy_score, + multiclass_and_binary_example), + TestCase('confusion matrix without normalize', + ConfusionMatrix, + confusion_matrix, + multiclass_and_binary_example), + TestCase('confusion matrix with normalize', + partial(ConfusionMatrix, normalize=True), + partial(confusion_matrix, normalize='true'), + multiclass_and_binary_example), + # TestCase('precision recall curve', + # PrecisionRecallCurve, + # precision_recall_curve, + # binary_example), + TestCase('precision', + Precision, + partial(precision_score, average='micro'), + multiclass_and_binary_example), + TestCase('recall', + Recall, + partial(recall_score, average='micro'), + multiclass_and_binary_example), + # TestCase('average_precision', + # AveragePrecision, + # average_precision_score, + # binary_example), + # TestCase('auroc', + # AUROC, + # roc_auc_score, + # binary_example), + TestCase('f beta', + partial(FBeta, beta=2), + partial(fbeta_score, average='micro', beta=2), + multiclass_and_binary_example), + TestCase('f1', + F1, + partial(f1_score, average='micro'), + multiclass_and_binary_example), + # TestCase('roc', + # ROC, + # roc_curve, + # binary_example), + # TestCase('multiclass roc', + # MulticlassROC, + # multiclass_roc, + # binary_example), + # TestCase('multiclass precision recall curve', + # MulticlassPrecisionRecallCurve, + # multiclass_precision_recall_curve, + # binary_example), + # TestCase('dice coefficient', + # DiceCoefficient, + # partial(f1_score, average='micro'), + # multiclass_and_binary_example), + # TestCase('intersection over union', + # IoU, + # partial(jaccard_score, average='macro'), + # binary_example), + TestCase('mean squared error', + MSE, + mean_squared_error, + regression_example), + TestCase('root mean squared error', + RMSE, + root_mean_squared_error, + regression_example), + TestCase('mean absolute error', + MAE, + mean_absolute_error, + regression_example), + TestCase('root mean squared log error', + RMSLE, + root_mean_squared_log_error, + regression_example), + TestCase('peak signal-to-noise ratio', + partial(PSNR, data_range=10), + partial(peak_signal_noise_ratio, data_range=10), + regression_example), + # TestCase('structual similarity index measure', + # SSIM, + # structural_similarity, + # regression_example) +] + + +# Utility test functions +def _idsfn(test): + """ Return id for current example being tested """ + return test.name + + +def _setup_ddp(rank, worldsize): + """ setup ddp enviroment for testing """ + import os + os.environ['MASTER_ADDR'] = 'localhost' + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def comparing_fn(lightning_val, comparing_val, rtol=1e-03, atol=1e-08): + """ function for comparing output, both multi and single output""" + # multi output + if isinstance(comparing_val, tuple): + for l_score, c_score in zip(lightning_val, comparing_val): + assert np.allclose(l_score.numpy(), c_score, rtol, atol) + else: # single output + assert np.allclose(lightning_val.numpy(), comparing_val, rtol, atol) + + +# ===== Tests start here ===== +def _test_ddp_single_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs): + """ ddp testing function, divide test_inputs equally between all processes """ + _setup_ddp(rank, worldsize) + + # Setup metric for ddp + lightning_metric = lightning_metric() + for test_input in test_inputs: + # rank 0 receives sample 0,2,4,... + # rank 1 receives sample 1,3,5,... + lightning_val = lightning_metric(*[ti[rank::2] for ti in test_input]) + + comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) + + comparing_fn(lightning_val, comparing_val) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.parametrize("test", TESTS, ids=_idsfn) +def test_ddp(test): + """Make sure that metrics are correctly sync and reduced in DDP mode""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(_test_ddp_single_batch, + args=(worldsize, + test.lightning_metric, + test.comparing_metric, + test.test_input), + nprocs=worldsize) + + +@pytest.mark.parametrize("test", TESTS, ids=_idsfn) +def test_multi_batch(test): + """ test that aggregation works for multiple batches """ + lightning_metric = test.lightning_metric() + comparing_metric = test.comparing_metric + + for test_input in test.test_input: + for i in range(2): # for lightning device in 2 artificially batches + # first batch consist of samples 0,2,4,... + # second batch consist of samples 1,3,5,... + _ = lightning_metric(*[ti[i::2] for ti in test_input]) + lightning_val = lightning_metric.aggregated + comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) + + comparing_fn(lightning_val, comparing_val) + + +@pytest.mark.parametrize("test", TESTS, ids=_idsfn) +def test_multi_batch_unequal_sizes(test): + """ test that aggregation works for multiple batches with uneven sizes """ + lightning_metric = test.lightning_metric() + comparing_metric = test.comparing_metric + + for test_input in test.test_input: + for i in range(2): # for lightning device in 2 artificially batches + if i == 0: # allocate 3/4 of data to the first batch + _ = lightning_metric(*[ti[:int(3 / 4 * len(ti))] for ti in test_input]) + else: + _ = lightning_metric(*[ti[int(3 / 4 * len(ti)):] for ti in test_input]) + lightning_val = lightning_metric.aggregated + comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) + + comparing_fn(lightning_val, comparing_val) + + +def _test_ddp_multi_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs): + """ ddp testing function, test that metric works with aggregation over multiple + devices and multiple batches """ + _setup_ddp(rank, worldsize) + + # Setup metric for ddp + lightning_metric = lightning_metric() + for test_input in test_inputs: + for i in range(2): # artificially divide samples between batches and processes + # rank 0, batch 0 consist of samples 0,4,8,... + # rank 0, batch 1 consist of samples 1,5,9,... + # rank 1, batch 0 consist of samples 2,6,10,... + # rank 1, batch 1 consist of samples 3,7,11,... + _ = lightning_metric(*[ti[i + worldsize * rank::4] for ti in test_input]) + lightning_val = lightning_metric.aggregated + comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) + + comparing_fn(lightning_val, comparing_val) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.parametrize("test", TESTS, ids=_idsfn) +def test_ddp_multi_batch(test): + """ test that aggregation works fine with in DDP mode and multiple batches """ + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(_test_ddp_multi_batch, + args=(worldsize, + test.lightning_metric, + test.comparing_metric, + test.test_input), + nprocs=worldsize) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 5a8589f8f254a..f3395a551ef62 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -190,7 +190,9 @@ def test_saving_pickable(tmpdir, metric: Metric): assert results_before_save == results_after_load -def check_call_order(): +def test_correct_call_order(): + """ Check that hooks are called in the expected order """ + class DummyMetric(Metric): def __init__(self): super().__init__("dummy") @@ -249,6 +251,7 @@ def compute(self, data: Any, output: Any): "ddp_reduce", "ddp_sync", "aggregate", + "compute" ] aggr = metric.aggregated assert metric.call_history == [ @@ -259,9 +262,11 @@ def compute(self, data: Any, output: Any): "ddp_reduce", "ddp_sync", "aggregate", + "compute", "aggregated", "aggregate", "reset", + "compute" ] assert torch.allclose(aggr, result) _ = metric(torch.tensor(2.0), torch.tensor(1.0)) @@ -273,15 +278,18 @@ def compute(self, data: Any, output: Any): "ddp_reduce", "ddp_sync", "aggregate", + "compute", "aggregated", "aggregate", "reset", + "compute", "input_convert", "forward", "output_convert", "ddp_reduce", "ddp_sync", "aggregate", + "compute" ] metric = DummyMetric() @@ -290,7 +298,7 @@ def compute(self, data: Any, output: Any): aggregated = metric.aggregated - assert torch.allclose(aggregated, torch.tensor(2.0)) + assert torch.allclose(aggregated, torch.tensor(4.0)) assert metric.call_history == [ "init", @@ -300,13 +308,16 @@ def compute(self, data: Any, output: Any): "ddp_reduce", "ddp_sync", "aggregate", + "compute", "input_convert", "forward", "output_convert", "ddp_reduce", "ddp_sync", "aggregate", + "compute", "aggregated", "aggregate", "reset", + "compute", ]