diff --git a/anomalib/core/metrics/auroc.py b/anomalib/core/metrics/auroc.py index 35165a9090..5ce844b989 100644 --- a/anomalib/core/metrics/auroc.py +++ b/anomalib/core/metrics/auroc.py @@ -1,4 +1,5 @@ """Implementation of AUROC metric based on TorchMetrics.""" +import torch from torch import Tensor from torchmetrics import ROC from torchmetrics.functional import auc @@ -13,5 +14,11 @@ def compute(self) -> Tensor: Returns: Value of the AUROC metric """ + tpr: Tensor + fpr: Tensor + fpr, tpr, _thresholds = super().compute() + # TODO: use stable sort after upgrading to pytorch 1.9.x (https://github.com/openvinotoolkit/anomalib/issues/92) + if not (torch.all(fpr.diff() <= 0) or torch.all(fpr.diff() >= 0)): + return auc(fpr, tpr, reorder=True) # only reorder if fpr is not increasing or decreasing return auc(fpr, tpr) diff --git a/tox.ini b/tox.ini index e80af51c8f..605b7f223a 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ envlist = [testenv:black] basepython = python3 -deps = black +deps = black==20.8b1 commands = black --check --diff anomalib -l 120 [testenv:isort]