From 6959ea03c553a6753cc973abb58f240a096bc1c1 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 20:48:57 +0100 Subject: [PATCH 01/94] Add stuff --- .../metrics/classification/utils.py | 385 ++++++++++++++++++ pytorch_lightning/metrics/utils.py | 57 ++- tests/metrics/classification/inputs.py | 18 +- tests/metrics/classification/test_inputs.py | 301 ++++++++++++++ 4 files changed, 738 insertions(+), 23 deletions(-) create mode 100644 pytorch_lightning/metrics/classification/utils.py create mode 100644 tests/metrics/classification/test_inputs.py diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py new file mode 100644 index 0000000000000..b8f5af2e988d8 --- /dev/null +++ b/pytorch_lightning/metrics/classification/utils.py @@ -0,0 +1,385 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Optional + +import numpy as np +import torch + +from pytorch_lightning.metrics.utils import to_onehot, select_topk + + +def _check_classification_inputs( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float, + num_classes: Optional[int] = None, + is_multiclass: bool = False, + top_k: int = 1, +) -> None: + """Performs error checking on inputs for classification. + + This ensures that preds and target take one of the shape/type combinations that are + specified in ``_input_format_classification`` docstring. It also checks the cases of + over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class + cases) that there are only up to 2 distinct labels. + + In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. + + When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, + multi-label, ...), and that, if availible, the implied number of classes in the ``C`` + dimension is consistent with it (as well as that max label in target is smaller than it). + + When ``num_classes`` is not specified in these cases, consistency of the highest target + value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. + + If ``top_k`` is larger than one, then an error is raised if the inputs are not (multi-dim) + multi-class with probability predictions. + + Preds and target tensors are expected to be squeezed already - all dimensions should be + greater than 1, except perhaps the first one (N). + + Args: + preds: tensor with predictions + target: tensor with ground truth labels, always integers + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: number of classes + is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim + multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim + multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. + Defaults to None, which treats inputs as they appear. + """ + + if target.is_floating_point(): + raise ValueError("target has to be an integer tensor") + elif target.min() < 0: + raise ValueError("target has to be a non-negative tensor") + + preds_float = preds.is_floating_point() + if not preds_float and preds.min() < 0: + raise ValueError("if preds are integers, they have to be non-negative") + + if not preds.shape[0] == target.shape[0]: + raise ValueError("preds and target should have the same first dimension.") + + if preds_float: + if preds.min() < 0 or preds.max() > 1: + raise ValueError( + "preds should be probabilities, but values were detected outside of [0,1] range" + ) + + if threshold > 1 or threshold < 0: + raise ValueError("Threshold should be a probability in [0,1]") + + if is_multiclass is False and target.max() > 1: + raise ValueError("If you set is_multiclass=False, then target should not exceed 1.") + + if is_multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") + + # Check that shape/types fall into one of the cases + if len(preds.shape) == len(target.shape): + if preds.shape != target.shape: + raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") + if preds_float and target.max() > 1: + raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") + + # Get the case + if len(preds.shape) == 1 and preds_float: + case = "binary" + elif len(preds.shape) == 1 and not preds_float: + case = "multi-class" + elif len(preds.shape) > 1 and preds_float: + case = "multi-label" + else: + case = "multi-dim multi-class" + + implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) + + elif len(preds.shape) == len(target.shape) + 1: + if not preds_float: + raise ValueError("if preds have one dimension more than target, preds should be a float tensor") + if not preds.shape[:-1] == target.shape: + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "if preds if preds have one dimension more than target, the shape of preds should be" + "either of shape (N, C, ...) or (N, ..., C), and of targets of shape (N, ...)" + ) + + extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] + + if len(preds.shape) == 2: + case = "multi-class" + else: + case = "multi-dim multi-class" + else: + raise ValueError( + "preds and target should both have the (same) shape (N, ...), or target (N, ...)" + " and preds (N, C, ...) or (N, ..., C)" + ) + + if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: + raise ValueError( + "You have set is_multiclass=False, but have more than 2 classes in your data," + " based on the C dimension of preds." + ) + + # Check that num_classes is consistent + if not num_classes: + if preds.shape != target.shape and target.max() >= extra_dim_size: + raise ValueError("The highest label in targets should be smaller than the size of C dimension") + else: + if case == "binary": + if num_classes > 2: + raise ValueError("Your data is binary, but num_classes is larger than 2.") + elif num_classes == 2 and not is_multiclass: + raise ValueError( + "Your data is binary and num_classes=2, but is_multiclass is not True." + "Set it to True if you want to transform binary data to multi-class format." + ) + elif num_classes == 1 and is_multiclass: + raise ValueError( + "You have binary data and have set is_multiclass=True, but num_classes is 1." + "Either leave is_multiclass unset or set it to 2 to transform binary data to multi-class format." + ) + elif "multi-class" in case: + if num_classes == 1 and is_multiclass is not False: + raise ValueError( + "You have set num_classes=1, but predictions are integers." + "If you want to convert (multi-dimensional) multi-class data with 2 classes" + "to binary/multi-label, set is_multiclass=False." + ) + elif num_classes > 1: + if is_multiclass is False: + if implied_classes != num_classes: + raise ValueError( + "You have set is_multiclass=False, but the implied number of classes " + "(from shape of inputs) does not match num_classes. If you are trying to" + "transform multi-dim multi-class data with 2 classes to multi-label, num_classes" + "should be either None or the product of the size of extra dimensions (...)." + "See Input Types in Metrics documentation." + ) + if num_classes <= target.max(): + raise ValueError("The highest label in targets should be smaller than num_classes") + if num_classes <= preds.max(): + raise ValueError("The highest label in preds should be smaller than num_classes") + if preds.shape != target.shape and num_classes != extra_dim_size: + raise ValueError("The size of C dimension of preds does not match num_classes") + + elif case == "multi-label": + if is_multiclass and num_classes != 2: + raise ValueError( + "Your have set is_multiclass=True, but num_classes is not equal to 2." + "If you are trying to transform multi-label data to 2 class multi-dimensional" + "multi-class, you should set num_classes to either 2 or None." + ) + if not is_multiclass and num_classes != implied_classes: + raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + + # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities + if top_k > 1: + if preds.shape == target.shape: + raise ValueError( + "You have set top_k above 1, but your data is not (multi-dimensional) multi-class" + "with probability predictions." + ) + + +def _input_format_classification( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + top_k: int = 1, + num_classes: Optional[int] = None, + is_multiclass: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor, str]: + """Convert preds and target tensors into common format. + + Preds and targets are supposed to fall into one of these categories (and are + validated to make sure this is the case): + + * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) + * Both preds and target are of shape ``(N,)``, and target is binary, while preds + are a float (binary) + * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and + is integer (multi-class) + * preds and target are of shape ``(N, ...)``, target is binary and preds is a float + (multi-label) + * preds are of shape ``(N, ..., C)`` or ``(N, C, ...)`` and are floats, target is of + shape ``(N, ...)`` and is integer (multi-dimensional multi-class) + * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional + multi-class) + + To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. + + The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` + of ``(N, C, X)``, the details for each case are described below. The function also returns + a ``mode`` string, which describes which of the above cases the inputs belonged to - regardless + of whether this was "overridden" by other settings (like ``is_multiclass``). + + In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed + into a binary tensor (elements become 1 if the probability is greater than or equal to + ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds + become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to + preds first. + + In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets + by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original + shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be + returned as ``(N,1)`` tensor. + + In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with + preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening + all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as + ``(N, 2, C)``, by an equivalent transformation as in the binary case. + + In multi-dimensional multi-class case, normally both target and preds are returned as + ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and + ``C``. The transformations performed here are equivalent to the multi-class case. However, if + ``is_multiclass=False`` (and there are up to two classes), then the data is returned as + ``(N, X)`` binary tensors (multi-label). + + Also, in multi-dimensional multi-class case, if the position of the ``C`` + dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a + ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. + If this is not the case, you should move it from the last to second place using + ``torch.movedim(preds, -1, 1)``. + + Note that where a one-hot transformation needs to be performed and the number of classes + is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be + equal to ``num_classes``, if it is given, or the maximum label value in preds and + target. + + Args: + preds: tensor with predictions + target: tensor with ground truth labels, always integers + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + num_classes: number of classes + top_k: number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class cases. + is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim + multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim + multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. + Defaults to None, which treats inputs as they appear. + + Returns: + preds: binary tensor of shape (N, C) or (N, C, X) + target: binary tensor of shape (N, C) or (N, C, X) + """ + preds, target = preds.clone().detach(), target.clone().detach() + + # Remove excess dimensions + if preds.shape[0] == 1: + preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) + else: + preds, target = preds.squeeze(), target.squeeze() + + _check_classification_inputs( + preds, + target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + preds_float = preds.is_floating_point() + + if len(preds.shape) == len(target.shape) == 1 and preds_float: + mode = "binary" + preds = (preds >= threshold).int() + + if is_multiclass: + target = to_onehot(target, 2) + preds = to_onehot(preds, 2) + else: + preds = preds.unsqueeze(-1) + target = target.unsqueeze(-1) + + elif len(preds.shape) == len(target.shape) and preds_float: + mode = "multi-label" + preds = (preds >= threshold).int() + + if is_multiclass: + preds = to_onehot(preds, 2).reshape(preds.shape[0], 2, -1) + target = to_onehot(target, 2).reshape(target.shape[0], 2, -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + + elif len(preds.shape) == len(target.shape) + 1 == 2: + mode = "multi-class" + if not num_classes: + num_classes = preds.shape[1] + + target = to_onehot(target, num_classes) + preds = select_topk(preds, top_k) + + # If is_multiclass=False, force to binary + if is_multiclass is False: + target = target[:, [1]] + preds = preds[:, [1]] + + elif len(preds.shape) == len(target.shape) == 1 and not preds_float: + mode = "multi-class" + + if not num_classes: + num_classes = max(preds.max(), target.max()) + 1 + + # If is_multiclass=False, force to binary + if is_multiclass is False: + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) + else: + preds = to_onehot(preds, num_classes) + target = to_onehot(target, num_classes) + + # Multi-dim multi-class (N, ...) with integers + elif preds.shape == target.shape and not preds_float: + mode = "multi-dim multi-class" + + if not num_classes: + num_classes = max(preds.max(), target.max()) + 1 + + # If is_multiclass=False, force to multi-label + if is_multiclass is False: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + else: + target = to_onehot(target, num_classes) + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = to_onehot(preds, num_classes) + preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + + # Multi-dim multi-class (N, C, ...) and (N, ..., C) + else: + mode = "multi-dim multi-class" + if preds.shape[:-1] == target.shape: + preds = torch.movedim(preds, -1, 1) + + num_classes = preds.shape[1] + + if is_multiclass is False: + target = target.reshape(target.shape[0], -1) + preds = select_topk(preds, 1)[:, 1, ...] + preds = preds.reshape(preds.shape[0], -1) + else: + target = to_onehot(target, num_classes) + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) + + return preds, target, mode diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e1ff95b94f471..1ce56b30cf9e5 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -20,6 +20,7 @@ def dim_zero_cat(x): + x = x if isinstance(x, (list, tuple)) else [x] return torch.cat(x, dim=0) @@ -36,8 +37,8 @@ def _flatten(x): def to_onehot( - tensor: torch.Tensor, - num_classes: int, + tensor: torch.Tensor, + num_classes: int, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format @@ -57,24 +58,46 @@ def to_onehot( [0, 0, 0, 1]]) """ dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) +def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: + """ + Convert a probability tensor to binary by selecting top-k highest entries. + + Args: + tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + position defined by the ``dim`` argument + topk: number of highest entries to turn into 1s + dim: dimension on which to compare entries + + Output: + A binary tensor of the same shape as the input tensor of type torch.int32 + + Example: + >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + >>> select_topk(x, topk=2) + tensor([[0, 1, 1], + [1, 1, 0]], dtype=torch.int32) + """ + zeros = torch.zeros_like(tensor, device=tensor.device) + topk_tensor = zeros.scatter(1, tensor.topk(k=topk, dim=dim).indices, 1.0) + + return topk_tensor.int() + + def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): """ Check that predictions and target have the same shape, else raise error """ if pred.shape != target.shape: - raise RuntimeError('Predictions and targets are expected to have the same shape') + raise RuntimeError("Predictions and targets are expected to have the same shape") def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5 + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into label tensors + """Convert preds and target tensors into label tensors Args: preds: either tensor with labels, tensor with probabilities/logits or @@ -87,9 +110,7 @@ def _input_format_classification( target: tensor with labels """ if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") if len(preds.shape) == len(target.shape) + 1: # multi class probabilites @@ -102,13 +123,9 @@ def _input_format_classification( def _input_format_classification_one_hot( - num_classes: int, - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - multilabel: bool = False + num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, multilabel: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Convert preds and target tensors into one hot spare label tensors + """Convert preds and target tensors into one hot spare label tensors Args: num_classes: number of classes @@ -123,9 +140,7 @@ def _input_format_classification_one_hot( target: one hot tensors of shape [num_classes, -1] with true labels """ if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError( - "preds and target must have same number of dimensions, or one additional dimension for preds" - ) + raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") if len(preds.shape) == len(target.shape) + 1: # multi class probabilites diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index 9613df3b6f8ca..e648aaf10093e 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -29,12 +29,21 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) _multilabel_inputs = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) ) +_multilabel_multidim_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)) +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) @@ -61,8 +70,13 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) +# Class dimension last +_multidim_multiclass_prob_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) _multidim_multiclass_inputs = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)) + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py new file mode 100644 index 0000000000000..8d17d5624fac0 --- /dev/null +++ b/tests/metrics/classification/test_inputs.py @@ -0,0 +1,301 @@ +import pytest +import torch +from torch import randint, rand + +from pytorch_lightning.metrics.utils import to_onehot, select_topk +from pytorch_lightning.metrics.classification.utils import _input_format_classification +from tests.metrics.classification.inputs import ( + Input, + _binary_inputs as _bin, + _binary_prob_inputs as _bin_prob, + _multiclass_inputs as _mc, + _multiclass_prob_inputs as _mc_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, + _multidim_multiclass_prob_inputs1 as _mdmc_prob1, + _multilabel_inputs as _ml, + _multilabel_prob_inputs as _ml_prob, + _multilabel_multidim_inputs as _mlmd, + _multilabel_multidim_prob_inputs as _mlmd_prob, +) +from tests.metrics.utils import NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, THRESHOLD + +torch.manual_seed(42) + +# Some additional inputs to test on +_mc_prob_2cls = Input(rand(NUM_BATCHES, BATCH_SIZE, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) +_mdmc_prob_many_dims = Input( + rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM), + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) +_mdmc_prob_many_dims1 = Input( + rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM, NUM_CLASSES), + randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), +) +_mdmc_prob_2cls = Input( + rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) +_mdmc_prob_2cls1 = Input( + rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) + +# Some utils +T = torch.Tensor + + +def idn(x): + return x + + +def usq(x): + return x.unsqueeze(-1) + + +def toint(x): + return x.int() + + +def thrs(x): + return x >= THRESHOLD + + +def rshp1(x): + return x.reshape(x.shape[0], -1) + + +def rshp2(x): + return x.reshape(x.shape[0], x.shape[1], -1) + + +def onehot(x): + return to_onehot(x, NUM_CLASSES) + + +def onehot2(x): + return to_onehot(x, 2) + + +def top1(x): + return select_topk(x, 1) + + +def top2(x): + return select_topk(x, 2) + + +def mvdim(x): + return torch.movedim(x, -1, 1) + + +# To avoid ugly black line wrapping +def ml_preds_tr(x): + return rshp1(toint(thrs(x))) + + +def onehot_rshp1(x): + return onehot(rshp1(x)) + + +def onehot2_rshp1(x): + return onehot2(rshp1(x)) + + +def top1_rshp2(x): + return top1(rshp2(x)) + + +def top2_rshp2(x): + return top2(rshp2(x)) + + +def mdmc1_top1_tr(x): + return top1(rshp2(mvdim(x))) + + +def mdmc1_top2_tr(x): + return top2(rshp2(mvdim(x))) + + +def probs_to_mc_preds_tr(x): + return toint(onehot2(thrs(x))) + + +def mlmd_prob_to_mc_preds_tr(x): + return onehot2(rshp1(toint(thrs(x)))) + + +def mdmc_prob_to_ml_preds_tr(x): + return top1(mvdim(x))[:, 1] + + +######################## +# Test correct inputs +######################## + + +@pytest.mark.parametrize( + "inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + [ + ############################# + # Test usual expected cases + (_bin, THRESHOLD, None, False, 1, "multi-class", usq, usq), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: usq(toint(thrs(x))), usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: toint(thrs(x)), idn), + (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", idn, idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", ml_preds_tr, rshp1), + (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", rshp1, rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", onehot, onehot), + (_mc_prob, THRESHOLD, None, None, 1, "multi-class", top1, onehot), + (_mc_prob, THRESHOLD, None, None, 2, "multi-class", top2, onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", onehot, onehot), + (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot), + (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot_rshp1), + # Test with C dim in last place + (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot), + (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot_rshp1), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot_rshp1), + ########################### + # Test some special cases + # Binary as multiclass + (_bin, THRESHOLD, None, None, 1, "multi-class", onehot2, onehot2), + # Binary probs as multiclass + (_bin_prob, THRESHOLD, None, True, 1, "binary", probs_to_mc_preds_tr, onehot2), + # Multilabel as multiclass + (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2, onehot2), + # Multilabel probs as multiclass + (_ml_prob, THRESHOLD, None, True, 1, "multi-label", probs_to_mc_preds_tr, onehot2), + # Multidim multilabel as multiclass + (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2_rshp1, onehot2_rshp1), + # Multidim multilabel probs as multiclass + (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", mlmd_prob_to_mc_preds_tr, onehot2_rshp1), + # Multiclass prob with 2 classes as binary + (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: top1(x)[:, [1]], usq), + # Multi-dim multi-class with 2 classes as multi-label + (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: top1(x)[:, 1], idn), + (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", mdmc_prob_to_ml_preds_tr, idn), + ], +) +def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0], + target=inputs.target[0], + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0])) + assert torch.equal(target_out, post_target(inputs.target[0])) + + # Test that things work when batch_size = 1 + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0][[0], ...], + target=inputs.target[0][[0], ...], + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...])) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...])) + + +# Test that threshold is correctly applied +def test_threshold(): + target = T([1, 1, 1]).int() + preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5]) + + preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) + + assert torch.equal(torch.tensor([0, 1, 1]), preds_probs_out.squeeze().long()) + + +######################################################################## +# Test incorrect inputs +######################################################################## + + +@pytest.mark.parametrize( + "preds, target, threshold, num_classes, is_multiclass, top_k", + [ + # Target not integer + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, 1), + # Target negative + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, 1), + # Preds negative integers + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Negative probabilities + (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Threshold outside of [0,1] + (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, 1), + # is_multiclass=False and target > 1 + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, 1), + # is_multiclass=False and preds integers with > 1 + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, 1), + # Wrong batch size + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + # Completely wrong shape + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + # Same #dims, different shape + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + # Same shape and preds floats, target not binary + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, 1), + # #dims in preds = 1 + #dims in target, C shape not second or last + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + # #dims in preds = 1 + #dims in target, preds not float + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + # is_multiclass=False, with C dimension > 2 + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, 1), + # Max target larger or equal to C dimension + (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, 1), + # C dimension not equal to num_classes + (rand(size=(7, 3, 4)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), + # Max target larger than num_classes (with #dim preds = 1 + #dims target) + (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + # Max target larger than num_classes (with #dim preds = #dims target) + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + # Max preds larger than num_classes (with #dim preds = #dims target) + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, 1), + # Num_classes=1, but is_multiclass not false + (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, 1, None, 1), + # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + # Multilabel input with implied class dimension != num_classes + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, 1), + # Binary input, num_classes > 2 + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, 1), + # Binary input, num_classes == 2 and is_multiclass not True + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, 1), + # Binary input, num_classes == 1 and is_multiclass=True + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, 1), + # Topk > 1 with non (md)mc prob data + (_bin.preds[0], _bin.target[0], 0.5, None, None, 2), + (_bin_prob.preds[0], _bin_prob.target[0], 0.5, None, None, 2), + (_mc.preds[0], _mc.target[0], 0.5, None, None, 2), + (_ml.preds[0], _ml.target[0], 0.5, None, None, 2), + (_mlmd.preds[0], _mlmd.target[0], 0.5, None, None, 2), + (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), + (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), + (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), + ], +) +def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, + target=target, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) From 06790157ed025ba98950b9616ef43835f9b6c4e7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 21:47:30 +0100 Subject: [PATCH 02/94] Change metrics documentation layout --- docs/source/metrics.rst | 188 ++++++++++++++++++++++++++-------------- 1 file changed, 121 insertions(+), 67 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d47c872f35047..407b64d3d2948 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -196,12 +196,71 @@ Metric API .. autoclass:: pytorch_lightning.metrics.Metric :noindex: -************* -Class metrics -************* +*************************** +Class vs Functional Metrics +*************************** +The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. + +Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. +If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. + +********************** Classification Metrics ----------------------- +********************** + +Input types +----------- + +For the purposes of classification metrics, inputs (predictions and targets) are split +into these categories (``N`` stands for the batch size and ``C`` for number of classes): + +.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 + :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype" + :widths: 20, 10, 10, 10, 10 + + "Binary", "(N,)", "``float``", "(N,)", "``binary``\*" + "Multi-class", "(N,)", "``int``", "(N,)", "``int``" + "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" + "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" + "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with probabilities", "(N, C, ...) or (N, ..., C)", "``float``", "(N, ...)", "``int``" + +.. note:: + All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so + that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. + +When predictions or targets are integers, it is assumed that class labels start at , i.e. +the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types + +.. code-block:: python + + # Binary inputs + binary_preds = torch.tensor([0.6, 0.1, 0.9]) + binary_target = torch.tensor([1, 0, 2]) + + # Multi-class inputs + mc_preds = torch.tensor([0, 2, 1]) + mc_target = torch.tensor([0, 1, 2]) + + # Multi-class inputs with probabilities + mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]]) + mc_target_probs = torch.tensor([0, 1, 2]) + + # Multi-label inputs + ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) + ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) + +In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class, +but are actually binary/multi-label. For example, if both predictions and targets are 1d +binary tensors. Or it could be the other way around, you want to treat binary/multi-label +inputs as 2-class (multi-dimensional) multi-class inputs. + +For these cases, the metrics where this distinction would make a difference, expose the +``is_multiclass`` argument. + +Class Metrics (Classification) +------------------------------ Accuracy ~~~~~~~~ @@ -239,61 +298,8 @@ ConfusionMatrix .. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix :noindex: -Regression Metrics ------------------- - -MeanSquaredError -~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError - :noindex: - - -MeanAbsoluteError -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError - :noindex: - - -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError - :noindex: - - -ExplainedVariance -~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance - :noindex: - - -PSNR -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.PSNR - :noindex: - - -SSIM -~~~~ - -.. autoclass:: pytorch_lightning.metrics.regression.SSIM - :noindex: - -****************** -Functional Metrics -****************** - -The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. - -Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. - -Classification --------------- +Functional Metrics (Classification) +----------------------------------- accuracy [func] ~~~~~~~~~~~~~~~ @@ -434,9 +440,57 @@ to_onehot [func] .. autofunction:: pytorch_lightning.metrics.functional.classification.to_onehot :noindex: +****************** +Regression Metrics +****************** + +Class Metrics (Regression) +-------------------------- + +MeanSquaredError +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError + :noindex: + + +MeanAbsoluteError +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError + :noindex: + + +MeanSquaredLogError +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError + :noindex: + + +ExplainedVariance +~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance + :noindex: + + +PSNR +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.PSNR + :noindex: + -Regression ----------- +SSIM +~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.SSIM + :noindex: + + +Functional Metrics (Regression) +------------------------------- explained_variance [func] ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -479,22 +533,22 @@ ssim [func] .. autofunction:: pytorch_lightning.metrics.functional.ssim :noindex: - +*** NLP ---- +*** bleu_score [func] -~~~~~~~~~~~~~~~~~ +----------------- .. autofunction:: pytorch_lightning.metrics.functional.nlp.bleu_score :noindex: - +******** Pairwise --------- +******** embedding_similarity [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +--------------------------- .. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity :noindex: From 35627b50092df312ed71d9ab1c828a2451306d5d Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:07:40 +0100 Subject: [PATCH 03/94] Add stuff --- docs/source/metrics.rst | 13 +- pytorch_lightning/metrics/__init__.py | 1 + .../metrics/classification/__init__.py | 2 +- .../metrics/classification/accuracy.py | 162 +++++++++++++--- .../metrics/functional/__init__.py | 3 +- .../metrics/functional/accuracy.py | 155 +++++++++++++++ tests/metrics/classification/test_accuracy.py | 181 +++++++++++------- tests/metrics/utils.py | 153 +++++++-------- 8 files changed, 494 insertions(+), 176 deletions(-) create mode 100644 pytorch_lightning/metrics/functional/accuracy.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 407b64d3d2948..e7f89ba1b853b 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -268,6 +268,12 @@ Accuracy .. autoclass:: pytorch_lightning.metrics.classification.Accuracy :noindex: +Hamming Loss +~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.HammingLoss + :noindex: + Precision ~~~~~~~~~ @@ -304,9 +310,14 @@ Functional Metrics (Classification) accuracy [func] ~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.accuracy +.. autofunction:: pytorch_lightning.metrics.functional.accuracy :noindex: +hamming_loss [func] +~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.hamming_loss + :noindex: auc [func] ~~~~~~~~~~ diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 653ad23c68f7e..d10aa2e5e995c 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -15,6 +15,7 @@ from pytorch_lightning.metrics.classification import ( Accuracy, + HammingLoss, Precision, Recall, FBeta, diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index db643c227abed..45d4dd03e430e 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics.classification.accuracy import Accuracy, HammingLoss from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 0f01fb9813407..8297cc73f9540 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -11,38 +11,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -import functools -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union -from collections.abc import Mapping, Sequence -from collections import namedtuple +from typing import Any, Callable, Optional import torch -from torch import nn from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.utils import _input_format_classification +from pytorch_lightning.metrics.functional.accuracy import ( + _accuracy_update, + _hamming_loss_update, + _accuracy_compute, + _hamming_loss_compute, +) class Accuracy(Metric): """ - Computes accuracy. Works with binary, multiclass, and multilabel data. - Accepts logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. + Computes the share of entirely correctly predicted samples. - Forward accepts + This metric generalizes to subset accuracy for multilabel data, and similarly for + multi-dimensional multi-class data: for the sample to be counted as correct, the the + class has to be correctly predicted across all extra dimension for each sample in the + ``N`` dimension. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` + is this is not what you want. - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + Accepts all input types listed in :ref:`metrics:Input types`. Args: + top_k: + Number of highest probability predictions considered to find the correct label, for + (multi-dimensional) multi-class inputs with probability predictions. Default 1 + + If your inputs are not (multi-dimensional) multi-class inputs with probability predictions, + an error will be raised if ``top_k`` is set to a value other than 1. + mdmc_accuracy: + Determines how should the extra dimension be handeled in case of multi-dimensional multi-class + inputs. Options are ``"global"`` or ``"subset"``. + + If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension + were unrolled into a new sample dimension. + + If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the + ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension + must be predicted correctly (the ``top_k`` option still applies here). threshold: - Threshold value for binary or multi-label logits. default: 0.5 + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -63,7 +79,97 @@ class Accuracy(Metric): >>> accuracy(preds, target) tensor(0.5000) + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy = Accuracy(top_k=2) + >>> accuracy(preds, target) + tensor(0.6667) + """ + + def __init__( + self, + top_k: int = 1, + mdmc_accuracy: str = "subset", + threshold: float = 0.5, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + self.threshold = threshold + self.top_k = top_k + self.mdmc_accuracy = mdmc_accuracy + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + """ + + correct, total = _accuracy_update(preds, target, self.threshold, self.top_k, self.mdmc_accuracy) + + self.correct += correct + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes accuracy based on inputs passed in to ``update`` previously. + """ + return _accuracy_compute(self.correct, self.total) + + +class HammingLoss(Metric): + """ + Computes the share of wrongly predicted labels. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. If this is not what you want, consider using + :class:`~pytorch_lightning.metrics.classification.Accuracy`. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + Example: + + >>> from pytorch_lightning.metrics import HammingLoss + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_loss = HammingLoss() + >>> hamming_loss(preds, target) + tensor(0.2500) + + """ + def __init__( self, threshold: float = 0.5, @@ -86,20 +192,20 @@ def __init__( def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. Args: - preds: Predictions from model + preds: Predictions from model (probabilities, or labels) target: Ground truth values """ - preds, target = _input_format_classification(preds, target, self.threshold) - assert preds.shape == target.shape + correct, total = _hamming_loss_update(preds, target, self.threshold) - self.correct += torch.sum(preds == target) - self.total += target.numel() + self.correct += correct + self.total += total - def compute(self): + def compute(self) -> torch.Tensor: """ - Computes accuracy over state. + Computes hamming loss based on inputs passed in to ``update`` previously. """ - return self.correct.float() / self.total + return _hamming_loss_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 3bb5313db7b27..42029335afe9f 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.metrics.functional.classification import ( - accuracy, auc, auroc, average_precision, @@ -42,5 +41,7 @@ from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error from pytorch_lightning.metrics.functional.psnr import psnr from pytorch_lightning.metrics.functional.ssim import ssim + +from pytorch_lightning.metrics.functional.accuracy import accuracy, hamming_loss from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py new file mode 100644 index 0000000000000..f9861388cceda --- /dev/null +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -0,0 +1,155 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from pytorch_lightning.metrics.classification.utils import _input_format_classification + +################################ +# Accuracy +################################ + + +def _accuracy_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: int, mdmc_accuracy: str +) -> Tuple[torch.Tensor, torch.Tensor]: + + preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) + + if mode in ["binary", "multi-label"]: + correct = (preds == target).all(dim=1).sum() + total = target.shape[0] + elif mdmc_accuracy == "global": + correct = (preds * target).sum() + total = target.sum() + elif mdmc_accuracy == "subset": + extra_dims = list(range(1, len(preds.shape))) + sample_correct = (preds * target).sum(dim=extra_dims) + sample_total = target.sum(dim=extra_dims) + + correct = (sample_correct == sample_total).sum() + total = target.shape[0] + + return (torch.tensor(correct, device=preds.device), torch.tensor(total, device=preds.device)) + + +def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: + return correct / total + + +def accuracy( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, top_k: int = 1, mdmc_accuracy: str = "subset" +) -> torch.Tensor: + """ + Computes the share of entirely correctly predicted samples. + + This metric generalizes to subset accuracy for multilabel data, and similarly for + multi-dimensional multi-class data: for the sample to be counted as correct, the the + class has to be correctly predicted across all extra dimension for each sample in the + ``N`` dimension. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` + is this is not what you want. + + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + top_k: + Number of highest probability predictions considered to find the correct label, for + (multi-dimensional) multi-class inputs with probability predictions. Default 1 + + If your inputs are not (multi-dimensional) multi-class inputs with probability predictions, + an error will be raised if ``top_k`` is set to a value other than 1. + mdmc_accuracy: + Determines how should the extra dimension be handeled in case of multi-dimensional multi-class + inputs. Options are ``"global"`` or ``"subset"``. + + If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension + were unrolled into a new sample dimension. + + If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the + ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension + must be predicted correctly (the ``top_k`` option still applies here). + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + + Example: + + >>> from pytorch_lightning.metrics.functional import accuracy + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> accuracy(preds, target) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy(preds, target, top_k=2) + tensor(0.6667) + """ + + correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) + return _accuracy_compute(correct, total) + + +################################ +# Hamming loss +################################ + + +def _hamming_loss_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: + preds, target, _ = _input_format_classification(preds, target, threshold=threshold) + + correct = (preds == target).sum() + total = preds.numel() + + return correct, total + + +def _hamming_loss_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: + return 1 - correct.float() / total + + +def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + Computes the share of wrongly predicted labels. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. If this is not what you want, consider using + :class:`~pytorch_lightning.metrics.classification.Accuracy`. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + + Example: + + >>> from pytorch_lightning.metrics.functional import hamming_loss + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_loss(preds, target) + tensor(0.2500) + + """ + + correct, total = _hamming_loss_update(preds, target, threshold) + return _hamming_loss_compute(correct, total) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 017438269bdbf..8d8386683748c 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -1,9 +1,11 @@ import numpy as np import pytest import torch -from sklearn.metrics import accuracy_score +from sklearn.metrics import accuracy_score as sk_accuracy, hamming_loss as sk_hamming_loss -from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics import Accuracy, HammingLoss +from pytorch_lightning.metrics.functional import accuracy, hamming_loss +from pytorch_lightning.metrics.classification.utils import _input_format_classification from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, @@ -13,101 +15,140 @@ _multidim_multiclass_prob_inputs, _multilabel_inputs, _multilabel_prob_inputs, + _multilabel_multidim_prob_inputs, + _multilabel_multidim_inputs, ) from tests.metrics.utils import THRESHOLD, MetricTester torch.manual_seed(42) -def _sk_accuracy_binary_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +def _sk_accuracy(preds, target): + sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + return sk_accuracy(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_binary(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +def _sk_hamming_loss(preds, target): + sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_multilabel_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) - - -def _sk_accuracy_multilabel(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) - - -def _sk_accuracy_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize( + "metric, fn_metric, sk_metric", [(Accuracy, accuracy, _sk_accuracy), (HammingLoss, hamming_loss, _sk_hamming_loss)] +) +@pytest.mark.parametrize( + "preds, target", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target), + (_binary_inputs.preds, _binary_inputs.target), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), + (_multilabel_inputs.preds, _multilabel_inputs.target), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target), + (_multiclass_inputs.preds, _multiclass_inputs.target), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), + ], +) +class TestAccuracies(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, metric, sk_metric, fn_metric): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=metric, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args={"threshold": THRESHOLD}, + ) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + def test_accuracy_fn(self, preds, target, metric, sk_metric, fn_metric): + self.run_functional_metric_test( + preds, + target, + metric_functional=fn_metric, + sk_metric=sk_metric, + metric_args={"threshold": THRESHOLD}, + ) -def _sk_accuracy_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +l1to4 = [.1, .2, .3, .4] +l1to4t3 = np.array([l1to4, l1to4, l1to4]) +l1to4t3_mc = [l1to4t3.T, l1to4t3.T, l1to4t3.T] - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +# The preds in these examples always put highest probability on class 3, second highest on class 2, +# third highest on class 1, and lowest on class 0 +topk_preds_mc = torch.tensor([l1to4t3, l1to4t3]).float() +topk_target_mc = torch.tensor([[1, 2, 3], [2, 1, 0]]) +# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) +topk_preds_mdmc = torch.tensor([l1to4t3_mc, l1to4t3_mc]).float() +topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) -def _sk_accuracy_multidim_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +# Replace with a proper sk_metric test once sklearn 0.24 hits :) +@pytest.mark.parametrize( + "preds, target, exp_result, k, mdmc_accuracy", + [ + (topk_preds_mc, topk_target_mc, 1 / 6, 1, "global"), + (topk_preds_mc, topk_target_mc, 3 / 6, 2, "global"), + (topk_preds_mc, topk_target_mc, 5 / 6, 3, "global"), + (topk_preds_mc, topk_target_mc, 1 / 6, 1, "subset"), + (topk_preds_mc, topk_target_mc, 3 / 6, 2, "subset"), + (topk_preds_mc, topk_target_mc, 5 / 6, 3, "subset"), + (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, "global"), + (topk_preds_mdmc, topk_target_mdmc, 8 / 18, 2, "global"), + (topk_preds_mdmc, topk_target_mdmc, 13 / 18, 3, "global"), + (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, "subset"), + (topk_preds_mdmc, topk_target_mdmc, 2 / 6, 2, "subset"), + (topk_preds_mdmc, topk_target_mdmc, 3 / 6, 3, "subset"), + ], +) +def test_topk_accuracy(preds, target, exp_result, k, mdmc_accuracy): + topk = Accuracy(top_k=k, mdmc_accuracy=mdmc_accuracy) + for batch in range(preds.shape[0]): + topk(preds[batch], target[batch]) -def _sk_accuracy_multidim_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + assert topk.compute() == exp_result - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + # Test functional + total_samples = target.shape[0] * target.shape[1] + preds = preds.view(total_samples, 4, -1) + target = target.view(total_samples, -1) -def test_accuracy_invalid_shape(): - with pytest.raises(ValueError): - acc = Accuracy() - acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3)) + assert accuracy(preds, target, top_k=k, mdmc_accuracy=mdmc_accuracy) == exp_result -@pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# Only MC and MDMC with probs input type should be accepted @pytest.mark.parametrize( - "preds, target, sk_metric", + "preds, target", [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_accuracy_binary_prob), - (_binary_inputs.preds, _binary_inputs.target, _sk_accuracy_binary), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob), - (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_accuracy_multilabel), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob), - (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_accuracy_multiclass), - ( - _multidim_multiclass_prob_inputs.preds, - _multidim_multiclass_prob_inputs.target, - _sk_accuracy_multidim_multiclass_prob, - ), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _sk_accuracy_multidim_multiclass), + (_binary_prob_inputs.preds, _binary_prob_inputs.target), + (_binary_inputs.preds, _binary_inputs.target), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), + (_multilabel_inputs.preds, _multilabel_inputs.target), + (_multiclass_inputs.preds, _multiclass_inputs.target), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), ], ) -class TestAccuracy(MetricTester): - def test_accuracy(self, ddp, dist_sync_on_step, preds, target, sk_metric): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=Accuracy, - sk_metric=sk_metric, - dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD}, - ) +def test_topk_accuracy_wrong_input_types(preds, target): + topk = Accuracy(top_k=2) + + with pytest.raises(ValueError): + topk(preds[0], target[0]) + + with pytest.raises(ValueError): + accuracy(preds[0], target[0], top_k=2) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 34abee8473863..b0010916b6476 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -21,10 +21,10 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -41,23 +41,23 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -71,8 +71,8 @@ def _class_test( if metric.dist_sync_on_step: if rank == 0: - ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) - ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: @@ -87,8 +87,8 @@ def _class_test( result = metric.compute() assert isinstance(result, torch.Tensor) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) # assert after aggregation @@ -101,17 +101,17 @@ def _functional_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = {}, - atol: float = 1e-8 + atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -124,22 +124,23 @@ def _functional_test( class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ + atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: - set_start_method('spawn') + set_start_method("spawn") except RuntimeError: pass self.poolSize = NUM_PROCESSES @@ -157,24 +158,26 @@ def run_functional_metric_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {} + metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ - _functional_test(preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=sk_metric, - metric_args=metric_args, - atol=self.atol) + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) def run_class_metric_test( self, @@ -188,22 +191,22 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": From 55fdaaf16a185d5ddf2cf37871f389658b8de2c4 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:41:43 +0100 Subject: [PATCH 04/94] Change testing utils --- tests/metrics/utils.py | 161 +++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 79 deletions(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 34abee8473863..8ec14c41b1360 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -21,10 +21,10 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ - os.environ["MASTER_ADDR"] = 'localhost' - os.environ['MASTER_PORT'] = '8088' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8088" - if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'): + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -41,23 +41,23 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -71,28 +71,28 @@ def _class_test( if metric.dist_sync_on_step: if rank == 0: - ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) - ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) else: sk_batch_result = sk_metric(preds[i], target[i]) # assert for batch if check_batch: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol, equal_nan=True) # check on all batches on all ranks result = metric.compute() assert isinstance(result, torch.Tensor) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) # assert after aggregation - assert np.allclose(result.numpy(), sk_result, atol=atol) + assert np.allclose(result.numpy(), sk_result, atol=atol, equal_nan=True) def _functional_test( @@ -101,17 +101,17 @@ def _functional_test( metric_functional: Callable, sk_metric: Callable, metric_args: dict = {}, - atol: float = 1e-8 + atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -120,26 +120,27 @@ def _functional_test( sk_result = sk_metric(preds[i], target[i]) # assert its the same - assert np.allclose(lightning_result.numpy(), sk_result, atol=atol) + assert np.allclose(lightning_result.numpy(), sk_result, atol=atol, equal_nan=True) class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ + atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: - set_start_method('spawn') + set_start_method("spawn") except RuntimeError: pass self.poolSize = NUM_PROCESSES @@ -157,24 +158,26 @@ def run_functional_metric_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {} + metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ - _functional_test(preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=sk_metric, - metric_args=metric_args, - atol=self.atol) + _functional_test( + preds=preds, + target=target, + metric_functional=metric_functional, + sk_metric=sk_metric, + metric_args=metric_args, + atol=self.atol, + ) def run_class_metric_test( self, @@ -188,22 +191,22 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": From 5cbf56a5422d0efc29ad84d85d5768863692496b Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:50:50 +0100 Subject: [PATCH 05/94] Replace len(*.shape) with *.ndim --- .../metrics/classification/utils.py | 20 +++++++++---------- pytorch_lightning/metrics/utils.py | 16 +++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index b8f5af2e988d8..62fe3e2095a78 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -90,25 +90,25 @@ def _check_classification_inputs( raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") # Check that shape/types fall into one of the cases - if len(preds.shape) == len(target.shape): + if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") if preds_float and target.max() > 1: raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") # Get the case - if len(preds.shape) == 1 and preds_float: + if preds.ndim == 1 and preds_float: case = "binary" - elif len(preds.shape) == 1 and not preds_float: + elif preds.ndim == 1 and not preds_float: case = "multi-class" - elif len(preds.shape) > 1 and preds_float: + elif preds.ndim > 1 and preds_float: case = "multi-label" else: case = "multi-dim multi-class" implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) - elif len(preds.shape) == len(target.shape) + 1: + elif preds.ndim == target.ndim + 1: if not preds_float: raise ValueError("if preds have one dimension more than target, preds should be a float tensor") if not preds.shape[:-1] == target.shape: @@ -120,7 +120,7 @@ def _check_classification_inputs( extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] - if len(preds.shape) == 2: + if preds.ndim == 2: case = "multi-class" else: case = "multi-dim multi-class" @@ -299,7 +299,7 @@ def _input_format_classification( preds_float = preds.is_floating_point() - if len(preds.shape) == len(target.shape) == 1 and preds_float: + if preds.ndim == target.ndim == 1 and preds_float: mode = "binary" preds = (preds >= threshold).int() @@ -310,7 +310,7 @@ def _input_format_classification( preds = preds.unsqueeze(-1) target = target.unsqueeze(-1) - elif len(preds.shape) == len(target.shape) and preds_float: + elif preds.ndim == target.ndim and preds_float: mode = "multi-label" preds = (preds >= threshold).int() @@ -321,7 +321,7 @@ def _input_format_classification( preds = preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) - elif len(preds.shape) == len(target.shape) + 1 == 2: + elif preds.ndim == target.ndim + 1 == 2: mode = "multi-class" if not num_classes: num_classes = preds.shape[1] @@ -334,7 +334,7 @@ def _input_format_classification( target = target[:, [1]] preds = preds[:, [1]] - elif len(preds.shape) == len(target.shape) == 1 and not preds_float: + elif preds.ndim == target.ndim == 1 and not preds_float: mode = "multi-class" if not num_classes: diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 1ce56b30cf9e5..0f71c531fe6ac 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -109,14 +109,14 @@ def _input_format_classification( preds: tensor with labels target: tensor with labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + if preds.ndim == target.ndim and preds.dtype == torch.float: # binary or multilabel probablities preds = (preds >= threshold).long() return preds, target @@ -139,24 +139,24 @@ def _input_format_classification_one_hot( preds: one hot tensor of shape [num_classes, -1] with predicted labels target: one hot tensors of shape [num_classes, -1] with true labels """ - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - if len(preds.shape) == len(target.shape) + 1: + if preds.ndim == target.ndim + 1: # multi class probabilites preds = torch.argmax(preds, dim=1) - if len(preds.shape) == len(target.shape) and preds.dtype == torch.long and num_classes > 1 and not multilabel: + if preds.ndim == target.ndim and preds.dtype == torch.long and num_classes > 1 and not multilabel: # multi-class preds = to_onehot(preds, num_classes=num_classes) target = to_onehot(target, num_classes=num_classes) - elif len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + elif preds.ndim == target.ndim and preds.dtype == torch.float: # binary or multilabel probablities preds = (preds >= threshold).long() # transpose class as first dim and reshape - if len(preds.shape) > 1: + if preds.ndim > 1: preds = preds.transpose(1, 0) target = target.transpose(1, 0) From 9c33d0b3c14e50a984c59596eff7f6a513f19e2c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 22:55:56 +0100 Subject: [PATCH 06/94] More descriptive error message for input formatting --- pytorch_lightning/metrics/classification/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 62fe3e2095a78..6a47c7adbbae1 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -76,9 +76,7 @@ def _check_classification_inputs( if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError( - "preds should be probabilities, but values were detected outside of [0,1] range" - ) + raise ValueError("preds should be probabilities, but values were detected outside of [0,1] range") if threshold > 1 or threshold < 0: raise ValueError("Threshold should be a probability in [0,1]") @@ -92,7 +90,10 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases if preds.ndim == target.ndim: if preds.shape != target.shape: - raise ValueError("if preds and target have the same number of dimensions, they should have the same shape") + raise ValueError( + "preds and targets should have the same shape", + f" got preds shape = {preds.shape} and target shape = {target.shape}.", + ) if preds_float and target.max() > 1: raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") From 65622058778a8186bdc8b7f8852fc4690405de09 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:24:22 +0100 Subject: [PATCH 07/94] Replace movedim with permute --- pytorch_lightning/metrics/classification/utils.py | 10 +++++++--- tests/metrics/classification/test_inputs.py | 7 ++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 6a47c7adbbae1..f58e61dcd3127 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Tuple, Optional -import numpy as np import torch from pytorch_lightning.metrics.utils import to_onehot, select_topk @@ -256,7 +255,8 @@ def _input_format_classification( dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. If this is not the case, you should move it from the last to second place using - ``torch.movedim(preds, -1, 1)``. + ``torch.movedim(preds, -1, 1)``, or using ``preds.permute``, if you are using an older + version of Pytorch. Note that where a one-hot transformation needs to be performed and the number of classes is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be @@ -370,7 +370,11 @@ def _input_format_classification( else: mode = "multi-dim multi-class" if preds.shape[:-1] == target.shape: - preds = torch.movedim(preds, -1, 1) + shape_permute = list(range(preds.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + preds = preds.permute(*shape_permute) num_classes = preds.shape[1] diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8d17d5624fac0..19828dc07eb34 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -84,7 +84,12 @@ def top2(x): def mvdim(x): - return torch.movedim(x, -1, 1) + """ Equivalent of torch.movedim(x, -1, 1) """ + shape_permute = list(range(x.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + return x.permute(*shape_permute) # To avoid ugly black line wrapping From cbbc769cf615220fba7b1422c657c35c9d9a507a Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:28:05 +0100 Subject: [PATCH 08/94] PEP 8 compliance --- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index f9861388cceda..e56b4b9f76eb3 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -68,7 +68,7 @@ class has to be correctly predicted across all extra dimension for each sample i Args: preds: Predictions from model (probabilities, or labels) - target: Ground truth values + target: Ground truth values top_k: Number of highest probability predictions considered to find the correct label, for (multi-dimensional) multi-class inputs with probability predictions. Default 1 From f45fc817e6b0cc47f2f95d6a79e582d719b90d4c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 24 Nov 2020 23:59:55 +0100 Subject: [PATCH 09/94] Division with float --- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index e56b4b9f76eb3..fa423c611a450 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -45,7 +45,7 @@ def _accuracy_update( def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: - return correct / total + return correct.float() / total def accuracy( From a04a71ea195c601d59b22c7295c6d1389d7155fe Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:39:18 +0100 Subject: [PATCH 10/94] Style changes in error messages --- .../metrics/classification/utils.py | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index f58e61dcd3127..398acbf2bc5fc 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -62,39 +62,40 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("target has to be an integer tensor") + raise ValueError("`target` has to be an integer tensor") elif target.min() < 0: - raise ValueError("target has to be a non-negative tensor") + raise ValueError("`target` has to be a non-negative tensor") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if preds are integers, they have to be non-negative") + raise ValueError("if `preds` are integers, they have to be non-negative") if not preds.shape[0] == target.shape[0]: - raise ValueError("preds and target should have the same first dimension.") + raise ValueError("`preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("preds should be probabilities, but values were detected outside of [0,1] range") + raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range") if threshold > 1 or threshold < 0: raise ValueError("Threshold should be a probability in [0,1]") if is_multiclass is False and target.max() > 1: - raise ValueError("If you set is_multiclass=False, then target should not exceed 1.") + raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set is_multiclass=False and preds are integers, then preds should not exceed 1.") + raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") # Check that shape/types fall into one of the cases if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError( - "preds and targets should have the same shape", - f" got preds shape = {preds.shape} and target shape = {target.shape}.", + "`preds` and `target` should have the same shape", + f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: - raise ValueError("if preds and target are of shape (N, ...) and preds are floats, target should be binary") + raise ValueError("if `preds` and `target` are of shape (N, ...)" + " and `preds` are floats, `target` should be binary") # Get the case if preds.ndim == 1 and preds_float: @@ -110,12 +111,12 @@ def _check_classification_inputs( elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if preds have one dimension more than target, preds should be a float tensor") + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor") if not preds.shape[:-1] == target.shape: if preds.shape[2:] != target.shape[1:]: raise ValueError( - "if preds if preds have one dimension more than target, the shape of preds should be" - "either of shape (N, C, ...) or (N, ..., C), and of targets of shape (N, ...)" + "if `preds` have one dimension more than `target`, the shape of `preds` should be" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)" ) extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] @@ -126,64 +127,64 @@ def _check_classification_inputs( case = "multi-dim multi-class" else: raise ValueError( - "preds and target should both have the (same) shape (N, ...), or target (N, ...)" - " and preds (N, C, ...) or (N, ..., C)" + "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" + " and `preds` (N, C, ...) or (N, ..., C)" ) if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: raise ValueError( - "You have set is_multiclass=False, but have more than 2 classes in your data," - " based on the C dimension of preds." + "You have set `is_multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." ) # Check that num_classes is consistent if not num_classes: if preds.shape != target.shape and target.max() >= extra_dim_size: - raise ValueError("The highest label in targets should be smaller than the size of C dimension") + raise ValueError("The highest label in `target` should be smaller than the size of C dimension") else: if case == "binary": if num_classes > 2: - raise ValueError("Your data is binary, but num_classes is larger than 2.") + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") elif num_classes == 2 and not is_multiclass: raise ValueError( - "Your data is binary and num_classes=2, but is_multiclass is not True." - "Set it to True if you want to transform binary data to multi-class format." + "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." ) elif num_classes == 1 and is_multiclass: raise ValueError( - "You have binary data and have set is_multiclass=True, but num_classes is 1." - "Either leave is_multiclass unset or set it to 2 to transform binary data to multi-class format." + "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." + " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." ) elif "multi-class" in case: if num_classes == 1 and is_multiclass is not False: raise ValueError( - "You have set num_classes=1, but predictions are integers." - "If you want to convert (multi-dimensional) multi-class data with 2 classes" - "to binary/multi-label, set is_multiclass=False." + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `is_multiclass=False`." ) elif num_classes > 1: if is_multiclass is False: if implied_classes != num_classes: raise ValueError( - "You have set is_multiclass=False, but the implied number of classes " - "(from shape of inputs) does not match num_classes. If you are trying to" - "transform multi-dim multi-class data with 2 classes to multi-label, num_classes" - "should be either None or the product of the size of extra dimensions (...)." - "See Input Types in Metrics documentation." + "You have set `is_multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." ) if num_classes <= target.max(): - raise ValueError("The highest label in targets should be smaller than num_classes") + raise ValueError("The highest label in `target` should be smaller than `num_classes`") if num_classes <= preds.max(): - raise ValueError("The highest label in preds should be smaller than num_classes") + raise ValueError("The highest label in `preds` should be smaller than `num_classes`") if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of preds does not match num_classes") + raise ValueError("The size of C dimension of `preds` does not match `num_classes`") elif case == "multi-label": if is_multiclass and num_classes != 2: raise ValueError( - "Your have set is_multiclass=True, but num_classes is not equal to 2." - "If you are trying to transform multi-label data to 2 class multi-dimensional" - "multi-class, you should set num_classes to either 2 or None." + "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." ) if not is_multiclass and num_classes != implied_classes: raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") @@ -192,8 +193,8 @@ def _check_classification_inputs( if top_k > 1: if preds.shape == target.shape: raise ValueError( - "You have set top_k above 1, but your data is not (multi-dimensional) multi-class" - "with probability predictions." + "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" + " with probability predictions." ) From eaac5d74cffb2980450a95e15e9ce13f13802605 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:42:52 +0100 Subject: [PATCH 11/94] More error message style improvements --- .../metrics/classification/utils.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 398acbf2bc5fc..54bcd840f3621 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -62,23 +62,23 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("`target` has to be an integer tensor") + raise ValueError("`target` has to be an integer tensor.") elif target.min() < 0: - raise ValueError("`target` has to be a non-negative tensor") + raise ValueError("`target` has to be a non-negative tensor.") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if `preds` are integers, they have to be non-negative") + raise ValueError("if `preds` are integers, they have to be non-negative.") if not preds.shape[0] == target.shape[0]: raise ValueError("`preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range") + raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range.") if threshold > 1 or threshold < 0: - raise ValueError("Threshold should be a probability in [0,1]") + raise ValueError("`threshold` should be a probability in [0,1].") if is_multiclass is False and target.max() > 1: raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") @@ -94,8 +94,9 @@ def _check_classification_inputs( f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: - raise ValueError("if `preds` and `target` are of shape (N, ...)" - " and `preds` are floats, `target` should be binary") + raise ValueError( + "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + ) # Get the case if preds.ndim == 1 and preds_float: @@ -111,12 +112,12 @@ def _check_classification_inputs( elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor") + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") if not preds.shape[:-1] == target.shape: if preds.shape[2:] != target.shape[1:]: raise ValueError( "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." ) extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] @@ -128,7 +129,7 @@ def _check_classification_inputs( else: raise ValueError( "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)" + " and `preds` (N, C, ...) or (N, ..., C)." ) if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: @@ -140,7 +141,7 @@ def _check_classification_inputs( # Check that num_classes is consistent if not num_classes: if preds.shape != target.shape and target.max() >= extra_dim_size: - raise ValueError("The highest label in `target` should be smaller than the size of C dimension") + raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") else: if case == "binary": if num_classes > 2: @@ -173,11 +174,11 @@ def _check_classification_inputs( " See Input Types in Metrics documentation." ) if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`") + raise ValueError("The highest label in `target` should be smaller than `num_classes`.") if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`") + raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`") + raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") elif case == "multi-label": if is_multiclass and num_classes != 2: From c1108f0cf81359008f5b462c3ba0bd851fd50a0d Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:44:28 +0100 Subject: [PATCH 12/94] Fix typo in docs --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 407b64d3d2948..831ac67922114 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -230,7 +230,7 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. -When predictions or targets are integers, it is assumed that class labels start at , i.e. +When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types .. code-block:: python From 277769b7c4e3aa84a83b3e12eb724ea3318e44c0 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:50:02 +0100 Subject: [PATCH 13/94] Add more descriptive variable names in utils --- pytorch_lightning/metrics/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 0f71c531fe6ac..b7f8d492ce01d 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -37,14 +37,14 @@ def _flatten(x): def to_onehot( - tensor: torch.Tensor, + label_tensor: torch.Tensor, num_classes: int, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] + label_tensor: dense label tensor, with shape [N, d1, d2, ...] num_classes: number of classes C Output: @@ -57,18 +57,18 @@ def to_onehot( [0, 0, 1, 0], [0, 0, 0, 1]]) """ - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape + dtype, device, shape = label_tensor.dtype, label_tensor.device, label_tensor.shape tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) -def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: +def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ Convert a probability tensor to binary by selecting top-k highest entries. Args: - tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the + prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the position defined by the ``dim`` argument topk: number of highest entries to turn into 1s dim: dimension on which to compare entries @@ -82,8 +82,8 @@ def select_topk(tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tens tensor([[0, 1, 1], [1, 1, 0]], dtype=torch.int32) """ - zeros = torch.zeros_like(tensor, device=tensor.device) - topk_tensor = zeros.scatter(1, tensor.topk(k=topk, dim=dim).indices, 1.0) + zeros = torch.zeros_like(prob_tensor, device=prob_tensor.device) + topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() From 484929861c54922cf8f748eaf677efae4cff5fcb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 09:59:31 +0100 Subject: [PATCH 14/94] Change internal var names --- tests/metrics/classification/test_inputs.py | 118 ++++++++++---------- 1 file changed, 57 insertions(+), 61 deletions(-) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 19828dc07eb34..6ec6c8dbbc498 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -43,47 +43,43 @@ T = torch.Tensor -def idn(x): +def _idn(x): return x -def usq(x): +def _usq(x): return x.unsqueeze(-1) -def toint(x): - return x.int() - - -def thrs(x): +def _thrs(x): return x >= THRESHOLD -def rshp1(x): +def _rshp1(x): return x.reshape(x.shape[0], -1) -def rshp2(x): +def _rshp2(x): return x.reshape(x.shape[0], x.shape[1], -1) -def onehot(x): +def _onehot(x): return to_onehot(x, NUM_CLASSES) -def onehot2(x): +def _onehot2(x): return to_onehot(x, 2) -def top1(x): +def _top1(x): return select_topk(x, 1) -def top2(x): +def _top2(x): return select_topk(x, 2) -def mvdim(x): +def _mvdim(x): """ Equivalent of torch.movedim(x, -1, 1) """ shape_permute = list(range(x.ndim)) shape_permute[1] = shape_permute[-1] @@ -93,44 +89,44 @@ def mvdim(x): # To avoid ugly black line wrapping -def ml_preds_tr(x): - return rshp1(toint(thrs(x))) +def _ml_preds_tr(x): + return _rshp1(_thrs(x).int()) -def onehot_rshp1(x): - return onehot(rshp1(x)) +def _onehot_rshp1(x): + return _onehot(_rshp1(x)) -def onehot2_rshp1(x): - return onehot2(rshp1(x)) +def _onehot2_rshp1(x): + return _onehot2(_rshp1(x)) -def top1_rshp2(x): - return top1(rshp2(x)) +def _top1_rshp2(x): + return _top1(_rshp2(x)) -def top2_rshp2(x): - return top2(rshp2(x)) +def _top2_rshp2(x): + return _top2(_rshp2(x)) -def mdmc1_top1_tr(x): - return top1(rshp2(mvdim(x))) +def _mdmc1_top1_tr(x): + return _top1(_rshp2(_mvdim(x))) -def mdmc1_top2_tr(x): - return top2(rshp2(mvdim(x))) +def _mdmc1_top2_tr(x): + return _top2(_rshp2(_mvdim(x))) -def probs_to_mc_preds_tr(x): - return toint(onehot2(thrs(x))) +def _probs_to_mc_preds_tr(x): + return _onehot2(_thrs(x)).int() -def mlmd_prob_to_mc_preds_tr(x): - return onehot2(rshp1(toint(thrs(x)))) +def _mlmd_prob_to_mc_preds_tr(x): + return _onehot2(_rshp1(_thrs(x).int())) -def mdmc_prob_to_ml_preds_tr(x): - return top1(mvdim(x))[:, 1] +def _mdmc_prob_to__ml_preds_tr(x): + return _top1(_mvdim(x))[:, 1] ######################## @@ -143,44 +139,44 @@ def mdmc_prob_to_ml_preds_tr(x): [ ############################# # Test usual expected cases - (_bin, THRESHOLD, None, False, 1, "multi-class", usq, usq), - (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: usq(toint(thrs(x))), usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: toint(thrs(x)), idn), - (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", idn, idn), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", ml_preds_tr, rshp1), - (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", rshp1, rshp1), - (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", onehot, onehot), - (_mc_prob, THRESHOLD, None, None, 1, "multi-class", top1, onehot), - (_mc_prob, THRESHOLD, None, None, 2, "multi-class", top2, onehot), - (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", onehot, onehot), - (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot), - (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", top1_rshp2, onehot_rshp1), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", top2_rshp2, onehot_rshp1), + (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x).int()), _usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x).int(), _idn), + (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", _onehot, _onehot), + (_mc_prob, THRESHOLD, None, None, 1, "multi-class", _top1, _onehot), + (_mc_prob, THRESHOLD, None, None, 2, "multi-class", _top2, _onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), # Test with C dim in last place - (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot), - (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", mdmc1_top1_tr, onehot_rshp1), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", mdmc1_top2_tr, onehot_rshp1), + (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot), + (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot_rshp1), + (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass - (_bin, THRESHOLD, None, None, 1, "multi-class", onehot2, onehot2), + (_bin, THRESHOLD, None, None, 1, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, THRESHOLD, None, True, 1, "binary", probs_to_mc_preds_tr, onehot2), + (_bin_prob, THRESHOLD, None, True, 1, "binary", _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2, onehot2), + (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, THRESHOLD, None, True, 1, "multi-label", probs_to_mc_preds_tr, onehot2), + (_ml_prob, THRESHOLD, None, True, 1, "multi-label", _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", onehot2_rshp1, onehot2_rshp1), + (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", mlmd_prob_to_mc_preds_tr, onehot2_rshp1), + (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: top1(x)[:, [1]], usq), + (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: top1(x)[:, 1], idn), - (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", mdmc_prob_to_ml_preds_tr, idn), + (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", _mdmc_prob_to__ml_preds_tr, _idn), ], ) def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): From 02bd636b35306e01b427876fb2e80a3291154ebe Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 10:28:23 +0100 Subject: [PATCH 15/94] Break down error checking for inputs into separate functions --- .../metrics/classification/utils.py | 225 +++++++++++------- 1 file changed, 137 insertions(+), 88 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 54bcd840f3621..50892af13ba9e 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -18,6 +18,137 @@ from pytorch_lightning.metrics.utils import to_onehot, select_topk +def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: + """ + This checks that the shape and type of inputs are consistent with + each other and fall into one of the allowed input types (see the + documentation of docstring of _input_format_classification). It does + not check for consistency of number of classes, other functions take + care of that. + + It returns the name of the case in which the inputs fall, and the implied + number of classes (from the C dim for multi-class data, or extra dim(s) for + multi-label data). + """ + + preds_float = preds.is_floating_point() + + if preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "`preds` and `target` should have the same shape", + f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", + ) + if preds_float and target.max() > 1: + raise ValueError( + "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + ) + + # Get the case + if preds.ndim == 1 and preds_float: + case = "binary" + elif preds.ndim == 1 and not preds_float: + case = "multi-class" + elif preds.ndim > 1 and preds_float: + case = "multi-label" + else: + case = "multi-dim multi-class" + + implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) + + elif preds.ndim == target.ndim + 1: + if not preds_float: + raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if not preds.shape[:-1] == target.shape: + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "if `preds` have one dimension more than `target`, the shape of `preds` should be" + " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." + ) + + implied_classes = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] + + if preds.ndim == 2: + case = "multi-class" + else: + case = "multi-dim multi-class" + else: + raise ValueError( + "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" + " and `preds` (N, C, ...) or (N, ..., C)." + ) + + return case, implied_classes + + +def _check_num_classes_binary(num_classes: int, is_multiclass: bool): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for binary data. + """ + + if num_classes > 2: + raise ValueError("Your data is binary, but `num_classes` is larger than 2.") + elif num_classes == 2 and not is_multiclass: + raise ValueError( + "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." + " Set it to True if you want to transform binary data to multi-class format." + ) + elif num_classes == 1 and is_multiclass: + raise ValueError( + "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." + " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." + ) + + +def _check_num_classes_mc( + preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int +): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for (multi-dimensional) multi-class data. + """ + + if num_classes == 1 and is_multiclass is not False: + raise ValueError( + "You have set `num_classes=1`, but predictions are integers." + " If you want to convert (multi-dimensional) multi-class data with 2 classes" + " to binary/multi-label, set `is_multiclass=False`." + ) + elif num_classes > 1: + if is_multiclass is False: + if implied_classes != num_classes: + raise ValueError( + "You have set `is_multiclass=False`, but the implied number of classes " + " (from shape of inputs) does not match `num_classes`. If you are trying to" + " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" + " should be either None or the product of the size of extra dimensions (...)." + " See Input Types in Metrics documentation." + ) + if num_classes <= target.max(): + raise ValueError("The highest label in `target` should be smaller than `num_classes`.") + if num_classes <= preds.max(): + raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") + if preds.shape != target.shape and num_classes != implied_classes: + raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") + + +def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): + """ + This checks that the consistency of `num_classes` with the data + and `is_multiclass` param for multi-label data. + """ + + if is_multiclass and num_classes != 2: + raise ValueError( + "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." + " If you are trying to transform multi-label data to 2 class multi-dimensional" + " multi-class, you should set `num_classes` to either 2 or None." + ) + if not is_multiclass and num_classes != implied_classes: + raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + + def _check_classification_inputs( preds: torch.Tensor, target: torch.Tensor, @@ -87,52 +218,9 @@ def _check_classification_inputs( raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") # Check that shape/types fall into one of the cases - if preds.ndim == target.ndim: - if preds.shape != target.shape: - raise ValueError( - "`preds` and `target` should have the same shape", - f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", - ) - if preds_float and target.max() > 1: - raise ValueError( - "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." - ) - - # Get the case - if preds.ndim == 1 and preds_float: - case = "binary" - elif preds.ndim == 1 and not preds_float: - case = "multi-class" - elif preds.ndim > 1 and preds_float: - case = "multi-label" - else: - case = "multi-dim multi-class" - - implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) - - elif preds.ndim == target.ndim + 1: - if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if not preds.shape[:-1] == target.shape: - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." - ) - - extra_dim_size = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] - - if preds.ndim == 2: - case = "multi-class" - else: - case = "multi-dim multi-class" - else: - raise ValueError( - "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)." - ) + case, implied_classes = _check_shape_and_type_consistency(preds, target) - if preds.shape != target.shape and is_multiclass is False and extra_dim_size != 2: + if preds.shape != target.shape and is_multiclass is False and implied_classes != 2: raise ValueError( "You have set `is_multiclass=False`, but have more than 2 classes in your data," " based on the C dimension of `preds`." @@ -140,55 +228,16 @@ def _check_classification_inputs( # Check that num_classes is consistent if not num_classes: - if preds.shape != target.shape and target.max() >= extra_dim_size: + if preds.shape != target.shape and target.max() >= implied_classes: raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") else: if case == "binary": - if num_classes > 2: - raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - elif num_classes == 2 and not is_multiclass: - raise ValueError( - "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." - " Set it to True if you want to transform binary data to multi-class format." - ) - elif num_classes == 1 and is_multiclass: - raise ValueError( - "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." - ) + _check_num_classes_binary(num_classes, is_multiclass) elif "multi-class" in case: - if num_classes == 1 and is_multiclass is not False: - raise ValueError( - "You have set `num_classes=1`, but predictions are integers." - " If you want to convert (multi-dimensional) multi-class data with 2 classes" - " to binary/multi-label, set `is_multiclass=False`." - ) - elif num_classes > 1: - if is_multiclass is False: - if implied_classes != num_classes: - raise ValueError( - "You have set `is_multiclass=False`, but the implied number of classes " - " (from shape of inputs) does not match `num_classes`. If you are trying to" - " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" - " should be either None or the product of the size of extra dimensions (...)." - " See Input Types in Metrics documentation." - ) - if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`.") - if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") - if preds.shape != target.shape and num_classes != extra_dim_size: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") + _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) elif case == "multi-label": - if is_multiclass and num_classes != 2: - raise ValueError( - "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." - " If you are trying to transform multi-label data to 2 class multi-dimensional" - " multi-class, you should set `num_classes` to either 2 or None." - ) - if not is_multiclass and num_classes != implied_classes: - raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") + _check_num_classes_ml(num_classes, is_multiclass, implied_classes) # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities if top_k > 1: From f97145bbc599aa5bf75ba08a3567153fb2508086 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 10:53:37 +0100 Subject: [PATCH 16/94] Remove the (N, ..., C) option in MD-MC --- docs/source/metrics.rst | 2 +- .../metrics/classification/utils.py | 39 ++++++------------- tests/metrics/classification/inputs.py | 6 --- tests/metrics/classification/test_inputs.py | 35 ----------------- 4 files changed, 13 insertions(+), 69 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 831ac67922114..082e84e8c79f4 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -224,7 +224,7 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c "Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``" "Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*" "Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``" - "Multi-dimensional multi-class with probabilities", "(N, C, ...) or (N, ..., C)", "``float``", "(N, ...)", "``int``" + "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" .. note:: All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 50892af13ba9e..d665bd88266b8 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -59,14 +59,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) elif preds.ndim == target.ndim + 1: if not preds_float: raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if not preds.shape[:-1] == target.shape: - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " either of shape (N, C, ...) or (N, ..., C), and of `target` of shape (N, ...)." - ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "if `preds` have one dimension more than `target`, the shape of `preds` should be" + " of shape (N, C, ...), and `target` of shape (N, ...)." + ) - implied_classes = preds.shape[-1 if preds.shape[:-1] == target.shape else 1] + implied_classes = preds.shape[1] if preds.ndim == 2: case = "multi-class" @@ -263,15 +262,15 @@ def _input_format_classification( * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) * Both preds and target are of shape ``(N,)``, and target is binary, while preds - are a float (binary) + are a float (binary) * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and - is integer (multi-class) + is integer (multi-class) * preds and target are of shape ``(N, ...)``, target is binary and preds is a float - (multi-label) - * preds are of shape ``(N, ..., C)`` or ``(N, C, ...)`` and are floats, target is of - shape ``(N, ...)`` and is integer (multi-dimensional multi-class) + (multi-label) + * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` + and is integer (multi-dimensional multi-class) * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional - multi-class) + multi-class) To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. @@ -302,13 +301,6 @@ def _input_format_classification( ``is_multiclass=False`` (and there are up to two classes), then the data is returned as ``(N, X)`` binary tensors (multi-label). - Also, in multi-dimensional multi-class case, if the position of the ``C`` - dimension is ambiguous (e.g. if targets are a ``(7, 3)`` tensor, while predictions are a - ``(7, 3, 3)`` tensor), it will be assumed that the ``C`` dimension is the second dimension. - If this is not the case, you should move it from the last to second place using - ``torch.movedim(preds, -1, 1)``, or using ``preds.permute``, if you are using an older - version of Pytorch. - Note that where a one-hot transformation needs to be performed and the number of classes is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be equal to ``num_classes``, if it is given, or the maximum label value in preds and @@ -420,13 +412,6 @@ def _input_format_classification( # Multi-dim multi-class (N, C, ...) and (N, ..., C) else: mode = "multi-dim multi-class" - if preds.shape[:-1] == target.shape: - shape_permute = list(range(preds.ndim)) - shape_permute[1] = shape_permute[-1] - shape_permute[2:] = range(1, len(shape_permute) - 1) - - preds = preds.permute(*shape_permute) - num_classes = preds.shape[1] if is_multiclass is False: diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index e648aaf10093e..48d3e85e3afeb 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -70,12 +70,6 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) -# Class dimension last -_multidim_multiclass_prob_inputs1 = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) - _multidim_multiclass_inputs = Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 6ec6c8dbbc498..edce1232863f1 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -12,7 +12,6 @@ _multiclass_prob_inputs as _mc_prob, _multidim_multiclass_inputs as _mdmc, _multidim_multiclass_prob_inputs as _mdmc_prob, - _multidim_multiclass_prob_inputs1 as _mdmc_prob1, _multilabel_inputs as _ml, _multilabel_prob_inputs as _ml_prob, _multilabel_multidim_inputs as _mlmd, @@ -28,16 +27,9 @@ rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), ) -_mdmc_prob_many_dims1 = Input( - rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM, NUM_CLASSES), - randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), -) _mdmc_prob_2cls = Input( rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) -_mdmc_prob_2cls1 = Input( - rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) # Some utils T = torch.Tensor @@ -79,15 +71,6 @@ def _top2(x): return select_topk(x, 2) -def _mvdim(x): - """ Equivalent of torch.movedim(x, -1, 1) """ - shape_permute = list(range(x.ndim)) - shape_permute[1] = shape_permute[-1] - shape_permute[2:] = range(1, len(shape_permute) - 1) - - return x.permute(*shape_permute) - - # To avoid ugly black line wrapping def _ml_preds_tr(x): return _rshp1(_thrs(x).int()) @@ -109,14 +92,6 @@ def _top2_rshp2(x): return _top2(_rshp2(x)) -def _mdmc1_top1_tr(x): - return _top1(_rshp2(_mvdim(x))) - - -def _mdmc1_top2_tr(x): - return _top2(_rshp2(_mvdim(x))) - - def _probs_to_mc_preds_tr(x): return _onehot2(_thrs(x)).int() @@ -125,10 +100,6 @@ def _mlmd_prob_to_mc_preds_tr(x): return _onehot2(_rshp1(_thrs(x).int())) -def _mdmc_prob_to__ml_preds_tr(x): - return _top1(_mvdim(x))[:, 1] - - ######################## # Test correct inputs ######################## @@ -153,11 +124,6 @@ def _mdmc_prob_to__ml_preds_tr(x): (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), - # Test with C dim in last place - (_mdmc_prob1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot), - (_mdmc_prob1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 1, "multi-dim multi-class", _mdmc1_top1_tr, _onehot_rshp1), - (_mdmc_prob_many_dims1, THRESHOLD, None, None, 2, "multi-dim multi-class", _mdmc1_top2_tr, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass @@ -176,7 +142,6 @@ def _mdmc_prob_to__ml_preds_tr(x): (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), - (_mdmc_prob_2cls1, THRESHOLD, None, False, 1, "multi-dim multi-class", _mdmc_prob_to__ml_preds_tr, _idn), ], ) def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): From 536feafd8945b78b3c781d9f5c96f3433fa8be8c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 10:55:12 +0100 Subject: [PATCH 17/94] Simplify select_topk --- pytorch_lightning/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index b7f8d492ce01d..e78ba055665d8 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -82,7 +82,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch tensor([[0, 1, 1], [1, 1, 0]], dtype=torch.int32) """ - zeros = torch.zeros_like(prob_tensor, device=prob_tensor.device) + zeros = torch.zeros_like(prob_tensor) topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() From 4241d7c281a160c10178a34375781f019348657a Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 11:50:55 +0100 Subject: [PATCH 18/94] Remove detach for inputs --- pytorch_lightning/metrics/classification/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index d665bd88266b8..16d51d5c35683 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -324,8 +324,6 @@ def _input_format_classification( preds: binary tensor of shape (N, C) or (N, C, X) target: binary tensor of shape (N, C) or (N, C, X) """ - preds, target = preds.clone().detach(), target.clone().detach() - # Remove excess dimensions if preds.shape[0] == 1: preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) From 86d6c4d976cd6201652feaca8a5cb91554cb3603 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Wed, 25 Nov 2020 14:09:43 +0100 Subject: [PATCH 19/94] Fix typos --- pytorch_lightning/metrics/classification/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 16d51d5c35683..eb73987187820 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -36,8 +36,8 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError( - "`preds` and `target` should have the same shape", - f" got `preds shape = {preds.shape} and `target` shape = {target.shape}.", + "`preds` and `target` should have the same shape,", + f" got `preds` shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: raise ValueError( @@ -62,7 +62,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.shape[2:] != target.shape[1:]: raise ValueError( "if `preds` have one dimension more than `target`, the shape of `preds` should be" - " of shape (N, C, ...), and `target` of shape (N, ...)." + " (N, C, ...), and the shape of `target` should be (N, ...)." ) implied_classes = preds.shape[1] @@ -74,7 +74,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) else: raise ValueError( "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...) or (N, ..., C)." + " and `preds` (N, C, ...)." ) return case, implied_classes From cde39970fc79af4dc6adba981957ddcdd5b2628b Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 21:34:33 +0100 Subject: [PATCH 20/94] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index eb73987187820..9a49c7c4e8cdb 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -22,7 +22,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) """ This checks that the shape and type of inputs are consistent with each other and fall into one of the allowed input types (see the - documentation of docstring of _input_format_classification). It does + documentation of docstring of ``_input_format_classification``). It does not check for consistency of number of classes, other functions take care of that. From 05a54da4df1ab72d1ec0424d4bd3af87f8e6bee5 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 21:40:54 +0100 Subject: [PATCH 21/94] Update docs/source/metrics.rst Co-authored-by: Jirka Borovec --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 082e84e8c79f4..b59fdc6c73009 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -233,7 +233,7 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types -.. code-block:: python +.. testcode:: # Binary inputs binary_preds = torch.tensor([0.6, 0.1, 0.9]) From 9a43a5eafe106db04652b2e80bbcf64f0f0308f7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 23:18:41 +0100 Subject: [PATCH 22/94] Minor error message changes --- .../metrics/classification/utils.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 9a49c7c4e8cdb..9ddcfb9132c2c 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -36,12 +36,12 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.ndim == target.ndim: if preds.shape != target.shape: raise ValueError( - "`preds` and `target` should have the same shape,", + "The `preds` and `target` should have the same shape,", f" got `preds` shape = {preds.shape} and `target` shape = {target.shape}.", ) if preds_float and target.max() > 1: raise ValueError( - "if `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." + "If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." ) # Get the case @@ -58,10 +58,10 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) elif preds.ndim == target.ndim + 1: if not preds_float: - raise ValueError("if `preds` have one dimension more than `target`, `preds` should be a float tensor.") + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") if preds.shape[2:] != target.shape[1:]: raise ValueError( - "if `preds` have one dimension more than `target`, the shape of `preds` should be" + "If `preds` have one dimension more than `target`, the shape of `preds` should be" " (N, C, ...), and the shape of `target` should be (N, ...)." ) @@ -73,7 +73,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) case = "multi-dim multi-class" else: raise ValueError( - "`preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" + "The `preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" " and `preds` (N, C, ...)." ) @@ -192,23 +192,23 @@ def _check_classification_inputs( """ if target.is_floating_point(): - raise ValueError("`target` has to be an integer tensor.") + raise ValueError("The `target` has to be an integer tensor.") elif target.min() < 0: - raise ValueError("`target` has to be a non-negative tensor.") + raise ValueError("The `target` has to be a non-negative tensor.") preds_float = preds.is_floating_point() if not preds_float and preds.min() < 0: - raise ValueError("if `preds` are integers, they have to be non-negative.") + raise ValueError("If `preds` are integers, they have to be non-negative.") if not preds.shape[0] == target.shape[0]: - raise ValueError("`preds` and `target` should have the same first dimension.") + raise ValueError("The `preds` and `target` should have the same first dimension.") if preds_float: if preds.min() < 0 or preds.max() > 1: - raise ValueError("`preds` should be probabilities, but values were detected outside of [0,1] range.") + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") if threshold > 1 or threshold < 0: - raise ValueError("`threshold` should be a probability in [0,1].") + raise ValueError("The `threshold` should be a probability in [0,1].") if is_multiclass is False and target.max() > 1: raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") From 3f4ad3c5a25bc82ad61a671701202f4b269d3a6c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 23:19:10 +0100 Subject: [PATCH 23/94] Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec --- pytorch_lightning/metrics/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e78ba055665d8..170315aa22236 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -57,8 +57,13 @@ def to_onehot( [0, 0, 1, 0], [0, 0, 0, 1]]) """ - dtype, device, shape = label_tensor.dtype, label_tensor.device, label_tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) + tensor_onehot = torch.zeros( + label_tensor.shape[0], + num_classes, + *label_tensor.shape[1:], + dtype=label_tensor.dtype, + device=label_tensor.device, + ) index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) From a654e6a9b70322c70232b810f8c156ff3d055138 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Thu, 26 Nov 2020 23:27:10 +0100 Subject: [PATCH 24/94] Reuse case from validation in formatting --- .../metrics/classification/utils.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 9ddcfb9132c2c..4183130704e8f 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -155,7 +155,7 @@ def _check_classification_inputs( num_classes: Optional[int] = None, is_multiclass: bool = False, top_k: int = 1, -) -> None: +) -> str: """Performs error checking on inputs for classification. This ensures that preds and target take one of the shape/type combinations that are @@ -189,6 +189,10 @@ def _check_classification_inputs( multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. Defaults to None, which treats inputs as they appear. + + Return: + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' """ if target.is_floating_point(): @@ -246,6 +250,8 @@ def _check_classification_inputs( " with probability predictions." ) + return case + def _input_format_classification( preds: torch.Tensor, @@ -276,7 +282,7 @@ def _input_format_classification( The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` of ``(N, C, X)``, the details for each case are described below. The function also returns - a ``mode`` string, which describes which of the above cases the inputs belonged to - regardless + a ``case`` string, which describes which of the above cases the inputs belonged to - regardless of whether this was "overridden" by other settings (like ``is_multiclass``). In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed @@ -323,6 +329,8 @@ def _input_format_classification( Returns: preds: binary tensor of shape (N, C) or (N, C, X) target: binary tensor of shape (N, C) or (N, C, X) + case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or + 'multi-dim multi-class' """ # Remove excess dimensions if preds.shape[0] == 1: @@ -330,7 +338,7 @@ def _input_format_classification( else: preds, target = preds.squeeze(), target.squeeze() - _check_classification_inputs( + case = _check_classification_inputs( preds, target, threshold=threshold, @@ -341,8 +349,7 @@ def _input_format_classification( preds_float = preds.is_floating_point() - if preds.ndim == target.ndim == 1 and preds_float: - mode = "binary" + if case == "binary": preds = (preds >= threshold).int() if is_multiclass: @@ -352,8 +359,7 @@ def _input_format_classification( preds = preds.unsqueeze(-1) target = target.unsqueeze(-1) - elif preds.ndim == target.ndim and preds_float: - mode = "multi-label" + elif case == "multi-label": preds = (preds >= threshold).int() if is_multiclass: @@ -363,8 +369,8 @@ def _input_format_classification( preds = preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) + # Multi-class with probabilities elif preds.ndim == target.ndim + 1 == 2: - mode = "multi-class" if not num_classes: num_classes = preds.shape[1] @@ -376,9 +382,8 @@ def _input_format_classification( target = target[:, [1]] preds = preds[:, [1]] + # Multi-class with labels elif preds.ndim == target.ndim == 1 and not preds_float: - mode = "multi-class" - if not num_classes: num_classes = max(preds.max(), target.max()) + 1 @@ -392,8 +397,6 @@ def _input_format_classification( # Multi-dim multi-class (N, ...) with integers elif preds.shape == target.shape and not preds_float: - mode = "multi-dim multi-class" - if not num_classes: num_classes = max(preds.max(), target.max()) + 1 @@ -407,9 +410,8 @@ def _input_format_classification( preds = to_onehot(preds, num_classes) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) - # Multi-dim multi-class (N, C, ...) and (N, ..., C) + # Multi-dim multi-class (N, C, ...) else: - mode = "multi-dim multi-class" num_classes = preds.shape[1] if is_multiclass is False: @@ -421,4 +423,4 @@ def _input_format_classification( target = target.reshape(target.shape[0], target.shape[1], -1) preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) - return preds, target, mode + return preds, target, case From 16ab8f784b0bdbc6c2b278678c44eacd62093c4e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 01:01:27 +0100 Subject: [PATCH 25/94] Refactor code in _input_format_classification --- .../metrics/classification/utils.py | 90 +++++-------------- tests/metrics/classification/test_inputs.py | 24 ++--- 2 files changed, 36 insertions(+), 78 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 4183130704e8f..c224f3f1caa18 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -249,6 +249,8 @@ def _check_classification_inputs( "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" " with probability predictions." ) + if is_multiclass is False: + raise ValueError("If you set `is_multiclass` to False, you can not set `top_k` above 1.") return case @@ -330,7 +332,7 @@ def _input_format_classification( preds: binary tensor of shape (N, C) or (N, C, X) target: binary tensor of shape (N, C) or (N, C, X) case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or - 'multi-dim multi-class' + 'multi-dim multi-class' """ # Remove excess dimensions if preds.shape[0] == 1: @@ -347,80 +349,34 @@ def _input_format_classification( top_k=top_k, ) - preds_float = preds.is_floating_point() - - if case == "binary": - preds = (preds >= threshold).int() - - if is_multiclass: - target = to_onehot(target, 2) - preds = to_onehot(preds, 2) - else: - preds = preds.unsqueeze(-1) - target = target.unsqueeze(-1) - - elif case == "multi-label": + if case in ["binary", "multi-label"]: preds = (preds >= threshold).int() + num_classes = num_classes if not is_multiclass else 2 - if is_multiclass: - preds = to_onehot(preds, 2).reshape(preds.shape[0], 2, -1) - target = to_onehot(target, 2).reshape(target.shape[0], 2, -1) - else: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - - # Multi-class with probabilities - elif preds.ndim == target.ndim + 1 == 2: - if not num_classes: + if "multi-class" in case or is_multiclass: + if preds.is_floating_point(): num_classes = preds.shape[1] - - target = to_onehot(target, num_classes) - preds = select_topk(preds, top_k) - - # If is_multiclass=False, force to binary - if is_multiclass is False: - target = target[:, [1]] - preds = preds[:, [1]] - - # Multi-class with labels - elif preds.ndim == target.ndim == 1 and not preds_float: - if not num_classes: - num_classes = max(preds.max(), target.max()) + 1 - - # If is_multiclass=False, force to binary - if is_multiclass is False: - preds = preds.unsqueeze(1) - target = target.unsqueeze(1) + preds = select_topk(preds, top_k) else: + num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 preds = to_onehot(preds, num_classes) - target = to_onehot(target, num_classes) - # Multi-dim multi-class (N, ...) with integers - elif preds.shape == target.shape and not preds_float: - if not num_classes: - num_classes = max(preds.max(), target.max()) + 1 + target = to_onehot(target, num_classes) - # If is_multiclass=False, force to multi-label if is_multiclass is False: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - else: - target = to_onehot(target, num_classes) - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = to_onehot(preds, num_classes) - preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + preds, target = preds[:, 1, ...], target[:, 1, ...] - # Multi-dim multi-class (N, C, ...) - else: - num_classes = preds.shape[1] + if (case in ["binary", "multi-label"] and not is_multiclass) or is_multiclass is False: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) - if is_multiclass is False: - target = target.reshape(target.shape[0], -1) - preds = select_topk(preds, 1)[:, 1, ...] - preds = preds.reshape(preds.shape[0], -1) - else: - target = to_onehot(target, num_classes) - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = select_topk(preds, top_k).reshape(preds.shape[0], preds.shape[1], -1) + elif "multi-class" in case or is_multiclass: + target = target.reshape(target.shape[0], target.shape[1], -1) + preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + + # Some operatins above create an extra dimension for MC/binary case - this removes it + if preds.ndim > 2: + preds = preds.squeeze(-1) + target = target.squeeze(-1) - return preds, target, case + return preds.int(), target.int(), case diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index edce1232863f1..79788a1c2ecf4 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -73,7 +73,7 @@ def _top2(x): # To avoid ugly black line wrapping def _ml_preds_tr(x): - return _rshp1(_thrs(x).int()) + return _rshp1(_thrs(x)) def _onehot_rshp1(x): @@ -93,11 +93,11 @@ def _top2_rshp2(x): def _probs_to_mc_preds_tr(x): - return _onehot2(_thrs(x)).int() + return _onehot2(_thrs(x)) def _mlmd_prob_to_mc_preds_tr(x): - return _onehot2(_rshp1(_thrs(x).int())) + return _onehot2(_rshp1(_thrs(x))) ######################## @@ -111,8 +111,8 @@ def _mlmd_prob_to_mc_preds_tr(x): ############################# # Test usual expected cases (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), - (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x).int()), _usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x).int(), _idn), + (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x), _idn), (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), @@ -155,8 +155,8 @@ def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_m ) assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0])) - assert torch.equal(target_out, post_target(inputs.target[0])) + assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) + assert torch.equal(target_out, post_target(inputs.target[0]).int()) # Test that things work when batch_size = 1 preds_out, target_out, mode = _input_format_classification( @@ -169,8 +169,8 @@ def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_m ) assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...])) - assert torch.equal(target_out, post_target(inputs.target[0][[0], ...])) + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) # Test that threshold is correctly applied @@ -180,7 +180,7 @@ def test_threshold(): preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) - assert torch.equal(torch.tensor([0, 1, 1]), preds_probs_out.squeeze().long()) + assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int()) ######################################################################## @@ -222,7 +222,7 @@ def test_threshold(): # Max target larger or equal to C dimension (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, 1), # C dimension not equal to num_classes - (rand(size=(7, 3, 4)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), + (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), # Max target larger than num_classes (with #dim preds = 1 + #dims target) (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), # Max target larger than num_classes (with #dim preds = #dims target) @@ -253,6 +253,8 @@ def test_threshold(): (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), + # Topk =2 with 2 classes, is_multiclass=False + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], 0.5, None, False, 2), ], ) def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): From ecffe18d3b1a9bf6d4c173af29d673274aaaa426 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 22:48:24 +0100 Subject: [PATCH 26/94] Small improvements --- pytorch_lightning/metrics/classification/utils.py | 9 ++++----- tests/metrics/classification/test_inputs.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index c224f3f1caa18..68b4a2da789da 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -366,13 +366,12 @@ def _input_format_classification( if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] - if (case in ["binary", "multi-label"] and not is_multiclass) or is_multiclass is False: - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - - elif "multi-class" in case or is_multiclass: + if ("multi-class" in case and is_multiclass is not False) or is_multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) # Some operatins above create an extra dimension for MC/binary case - this removes it if preds.ndim > 2: diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 79788a1c2ecf4..430d844217cb5 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -112,7 +112,7 @@ def _mlmd_prob_to_mc_preds_tr(x): # Test usual expected cases (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", lambda x: _thrs(x), _idn), + (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _thrs, _idn), (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), From 725c7dd717199b8d35becd77ea1490b7ee3fbddb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:01:01 +0100 Subject: [PATCH 27/94] PEP 8 --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 68b4a2da789da..5335083e36b95 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -364,7 +364,7 @@ def _input_format_classification( target = to_onehot(target, num_classes) if is_multiclass is False: - preds, target = preds[:, 1, ...], target[:, 1, ...] + preds, target = preds[:, 1, ...], target[:, 1, ...] if ("multi-class" in case and is_multiclass is not False) or is_multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) From 41ad0b7163229b7397daeababd1924a10d2db372 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:04:18 +0100 Subject: [PATCH 28/94] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 5335083e36b95..79a382f5104a8 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -54,7 +54,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) else: case = "multi-dim multi-class" - implied_classes = torch.prod(torch.Tensor(list(preds.shape[1:]))) + implied_classes = preds[0].numel() elif preds.ndim == target.ndim + 1: if not preds_float: From ca13e76713c9f5b5bfc5f315e2380040b9b726fb Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:06:57 +0100 Subject: [PATCH 29/94] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 79a382f5104a8..2dac8b7b7eae9 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -73,8 +73,8 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) case = "multi-dim multi-class" else: raise ValueError( - "The `preds` and `target` should both have the (same) shape (N, ...), or `target` (N, ...)" - " and `preds` (N, C, ...)." + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." ) return case, implied_classes From ede2c7fa2ff98aba446591da11b9dac28df2824e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:07:14 +0100 Subject: [PATCH 30/94] Update docs/source/metrics.rst Co-authored-by: Rohit Gupta --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index b59fdc6c73009..4be5e67b7c447 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -251,7 +251,7 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class, +In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label. For example, if both predictions and targets are 1d binary tensors. Or it could be the other way around, you want to treat binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs. From c6e4de44ff93160fcc30467a708946de7185fa10 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:07:35 +0100 Subject: [PATCH 31/94] Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 2dac8b7b7eae9..76d4348978cc0 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -37,7 +37,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) if preds.shape != target.shape: raise ValueError( "The `preds` and `target` should have the same shape,", - f" got `preds` shape = {preds.shape} and `target` shape = {target.shape}.", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", ) if preds_float and target.max() > 1: raise ValueError( From 201d0debf8026681cf8756f52cbb3bd447935aa5 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:17:24 +0100 Subject: [PATCH 32/94] Apply suggestions from code review Co-authored-by: Rohit Gupta --- docs/source/metrics.rst | 2 +- pytorch_lightning/metrics/classification/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 4be5e67b7c447..6d70b92ca8a9f 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -203,7 +203,7 @@ Class vs Functional Metrics The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. -If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. +If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface. ********************** Classification Metrics diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 76d4348978cc0..ac6ec711a2654 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -96,7 +96,7 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool): elif num_classes == 1 and is_multiclass: raise ValueError( "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either leave `is_multiclass` unset or set it to 2 to transform binary data to multi-class format." + " Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format." ) From f08edbcc6e8eaaa2eec655155db82ade05c0230e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Fri, 27 Nov 2020 23:25:09 +0100 Subject: [PATCH 33/94] Alphabetical reordering of regression metrics --- docs/source/metrics.rst | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index b59fdc6c73009..50dc36fac6b25 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -202,7 +202,7 @@ Class vs Functional Metrics The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs. -Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. +Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface. If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface. ********************** @@ -447,10 +447,10 @@ Regression Metrics Class Metrics (Regression) -------------------------- -MeanSquaredError -~~~~~~~~~~~~~~~~ +ExplainedVariance +~~~~~~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance :noindex: @@ -461,17 +461,17 @@ MeanAbsoluteError :noindex: -MeanSquaredLogError -~~~~~~~~~~~~~~~~~~~ +MeanSquaredError +~~~~~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError :noindex: -ExplainedVariance -~~~~~~~~~~~~~~~~~ +MeanSquaredLogError +~~~~~~~~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError :noindex: @@ -513,17 +513,17 @@ mean_squared_error [func] :noindex: -psnr [func] -~~~~~~~~~~~ +mean_squared_log_error [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.psnr +.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error :noindex: -mean_squared_log_error [func] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +psnr [func] +~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error +.. autofunction:: pytorch_lightning.metrics.functional.psnr :noindex: From 35e3eff9e4f5dcf9bc4e32722becf9fbfba3fc53 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sat, 28 Nov 2020 16:14:16 +0100 Subject: [PATCH 34/94] Change default value of top_k and add error checking --- .../metrics/classification/utils.py | 80 +++++++++------- tests/metrics/classification/test_inputs.py | 96 ++++++++++--------- 2 files changed, 96 insertions(+), 80 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index ac6ec711a2654..af4319152942f 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -148,13 +148,25 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") +def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): + if "multi-class" not in case or not preds_float: + raise ValueError( + "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" + " with probability predictions." + ) + if is_multiclass is False: + raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") + if top_k >= implied_classes: + raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") + + def _check_classification_inputs( preds: torch.Tensor, target: torch.Tensor, threshold: float, - num_classes: Optional[int] = None, - is_multiclass: bool = False, - top_k: int = 1, + num_classes: Optional[int], + is_multiclass: bool, + top_k: Optional[int], ) -> str: """Performs error checking on inputs for classification. @@ -172,8 +184,9 @@ def _check_classification_inputs( When ``num_classes`` is not specified in these cases, consistency of the highest target value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. - If ``top_k`` is larger than one, then an error is raised if the inputs are not (multi-dim) - multi-class with probability predictions. + If ``top_k`` is set (not None) for inputs which are not (multi-dimensional) multi class + with probabilities, then an error is raised. Similarly if ``top_k`` is set to a number + that is higher than or equal to the ``C`` dimension of ``preds``. Preds and target tensors are expected to be squeezed already - all dimensions should be greater than 1, except perhaps the first one (N). @@ -189,6 +202,8 @@ def _check_classification_inputs( multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. Defaults to None, which treats inputs as they appear. + top_k: number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class cases. Return: case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or @@ -197,7 +212,7 @@ def _check_classification_inputs( if target.is_floating_point(): raise ValueError("The `target` has to be an integer tensor.") - elif target.min() < 0: + if target.min() < 0: raise ValueError("The `target` has to be a non-negative tensor.") preds_float = preds.is_floating_point() @@ -207,9 +222,8 @@ def _check_classification_inputs( if not preds.shape[0] == target.shape[0]: raise ValueError("The `preds` and `target` should have the same first dimension.") - if preds_float: - if preds.min() < 0 or preds.max() > 1: - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") + if preds_float and (preds.min() < 0 or preds.max() > 1): + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") if threshold > 1 or threshold < 0: raise ValueError("The `threshold` should be a probability in [0,1].") @@ -223,34 +237,30 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) - if preds.shape != target.shape and is_multiclass is False and implied_classes != 2: - raise ValueError( - "You have set `is_multiclass=False`, but have more than 2 classes in your data," - " based on the C dimension of `preds`." - ) + # Check consistency with the `C` dimension in case of multi-class data + if preds.shape != target.shape: + if is_multiclass is False and implied_classes != 2: + raise ValueError( + "You have set `is_multiclass=False`, but have more than 2 classes in your data," + " based on the C dimension of `preds`." + ) + if target.max() >= implied_classes: + raise ValueError( + "The highest label in `target` should be smaller than the size of the `C` dimension of `preds`." + ) # Check that num_classes is consistent - if not num_classes: - if preds.shape != target.shape and target.max() >= implied_classes: - raise ValueError("The highest label in `target` should be smaller than the size of C dimension.") - else: + if num_classes: if case == "binary": _check_num_classes_binary(num_classes, is_multiclass) elif "multi-class" in case: _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) - elif case == "multi-label": _check_num_classes_ml(num_classes, is_multiclass, implied_classes) - # Check that if top_k > 1, we have (multi-class) multi-dim with probabilities - if top_k > 1: - if preds.shape == target.shape: - raise ValueError( - "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" - " with probability predictions." - ) - if is_multiclass is False: - raise ValueError("If you set `is_multiclass` to False, you can not set `top_k` above 1.") + # Check that top_k is consistent + if top_k: + _check_top_k(top_k, case, implied_classes, is_multiclass, preds_float) return case @@ -259,7 +269,7 @@ def _input_format_classification( preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, - top_k: int = 1, + top_k: Optional[int] = None, num_classes: Optional[int] = None, is_multiclass: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor, str]: @@ -322,7 +332,10 @@ def _input_format_classification( (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 num_classes: number of classes top_k: number of highest probability entries for each sample to convert to 1s, relevant - only for (multi-dimensional) multi-class cases. + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as one for these inputs. + + Should be left unset (``None``) for all other types of inputs. is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. @@ -349,6 +362,8 @@ def _input_format_classification( top_k=top_k, ) + top_k = top_k if top_k else 1 + if case in ["binary", "multi-label"]: preds = (preds >= threshold).int() num_classes = num_classes if not is_multiclass else 2 @@ -370,12 +385,11 @@ def _input_format_classification( target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) else: - preds = preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) + preds = preds.reshape(preds.shape[0], -1) # Some operatins above create an extra dimension for MC/binary case - this removes it if preds.ndim > 2: - preds = preds.squeeze(-1) - target = target.squeeze(-1) + preds, target = preds.squeeze(-1), target.squeeze(-1) return preds.int(), target.int(), case diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 430d844217cb5..058ec66c10ed6 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -110,38 +110,38 @@ def _mlmd_prob_to_mc_preds_tr(x): [ ############################# # Test usual expected cases - (_bin, THRESHOLD, None, False, 1, "multi-class", _usq, _usq), - (_bin_prob, THRESHOLD, None, None, 1, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _thrs, _idn), - (_ml, THRESHOLD, None, False, 1, "multi-dim multi-class", _idn, _idn), - (_ml_prob, THRESHOLD, None, None, 1, "multi-label", _ml_preds_tr, _rshp1), - (_mlmd, THRESHOLD, None, False, 1, "multi-dim multi-class", _rshp1, _rshp1), - (_mc, THRESHOLD, NUM_CLASSES, None, 1, "multi-class", _onehot, _onehot), - (_mc_prob, THRESHOLD, None, None, 1, "multi-class", _top1, _onehot), + (_bin, THRESHOLD, None, False, None, "multi-class", _usq, _usq), + (_bin_prob, THRESHOLD, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, THRESHOLD, None, None, None, "multi-label", _thrs, _idn), + (_ml, THRESHOLD, None, False, None, "multi-dim multi-class", _idn, _idn), + (_ml_prob, THRESHOLD, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, THRESHOLD, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, THRESHOLD, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), + (_mc_prob, THRESHOLD, None, None, None, "multi-class", _top1, _onehot), (_mc_prob, THRESHOLD, None, None, 2, "multi-class", _top2, _onehot), - (_mdmc, THRESHOLD, NUM_CLASSES, None, 1, "multi-dim multi-class", _onehot, _onehot), - (_mdmc_prob, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc, THRESHOLD, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 1, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass - (_bin, THRESHOLD, None, None, 1, "multi-class", _onehot2, _onehot2), + (_bin, THRESHOLD, None, None, None, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, THRESHOLD, None, True, 1, "binary", _probs_to_mc_preds_tr, _onehot2), + (_bin_prob, THRESHOLD, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2, _onehot2), + (_ml, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, THRESHOLD, None, True, 1, "multi-label", _probs_to_mc_preds_tr, _onehot2), + (_ml_prob, THRESHOLD, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, THRESHOLD, None, True, 1, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), + (_mlmd, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, THRESHOLD, None, True, 1, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), + (_mlmd_prob, THRESHOLD, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, THRESHOLD, None, False, 1, "multi-class", lambda x: _top1(x)[:, [1]], _usq), + (_mc_prob_2cls, THRESHOLD, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, THRESHOLD, None, False, 1, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls, THRESHOLD, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), ], ) def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): @@ -192,59 +192,59 @@ def test_threshold(): "preds, target, threshold, num_classes, is_multiclass, top_k", [ # Target not integer - (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, 1), + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, None), # Target negative - (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, 1), + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, None), # Preds negative integers - (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), # Negative probabilities - (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), # Threshold outside of [0,1] - (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, None), # is_multiclass=False and target > 1 - (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, 1), + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, None), # is_multiclass=False and preds integers with > 1 - (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, 1), + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, None), # Wrong batch size - (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, 1), + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, None), # Completely wrong shape - (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, None), # Same #dims, different shape - (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, 1), + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, None), # Same shape and preds floats, target not binary - (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, 1), + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, None), # #dims in preds = 1 + #dims in target, C shape not second or last - (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # #dims in preds = 1 + #dims in target, preds not float - (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, 1), + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # is_multiclass=False, with C dimension > 2 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, 1), + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, None), # Max target larger or equal to C dimension - (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, 1), + (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, None), # C dimension not equal to num_classes - (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, 1), + (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, None), # Max target larger than num_classes (with #dim preds = 1 + #dims target) - (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), # Max target larger than num_classes (with #dim preds = #dims target) - (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, 1), + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), # Max preds larger than num_classes (with #dim preds = #dims target) - (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, 1), + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, None), # Num_classes=1, but is_multiclass not false (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, 1, None, 1), # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes - (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), # Multilabel input with implied class dimension != num_classes - (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, 1), + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) - (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, 1), + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, None), # Binary input, num_classes > 2 - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, None), # Binary input, num_classes == 2 and is_multiclass not True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, 1), - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, 1), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, None), # Binary input, num_classes == 1 and is_multiclass=True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, 1), - # Topk > 1 with non (md)mc prob data + (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, None), + # Topk set with non (md)mc prob data (_bin.preds[0], _bin.target[0], 0.5, None, None, 2), (_bin_prob.preds[0], _bin_prob.target[0], 0.5, None, None, 2), (_mc.preds[0], _mc.target[0], 0.5, None, None, 2), @@ -253,8 +253,10 @@ def test_threshold(): (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), - # Topk =2 with 2 classes, is_multiclass=False + # top_k =2 with 2 classes, is_multiclass=False (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], 0.5, None, False, 2), + # top_k = number of classes (C dimension) + (_mc_prob.preds[0], _mc_prob.target[0], 0.5, None, None, NUM_CLASSES), ], ) def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): From c28aadf7e9774419326462949d5f371812c6bed4 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sat, 28 Nov 2020 16:23:27 +0100 Subject: [PATCH 35/94] Extract basic validation into separate function --- .../metrics/classification/utils.py | 64 +++++++++++-------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index af4319152942f..eda929fdd32ed 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -18,6 +18,37 @@ from pytorch_lightning.metrics.utils import to_onehot, select_topk +def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): + """ + Perform basic validation of inputs that does not require deducing any information + of the type of inputs. + """ + + if target.is_floating_point(): + raise ValueError("The `target` has to be an integer tensor.") + if target.min() < 0: + raise ValueError("The `target` has to be a non-negative tensor.") + + preds_float = preds.is_floating_point() + if not preds_float and preds.min() < 0: + raise ValueError("If `preds` are integers, they have to be non-negative.") + + if not preds.shape[0] == target.shape[0]: + raise ValueError("The `preds` and `target` should have the same first dimension.") + + if preds_float and (preds.min() < 0 or preds.max() > 1): + raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") + + if threshold > 1 or threshold < 0: + raise ValueError("The `threshold` should be a probability in [0,1].") + + if is_multiclass is False and target.max() > 1: + raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") + + if is_multiclass is False and not preds_float and preds.max() > 1: + raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") + + def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: """ This checks that the shape and type of inputs are consistent with @@ -88,12 +119,12 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool): if num_classes > 2: raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - elif num_classes == 2 and not is_multiclass: + if num_classes == 2 and not is_multiclass: raise ValueError( "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." " Set it to True if you want to transform binary data to multi-class format." ) - elif num_classes == 1 and is_multiclass: + if num_classes == 1 and is_multiclass: raise ValueError( "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." " Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format." @@ -114,7 +145,7 @@ def _check_num_classes_mc( " If you want to convert (multi-dimensional) multi-class data with 2 classes" " to binary/multi-label, set `is_multiclass=False`." ) - elif num_classes > 1: + if num_classes > 1: if is_multiclass is False: if implied_classes != num_classes: raise ValueError( @@ -210,29 +241,8 @@ def _check_classification_inputs( 'multi-dim multi-class' """ - if target.is_floating_point(): - raise ValueError("The `target` has to be an integer tensor.") - if target.min() < 0: - raise ValueError("The `target` has to be a non-negative tensor.") - - preds_float = preds.is_floating_point() - if not preds_float and preds.min() < 0: - raise ValueError("If `preds` are integers, they have to be non-negative.") - - if not preds.shape[0] == target.shape[0]: - raise ValueError("The `preds` and `target` should have the same first dimension.") - - if preds_float and (preds.min() < 0 or preds.max() > 1): - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") - - if threshold > 1 or threshold < 0: - raise ValueError("The `threshold` should be a probability in [0,1].") - - if is_multiclass is False and target.max() > 1: - raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") - - if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") + # Baisc validation (that does not need case/type information) + _basic_input_validation(preds, target, threshold, is_multiclass) # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) @@ -260,7 +270,7 @@ def _check_classification_inputs( # Check that top_k is consistent if top_k: - _check_top_k(top_k, case, implied_classes, is_multiclass, preds_float) + _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) return case From 323285e0899893049772bbff1fc339fba213cf4e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sat, 28 Nov 2020 16:38:10 +0100 Subject: [PATCH 36/94] Update to new top_k default --- .../metrics/classification/accuracy.py | 26 +++++++++-------- .../metrics/functional/accuracy.py | 28 +++++++++++-------- tests/metrics/classification/test_accuracy.py | 8 +++--- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 28b6e26a1a986..49faa1d62bdf9 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -45,12 +45,9 @@ class has to be correctly predicted across all extra dimension for each sample i Accepts all input types listed in :ref:`metrics:Input types`. Args: - top_k: - Number of highest probability predictions considered to find the correct label, for - (multi-dimensional) multi-class inputs with probability predictions. Default 1 - - If your inputs are not (multi-dimensional) multi-class inputs with probability predictions, - an error will be raised if ``top_k`` is set to a value other than 1. + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 mdmc_accuracy: Determines how should the extra dimension be handeled in case of multi-dimensional multi-class inputs. Options are ``"global"`` or ``"subset"``. @@ -61,9 +58,12 @@ class has to be correctly predicted across all extra dimension for each sample i If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension must be predicted correctly (the ``top_k`` option still applies here). - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -94,9 +94,9 @@ class has to be correctly predicted across all extra dimension for each sample i def __init__( self, - top_k: int = 1, - mdmc_accuracy: str = "subset", threshold: float = 0.5, + mdmc_accuracy: str = "subset", + top_k: Optional[int] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -126,7 +126,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): target: Ground truth values """ - correct, total = _accuracy_update(preds, target, self.threshold, self.top_k, self.mdmc_accuracy) + correct, total = _accuracy_update( + preds, target, threshold=self.threshold, top_k=self.top_k, mdmc_accuracy=self.mdmc_accuracy + ) self.correct += correct self.total += total diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index fa423c611a450..47bc863a58c41 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Tuple, Union, Optional import torch from pytorch_lightning.metrics.classification.utils import _input_format_classification @@ -22,7 +22,7 @@ def _accuracy_update( - preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: int, mdmc_accuracy: str + preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], mdmc_accuracy: str ) -> Tuple[torch.Tensor, torch.Tensor]: preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) @@ -49,7 +49,11 @@ def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tenso def accuracy( - preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, top_k: int = 1, mdmc_accuracy: str = "subset" + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + mdmc_accuracy: str = "subset", + top_k: Optional[int] = None, ) -> torch.Tensor: """ Computes the share of entirely correctly predicted samples. @@ -69,12 +73,9 @@ class has to be correctly predicted across all extra dimension for each sample i Args: preds: Predictions from model (probabilities, or labels) target: Ground truth values - top_k: - Number of highest probability predictions considered to find the correct label, for - (multi-dimensional) multi-class inputs with probability predictions. Default 1 - - If your inputs are not (multi-dimensional) multi-class inputs with probability predictions, - an error will be raised if ``top_k`` is set to a value other than 1. + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 mdmc_accuracy: Determines how should the extra dimension be handeled in case of multi-dimensional multi-class inputs. Options are ``"global"`` or ``"subset"``. @@ -85,9 +86,12 @@ class has to be correctly predicted across all extra dimension for each sample i If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension must be predicted correctly (the ``top_k`` option still applies here). - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. Example: diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 8d8386683748c..c301f2503f0f3 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -81,7 +81,7 @@ def test_accuracy_fn(self, preds, target, metric, sk_metric, fn_metric): ) -l1to4 = [.1, .2, .3, .4] +l1to4 = [0.1, 0.2, 0.3, 0.4] l1to4t3 = np.array([l1to4, l1to4, l1to4]) l1to4t3_mc = [l1to4t3.T, l1to4t3.T, l1to4t3.T] @@ -130,7 +130,7 @@ def test_topk_accuracy(preds, target, exp_result, k, mdmc_accuracy): assert accuracy(preds, target, top_k=k, mdmc_accuracy=mdmc_accuracy) == exp_result -# Only MC and MDMC with probs input type should be accepted +# Only MC and MDMC with probs input type should be accepted for top_k @pytest.mark.parametrize( "preds, target", [ @@ -145,10 +145,10 @@ def test_topk_accuracy(preds, target, exp_result, k, mdmc_accuracy): ], ) def test_topk_accuracy_wrong_input_types(preds, target): - topk = Accuracy(top_k=2) + topk = Accuracy(top_k=1) with pytest.raises(ValueError): topk(preds[0], target[0]) with pytest.raises(ValueError): - accuracy(preds[0], target[0], top_k=2) + accuracy(preds[0], target[0], top_k=1) From 0cb0eac3b1602ae9cacf094aebb10c20ec01bc90 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 29 Nov 2020 18:33:11 +0100 Subject: [PATCH 37/94] Update desciption of parameters in input formatting --- .../metrics/classification/utils.py | 72 ++++++++++++++----- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index eda929fdd32ed..8c4c6b6eb94a9 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -223,18 +223,37 @@ def _check_classification_inputs( greater than 1, except perhaps the first one (N). Args: - preds: tensor with predictions - target: tensor with ground truth labels, always integers + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - num_classes: number of classes - is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim - multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim - multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. - Defaults to None, which treats inputs as they appear. - top_k: number of highest probability entries for each sample to convert to 1s, relevant - only for (multi-dimensional) multi-class cases. + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interepreted as 1 for these inputs. + + Should be left unset (``None``) for all other types of inputs. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be (see :ref:`metrics: Input types` documentation section for + input classification and examples of the use of this parameter). Should be left at default + value (``None``) in most cases. + + The special cases where this parameter should be set are: + + - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional + multi-class with 2 classes, respectively. The probabilities are interpreted as the + probability of the "1" class, and thresholding still applies as usual. In this case + the parameter should be set to ``True``. + - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes + as binary or multi-label inputs, respectively. This is mainly meant for the case when + inputs are labels, but will work if they are probabilities as well. For this case the + parameter should be set to ``False``. Return: case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or @@ -335,21 +354,38 @@ def _input_format_classification( target. Args: - preds: tensor with predictions - target: tensor with ground truth labels, always integers + preds: Tensor with predictions (labels or probabilities) + target: Tensor with ground truth labels, always integers (labels) threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - num_classes: number of classes - top_k: number of highest probability entries for each sample to convert to 1s, relevant + num_classes: + Number of classes. If not explicitly set, the number of classes will be infered + either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` + tensor, where applicable. + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interepreted as one for these inputs. + default value (``None``) will be interepreted as 1 for these inputs. Should be left unset (``None``) for all other types of inputs. - is_multiclass: if True, treat binary and multi-label inputs as multi-class or multi-dim - multi-class with 2 classes, respectively. If False, treat multi-class and multi-dim - multi-class inputs with 1 or 2 classes as binary and multi-label, respectively. - Defaults to None, which treats inputs as they appear. + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be (see :ref:`metrics: Input types` documentation section for + input classification and examples of the use of this parameter). Should be left at default + value (``None``) in most cases. + + The special cases where this parameter should be set are: + + - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional + multi-class with 2 classes, respectively. The probabilities are interpreted as the + probability of the "1" class, and thresholding still applies as usual. In this case + the parameter should be set to ``True``. + - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes + as binary or multi-label inputs, respectively. This is mainly meant for the case when + inputs are labels, but will work if they are probabilities as well. For this case the + parameter should be set to ``False``. + Returns: preds: binary tensor of shape (N, C) or (N, C, X) From 8e7a85a78568160f8c6e94bf01dbd76c8c064e5c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 30 Nov 2020 23:27:03 +0100 Subject: [PATCH 38/94] Apply suggestions from code review Co-authored-by: Nicki Skafte --- pytorch_lightning/metrics/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 170315aa22236..c8505bba58e3b 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -121,7 +121,7 @@ def _input_format_classification( # multi class probabilites preds = torch.argmax(preds, dim=1) - if preds.ndim == target.ndim and preds.dtype == torch.float: + if preds.ndim == target.ndim and preds.is_floating_point(): # binary or multilabel probablities preds = (preds >= threshold).long() return preds, target @@ -151,12 +151,12 @@ def _input_format_classification_one_hot( # multi class probabilites preds = torch.argmax(preds, dim=1) - if preds.ndim == target.ndim and preds.dtype == torch.long and num_classes > 1 and not multilabel: + if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel: # multi-class preds = to_onehot(preds, num_classes=num_classes) target = to_onehot(target, num_classes=num_classes) - elif preds.ndim == target.ndim and preds.dtype == torch.float: + elif preds.ndim == target.ndim and preds.is_floating_point(): # binary or multilabel probablities preds = (preds >= threshold).long() From 829155efcf09e542cded5075cf2b543e588bc62f Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 1 Dec 2020 00:06:02 +0100 Subject: [PATCH 39/94] Check that probabilities in preds sum to 1 (for MC) --- .../metrics/classification/utils.py | 5 +++++ tests/metrics/classification/inputs.py | 9 +++++++-- tests/metrics/classification/test_inputs.py | 18 +++++++++++++----- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/utils.py index 8c4c6b6eb94a9..7cfc505e6c673 100644 --- a/pytorch_lightning/metrics/classification/utils.py +++ b/pytorch_lightning/metrics/classification/utils.py @@ -266,6 +266,11 @@ def _check_classification_inputs( # Check that shape/types fall into one of the cases case, implied_classes = _check_shape_and_type_consistency(preds, target) + # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 + if "multi-class" in case and preds.is_floating_point(): + if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): + raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") + # Check consistency with the `C` dimension in case of multi-class data if preds.shape != target.shape: if is_multiclass is False and implied_classes != 2: diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index 48d3e85e3afeb..9f70a80cd31a4 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -53,8 +53,11 @@ target=__temp_target ) +__mc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +__mc_prob_preds = __mc_prob_preds / __mc_prob_preds.sum(dim=2, keepdim=True) + _multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) @@ -64,9 +67,11 @@ target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) +__mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) +__mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) _multidim_multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 058ec66c10ed6..8fe3c01fe77c3 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -22,14 +22,20 @@ torch.manual_seed(42) # Some additional inputs to test on -_mc_prob_2cls = Input(rand(NUM_BATCHES, BATCH_SIZE, 2), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) +_mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) +_mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) + +_mdmc_prob_many_dims_preds = rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM) +_mdmc_prob_many_dims_preds /= _mdmc_prob_many_dims_preds.sum(dim=2, keepdim=True) _mdmc_prob_many_dims = Input( - rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM), + _mdmc_prob_many_dims_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), ) -_mdmc_prob_2cls = Input( - rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM), randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) -) + +_mdmc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2, EXTRA_DIM) +_mdmc_prob_2cls_preds /= _mdmc_prob_2cls_preds.sum(dim=2, keepdim=True) +_mdmc_prob_2cls = Input(_mdmc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))) # Some utils T = torch.Tensor @@ -219,6 +225,8 @@ def test_threshold(): (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # is_multiclass=False, with C dimension > 2 (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, None), + # Probs of multiclass preds do not sum up to 1 + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, None, None), # Max target larger or equal to C dimension (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, None), # C dimension not equal to num_classes From 768879db9b13c911b39f9f6ddc9e2f208d2279a7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 1 Dec 2020 00:21:31 +0100 Subject: [PATCH 40/94] Fix coverage --- tests/metrics/classification/test_inputs.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8fe3c01fe77c3..6b5a03fcf1ea6 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -224,15 +224,22 @@ def test_threshold(): # #dims in preds = 1 + #dims in target, preds not float (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), # is_multiclass=False, with C dimension > 2 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, False, None), + (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), 0.5, None, False, None), # Probs of multiclass preds do not sum up to 1 (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, None, None), # Max target larger or equal to C dimension - (rand(size=(7, 3)), randint(low=4, high=6, size=(7,)), 0.5, None, None, None), + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), 0.5, None, None, None), # C dimension not equal to num_classes - (rand(size=(7, 4, 3)), randint(high=4, size=(7, 3)), 0.5, 7, None, None), + (_mc_prob.preds[0], _mc_prob.target[0], 0.5, NUM_CLASSES + 1, None, None), # Max target larger than num_classes (with #dim preds = 1 + #dims target) - (rand(size=(7, 3, 4)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), + ( + _mc_prob.preds[0], + randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), + 0.5, + 4, + None, + None, + ), # Max target larger than num_classes (with #dim preds = #dims target) (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), # Max preds larger than num_classes (with #dim preds = #dims target) From eeded458b94f477686d70341e84f8e904df7a7cf Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 1 Dec 2020 12:45:10 +0100 Subject: [PATCH 41/94] Split accuracy and hamming loss --- .../metrics/classification/__init__.py | 3 +- .../metrics/classification/accuracy.py | 85 +--------------- .../metrics/classification/hamming_loss.py | 96 +++++++++++++++++++ .../metrics/functional/__init__.py | 3 +- .../metrics/functional/accuracy.py | 54 +---------- .../metrics/functional/hamming_loss.py | 60 ++++++++++++ tests/metrics/classification/test_accuracy.py | 29 ++---- .../classification/test_hamming_loss.py | 69 +++++++++++++ 8 files changed, 240 insertions(+), 159 deletions(-) create mode 100644 pytorch_lightning/metrics/classification/hamming_loss.py create mode 100644 pytorch_lightning/metrics/functional/hamming_loss.py create mode 100644 tests/metrics/classification/test_hamming_loss.py diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 45d4dd03e430e..b52e6ab21823a 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.metrics.classification.accuracy import Accuracy, HammingLoss +from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics.classification.hamming_loss import HammingLoss from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 49faa1d62bdf9..1c772aec4d88f 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -15,12 +15,7 @@ import torch from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.functional.accuracy import ( - _accuracy_update, - _hamming_loss_update, - _accuracy_compute, - _hamming_loss_compute, -) +from pytorch_lightning.metrics.functional.accuracy import _accuracy_update, _accuracy_compute class Accuracy(Metric): @@ -138,81 +133,3 @@ def compute(self) -> torch.Tensor: Computes accuracy based on inputs passed in to ``update`` previously. """ return _accuracy_compute(self.correct, self.total) - - -class HammingLoss(Metric): - """ - Computes the share of wrongly predicted labels. - - This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it - treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. If this is not what you want, consider using - :class:`~pytorch_lightning.metrics.classification.Accuracy`. - - Accepts all input types listed in :ref:`metrics:Input types`. - - Args: - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - - Example: - - >>> from pytorch_lightning.metrics import HammingLoss - >>> target = torch.tensor([[0, 1], [1, 1]]) - >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_loss = HammingLoss() - >>> hamming_loss(preds, target) - tensor(0.2500) - - """ - - def __init__( - self, - threshold: float = 0.5, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - self.threshold = threshold - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. See :ref:`metrics:Input types` for more information - on input types. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth values - """ - correct, total = _hamming_loss_update(preds, target, self.threshold) - - self.correct += correct - self.total += total - - def compute(self) -> torch.Tensor: - """ - Computes hamming loss based on inputs passed in to ``update`` previously. - """ - return _hamming_loss_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/classification/hamming_loss.py b/pytorch_lightning/metrics/classification/hamming_loss.py new file mode 100644 index 0000000000000..5432a03c8403b --- /dev/null +++ b/pytorch_lightning/metrics/classification/hamming_loss.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.functional.hamming_loss import _hamming_loss_update, _hamming_loss_compute + + +class HammingLoss(Metric): + """ + Computes the share of wrongly predicted labels. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. If this is not what you want, consider using + :class:`~pytorch_lightning.metrics.classification.Accuracy`. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + + Example: + + >>> from pytorch_lightning.metrics import HammingLoss + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_loss = HammingLoss() + >>> hamming_loss(preds, target) + tensor(0.2500) + + """ + + def __init__( + self, + threshold: float = 0.5, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + self.threshold = threshold + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth values + """ + correct, total = _hamming_loss_update(preds, target, self.threshold) + + self.correct += correct + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes hamming loss based on inputs passed in to ``update`` previously. + """ + return _hamming_loss_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 42029335afe9f..fc0d2f00b985a 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -42,6 +42,7 @@ from pytorch_lightning.metrics.functional.psnr import psnr from pytorch_lightning.metrics.functional.ssim import ssim -from pytorch_lightning.metrics.functional.accuracy import accuracy, hamming_loss +from pytorch_lightning.metrics.functional.accuracy import accuracy +from pytorch_lightning.metrics.functional.hamming_loss import hamming_loss from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 47bc863a58c41..f029f1a399870 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -11,15 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union, Optional +from typing import Tuple, Optional import torch from pytorch_lightning.metrics.classification.utils import _input_format_classification -################################ -# Accuracy -################################ - def _accuracy_update( preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], mdmc_accuracy: str @@ -109,51 +105,3 @@ class has to be correctly predicted across all extra dimension for each sample i correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) return _accuracy_compute(correct, total) - - -################################ -# Hamming loss -################################ - - -def _hamming_loss_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: - preds, target, _ = _input_format_classification(preds, target, threshold=threshold) - - correct = (preds == target).sum() - total = preds.numel() - - return correct, total - - -def _hamming_loss_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: - return 1 - correct.float() / total - - -def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: - """ - Computes the share of wrongly predicted labels. - - This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it - treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. If this is not what you want, consider using - :class:`~pytorch_lightning.metrics.classification.Accuracy`. - - Accepts all input types listed in :ref:`metrics:Input types`. - - Args: - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 - - Example: - - >>> from pytorch_lightning.metrics.functional import hamming_loss - >>> target = torch.tensor([[0, 1], [1, 1]]) - >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_loss(preds, target) - tensor(0.2500) - - """ - - correct, total = _hamming_loss_update(preds, target, threshold) - return _hamming_loss_compute(correct, total) diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py new file mode 100644 index 0000000000000..08c7c9056a960 --- /dev/null +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from pytorch_lightning.metrics.classification.utils import _input_format_classification + + +def _hamming_loss_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: + preds, target, _ = _input_format_classification(preds, target, threshold=threshold) + + correct = (preds == target).sum() + total = preds.numel() + + return correct, total + + +def _hamming_loss_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: + return 1 - correct.float() / total + + +def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + Computes the share of wrongly predicted labels. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. If this is not what you want, consider using + :class:`~pytorch_lightning.metrics.classification.Accuracy`. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + + Example: + + >>> from pytorch_lightning.metrics.functional import hamming_loss + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_loss(preds, target) + tensor(0.2500) + + """ + + correct, total = _hamming_loss_update(preds, target, threshold) + return _hamming_loss_compute(correct, total) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index c301f2503f0f3..9cbf5ec855885 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -1,10 +1,10 @@ import numpy as np import pytest import torch -from sklearn.metrics import accuracy_score as sk_accuracy, hamming_loss as sk_hamming_loss +from sklearn.metrics import accuracy_score as sk_accuracy -from pytorch_lightning.metrics import Accuracy, HammingLoss -from pytorch_lightning.metrics.functional import accuracy, hamming_loss +from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.metrics.functional import accuracy from pytorch_lightning.metrics.classification.utils import _input_format_classification from tests.metrics.classification.inputs import ( _binary_inputs, @@ -31,17 +31,6 @@ def _sk_accuracy(preds, target): return sk_accuracy(y_true=sk_target, y_pred=sk_preds) -def _sk_hamming_loss(preds, target): - sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - - return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) - - -@pytest.mark.parametrize( - "metric, fn_metric, sk_metric", [(Accuracy, accuracy, _sk_accuracy), (HammingLoss, hamming_loss, _sk_hamming_loss)] -) @pytest.mark.parametrize( "preds, target", [ @@ -60,23 +49,23 @@ def _sk_hamming_loss(preds, target): class TestAccuracies(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, metric, sk_metric, fn_metric): + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target): self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=metric, - sk_metric=sk_metric, + metric_class=Accuracy, + sk_metric=_sk_accuracy, dist_sync_on_step=dist_sync_on_step, metric_args={"threshold": THRESHOLD}, ) - def test_accuracy_fn(self, preds, target, metric, sk_metric, fn_metric): + def test_accuracy_fn(self, preds, target): self.run_functional_metric_test( preds, target, - metric_functional=fn_metric, - sk_metric=sk_metric, + metric_functional=accuracy, + sk_metric=_sk_accuracy, metric_args={"threshold": THRESHOLD}, ) diff --git a/tests/metrics/classification/test_hamming_loss.py b/tests/metrics/classification/test_hamming_loss.py new file mode 100644 index 0000000000000..1d30a7cdabc12 --- /dev/null +++ b/tests/metrics/classification/test_hamming_loss.py @@ -0,0 +1,69 @@ +import pytest +import torch +from sklearn.metrics import hamming_loss as sk_hamming_loss + +from pytorch_lightning.metrics import HammingLoss +from pytorch_lightning.metrics.functional import hamming_loss +from pytorch_lightning.metrics.classification.utils import _input_format_classification +from tests.metrics.classification.inputs import ( + _binary_inputs, + _binary_prob_inputs, + _multiclass_inputs, + _multiclass_prob_inputs, + _multidim_multiclass_inputs, + _multidim_multiclass_prob_inputs, + _multilabel_inputs, + _multilabel_prob_inputs, + _multilabel_multidim_prob_inputs, + _multilabel_multidim_inputs, +) +from tests.metrics.utils import THRESHOLD, MetricTester + +torch.manual_seed(42) + + +def _sk_hamming_loss(preds, target): + sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) + + return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target), + (_binary_inputs.preds, _binary_inputs.target), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), + (_multilabel_inputs.preds, _multilabel_inputs.target), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target), + (_multiclass_inputs.preds, _multiclass_inputs.target), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), + ], +) +class TestAccuracies(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=HammingLoss, + sk_metric=_sk_hamming_loss, + dist_sync_on_step=dist_sync_on_step, + metric_args={"threshold": THRESHOLD}, + ) + + def test_accuracy_fn(self, preds, target): + self.run_functional_metric_test( + preds, + target, + metric_functional=hamming_loss, + sk_metric=_sk_hamming_loss, + metric_args={"threshold": THRESHOLD}, + ) From b49cfdc664fa4d826f3b31521e99f93c94297703 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 1 Dec 2020 12:56:33 +0100 Subject: [PATCH 42/94] Remove old redundant accuracy --- .../metrics/functional/classification.py | 318 +++++------- .../metrics/functional/test_classification.py | 471 ++++++++++-------- 2 files changed, 387 insertions(+), 402 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 75eeeca3b8e17..cefea07c67f15 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -22,8 +22,8 @@ def to_onehot( - tensor: torch.Tensor, - num_classes: Optional[int] = None, + tensor: torch.Tensor, + num_classes: Optional[int] = None, ) -> torch.Tensor: """ Converts a dense label tensor to one-hot format @@ -47,16 +47,12 @@ def to_onehot( if num_classes is None: num_classes = int(tensor.max().detach().item() + 1) dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], dtype=dtype, device=device) index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) return tensor_onehot.scatter_(1, index, 1.0) -def to_categorical( - tensor: torch.Tensor, - argmax_dim: int = 1 -) -> torch.Tensor: +def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ Converts a tensor of probabilities to a dense label tensor @@ -78,9 +74,9 @@ def to_categorical( def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, ) -> int: """ Calculates the number of classes for a given prediction and target tensor. @@ -100,17 +96,20 @@ def get_num_classes( if num_classes is None: num_classes = num_all_classes elif num_classes != num_all_classes: - rank_zero_warn(f'You have set {num_classes} number of classes which is' - f' different from predicted ({num_pred_classes}) and' - f' target ({num_target_classes}) number of classes', - RuntimeWarning) + rank_zero_warn( + f"You have set {num_classes} number of classes which is" + f" different from predicted ({num_pred_classes}) and" + f" target ({num_target_classes}) number of classes", + RuntimeWarning, + ) return num_classes def stat_scores( - pred: torch.Tensor, - target: torch.Tensor, - class_index: int, argmax_dim: int = 1, + pred: torch.Tensor, + target: torch.Tensor, + class_index: int, + argmax_dim: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculates the number of true positive, false positive, true negative @@ -148,11 +147,11 @@ def stat_scores( def stat_scores_multiple_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - argmax_dim: int = 1, - reduction: str = 'none', + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + argmax_dim: int = 1, + reduction: str = "none", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculates the number of true positive, false positive, true negative @@ -201,13 +200,13 @@ def stat_scores_multiple_classes( if target.dtype != torch.bool: target = target.clamp_max(max=num_classes) - possible_reductions = ('none', 'sum', 'elementwise_mean') + possible_reductions = ("none", "sum", "elementwise_mean") if reduction not in possible_reductions: raise ValueError("reduction type %s not supported" % reduction) - if reduction == 'none': - pred = pred.view((-1, )).long() - target = target.view((-1, )).long() + if reduction == "none": + pred = pred.view((-1,)).long() + target = target.view((-1,)).long() tps = torch.zeros((num_classes + 1,), device=pred.device) fps = torch.zeros((num_classes + 1,), device=pred.device) @@ -230,7 +229,7 @@ def stat_scores_multiple_classes( fns = fns[:num_classes] sups = sups[:num_classes] - elif reduction == 'sum' or reduction == 'elementwise_mean': + elif reduction == "sum" or reduction == "elementwise_mean": count_match_true = (pred == target).sum().float() oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim) @@ -240,7 +239,7 @@ def stat_scores_multiple_classes( tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn) sups = pred.nelement() - oob_sup.float() - if reduction == 'elementwise_mean': + if reduction == "elementwise_mean": tps /= num_classes fps /= num_classes fns /= num_classes @@ -250,64 +249,23 @@ def stat_scores_multiple_classes( return tps.float(), fps.float(), tns.float(), fns.float(), sups.float() -def accuracy( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - return_state: bool = False -) -> torch.Tensor: - """ - Computes the accuracy classification score - - Args: - pred: predicted labels - target: ground truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - A Tensor with the accuracy score. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> accuracy(x, y) - tensor(0.7500) - - """ - tps, fps, tns, fns, sups = stat_scores_multiple_classes( - pred=pred, target=target, num_classes=num_classes) - if return_state: - return {'tps': tps, 'sups': sups} - return class_reduce(tps, sups, sups, class_reduction=class_reduction) - - def _confmat_normalize(cm): """ Normalization function for confusion matrix """ cm = cm / cm.sum(-1, keepdim=True) nan_elements = cm[torch.isnan(cm)].nelement() if nan_elements != 0: cm[torch.isnan(cm)] = 0 - rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') + rank_zero_warn(f"{nan_elements} nan values found in confusion matrix have been replaced with zeros.") return cm def precision_recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - return_support: bool = False, - return_state: bool = False + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = "micro", + return_support: bool = False, + return_state: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds @@ -343,17 +301,17 @@ def precision_recall( precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction) if return_state: - return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups} + return {"tps": tps, "fps": fps, "fns": fns, "sups": sups} if return_support: return precision, recall, sups return precision, recall def precision( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = "micro", ) -> torch.Tensor: """ Computes precision score. @@ -380,15 +338,14 @@ def precision( tensor(0.7500) """ - return precision_recall(pred=pred, target=target, - num_classes=num_classes, class_reduction=class_reduction)[0] + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] def recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = "micro", ) -> torch.Tensor: """ Computes recall score. @@ -414,15 +371,14 @@ def recall( >>> recall(x, y) tensor(0.7500) """ - return precision_recall(pred=pred, target=target, - num_classes=num_classes, class_reduction=class_reduction)[1] + return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] def _binary_clf_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py @@ -441,7 +397,7 @@ def _binary_clf_curve( if sample_weight is not None: weight = sample_weight[desc_score_indices] else: - weight = 1. + weight = 1.0 # pred typically has many tied values. Here we extract # the indices associated with the distinct values. We also @@ -463,10 +419,10 @@ def _binary_clf_curve( def roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. @@ -493,9 +449,7 @@ def roc( tensor([4, 3, 2, 1, 0]) """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) + fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) # Add an extra threshold position # to make sure that the curve starts at (0, 0) @@ -517,10 +471,10 @@ def roc( def multiclass_roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. @@ -554,17 +508,16 @@ def multiclass_roc( for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(roc(pred=pred_c, target=target, - sample_weight=sample_weight, pos_label=c)) + class_roc_vals.append(roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) return tuple(class_roc_vals) def precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes precision-recall pairs for different thresholds. @@ -591,9 +544,7 @@ def precision_recall_curve( tensor([1, 2, 3]) """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) + fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) precision = tps / (tps + fps) recall = tps / tps[-1] @@ -605,13 +556,9 @@ def precision_recall_curve( # need to call reversed explicitly, since including that to slice would # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), - torch.ones(1, dtype=precision.dtype, - device=precision.device)]) + precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) - recall = torch.cat([reversed(recall[sl]), - torch.zeros(1, dtype=recall.dtype, - device=recall.device)]) + recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) thresholds = torch.tensor(reversed(thresholds[sl])) @@ -619,10 +566,10 @@ def precision_recall_curve( def multiclass_precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes precision-recall pairs for different thresholds given a multiclass scores. @@ -659,19 +606,14 @@ def multiclass_precision_recall_curve( for c in range(num_classes): pred_c = pred[:, c] - class_pr_vals.append(precision_recall_curve( - pred=pred_c, - target=target, - sample_weight=sample_weight, pos_label=c)) + class_pr_vals.append( + precision_recall_curve(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c) + ) return tuple(class_pr_vals) -def auc( - x: torch.Tensor, - y: torch.Tensor, - reorder: bool = True -) -> torch.Tensor: +def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True) -> torch.Tensor: """ Computes Area Under the Curve (AUC) using the trapezoidal rule @@ -692,14 +634,16 @@ def auc( >>> auc(x, y) tensor(4.) """ - direction = 1. + direction = 1.0 if reorder: - rank_zero_warn("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1" - " Note that when `reorder` is True, the unstable algorithm of torch.argsort is" - " used internally to sort 'x' which may in some cases cause inaccuracies" - " in the result.", - DeprecationWarning) + rank_zero_warn( + "The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1" + " Note that when `reorder` is True, the unstable algorithm of torch.argsort is" + " used internally to sort 'x' which may in some cases cause inaccuracies" + " in the result.", + DeprecationWarning, + ) # can't use lexsort here since it is not implemented for torch order = torch.argsort(x) x, y = x[order], y[order] @@ -707,11 +651,12 @@ def auc( dx = x[1:] - x[:-1] if (dx < 0).any(): if (dx, 0).all(): - direction = -1. + direction = -1.0 else: # TODO: Update message on removing reorder - raise ValueError("Reorder is not turned on, and the 'x' array is" - f" neither increasing or decreasing: {x}") + raise ValueError( + "Reorder is not turned on, and the 'x' array is" f" neither increasing or decreasing: {x}" + ) return direction * torch.trapz(y, x) @@ -746,10 +691,10 @@ def new_func(*args, **kwargs) -> torch.Tensor: def auroc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1.0, ) -> torch.Tensor: """ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores @@ -771,9 +716,11 @@ def auroc( tensor(0.5000) """ if any(target > 1): - raise ValueError('AUROC metric is meant for binary classification, but' - ' target tensor contains value different from 0 and 1.' - ' Use `multiclass_auroc` for multi class classification.') + raise ValueError( + "AUROC metric is meant for binary classification, but" + " target tensor contains value different from 0 and 1." + " Use `multiclass_auroc` for multi class classification." + ) @auc_decorator(reorder=True) def _auroc(pred, target, sample_weight, pos_label): @@ -783,10 +730,10 @@ def _auroc(pred, target, sample_weight, pos_label): def multiclass_auroc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ) -> torch.Tensor: """ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass @@ -814,35 +761,36 @@ def multiclass_auroc( if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): raise ValueError( "Multiclass AUROC metric expects the target scores to be" - " probabilities, i.e. they should sum up to 1.0 over classes") + " probabilities, i.e. they should sum up to 1.0 over classes" + ) if torch.unique(target).size(0) != pred.size(1): raise ValueError( f"Number of classes found in in 'target' ({torch.unique(target).size(0)})" f" does not equal the number of columns in 'pred' ({pred.size(1)})." " Multiclass AUROC is not defined when all of the classes do not" - " occur in the target labels.") + " occur in the target labels." + ) if num_classes is not None and num_classes != pred.size(1): raise ValueError( f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal" - f" the number of classes passed in 'num_classes' ({num_classes}).") + f" the number of classes passed in 'num_classes' ({num_classes})." + ) @multiclass_auc_decorator(reorder=False) def _multiclass_auroc(pred, target, sample_weight, num_classes): return multiclass_roc(pred, target, sample_weight, num_classes) - class_aurocs = _multiclass_auroc(pred=pred, target=target, - sample_weight=sample_weight, - num_classes=num_classes) + class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes) return torch.mean(class_aurocs) def average_precision( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1.0, ) -> torch.Tensor: """ Compute average precision from prediction scores @@ -863,9 +811,9 @@ def average_precision( >>> average_precision(x, y) tensor(0.3333) """ - precision, recall, _ = precision_recall_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) + precision, recall, _ = precision_recall_curve( + pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label + ) # Return the step function integral # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve @@ -873,12 +821,12 @@ def average_precision( def dice_score( - pred: torch.Tensor, - target: torch.Tensor, - bg: bool = False, - nan_score: float = 0.0, - no_fg_score: float = 0.0, - reduction: str = 'elementwise_mean', + pred: torch.Tensor, + target: torch.Tensor, + bg: bool = False, + nan_score: float = 0.0, + no_fg_score: float = 0.0, + reduction: str = "elementwise_mean", ) -> torch.Tensor: """ Compute dice score from prediction scores @@ -910,7 +858,7 @@ def dice_score( """ num_classes = pred.shape[1] - bg = (1 - int(bool(bg))) + bg = 1 - int(bool(bg)) scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) for i in range(bg, num_classes): if not (target == i).any(): @@ -928,12 +876,12 @@ def dice_score( def iou( - pred: torch.Tensor, - target: torch.Tensor, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + pred: torch.Tensor, + target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", ) -> torch.Tensor: """ Intersection over union, or Jaccard index calculation. @@ -1005,9 +953,11 @@ def iou( # Remove the ignored class index from the scores. if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: - scores = torch.cat([ - scores[:ignore_index], - scores[ignore_index + 1:], - ]) + scores = torch.cat( + [ + scores[:ignore_index], + scores[ignore_index + 1 :], + ] + ) return reduce(scores, reduction=reduction) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 307aeea1f9ac1..4016cbeabb724 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -3,7 +3,6 @@ import pytest import torch from sklearn.metrics import ( - accuracy_score as sk_accuracy, jaccard_score as sk_jaccard_score, precision_score as sk_precision, recall_score as sk_recall, @@ -11,7 +10,7 @@ fbeta_score as sk_fbeta_score, roc_curve as sk_roc_curve, roc_auc_score as sk_roc_auc_score, - precision_recall_curve as sk_precision_recall_curve + precision_recall_curve as sk_precision_recall_curve, ) from pytorch_lightning import seed_everything @@ -21,7 +20,6 @@ get_num_classes, stat_scores, stat_scores_multiple_classes, - accuracy, precision, recall, _binary_clf_curve, @@ -36,18 +34,20 @@ ) -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [ - pytest.param(sk_accuracy, accuracy, False, id='accuracy'), - pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'), - pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'), - pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'), - pytest.param(sk_roc_curve, roc, True, id='roc'), - pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'), - pytest.param(sk_roc_auc_score, auroc, True, id='auroc') -]) +@pytest.mark.parametrize( + ["sklearn_metric", "torch_metric", "only_binary"], + [ + pytest.param(partial(sk_jaccard_score, average="macro"), iou, False, id="iou"), + pytest.param(partial(sk_precision, average="micro"), precision, False, id="precision"), + pytest.param(partial(sk_recall, average="micro"), recall, False, id="recall"), + pytest.param(sk_roc_curve, roc, True, id="roc"), + pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id="precision_recall_curve"), + pytest.param(sk_roc_auc_score, auroc, True, id="auroc"), + ], +) def test_against_sklearn(sklearn_metric, torch_metric, only_binary): """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" # for metrics with only_binary=False, we try out different combinations of number # of labels in pred and target (also test binary) @@ -58,8 +58,7 @@ def test_against_sklearn(sklearn_metric, torch_metric, only_binary): pred = torch.randint(n_cls_pred, (300,), device=device) target = torch.randint(n_cls_target, (300,), device=device) - sk_score = sklearn_metric(target.cpu().detach().numpy(), - pred.cpu().detach().numpy()) + sk_score = sklearn_metric(target.cpu().detach().numpy(), pred.cpu().detach().numpy()) pl_score = torch_metric(pred, target) # if multi output @@ -72,20 +71,21 @@ def test_against_sklearn(sklearn_metric, torch_metric, only_binary): assert torch.allclose(sk_score, pl_score) -@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted']) -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(sk_precision, precision, id='precision'), - pytest.param(sk_recall, recall, id='recall'), -]) +@pytest.mark.parametrize("class_reduction", ["micro", "macro", "weighted"]) +@pytest.mark.parametrize( + ["sklearn_metric", "torch_metric"], + [ + pytest.param(sk_precision, precision, id="precision"), + pytest.param(sk_recall, recall, id="recall"), + ], +) def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric): - """ Test metrics where the class_reduction parameter have a correponding - value in sklearn """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' + """Test metrics where the class_reduction parameter have a correponding + value in sklearn""" + device = "cuda" if torch.cuda.is_available() else "cpu" pred = torch.randint(10, (300,), device=device) target = torch.randint(10, (300,), device=device) - sk_score = sklearn_metric(target.cpu().detach().numpy(), - pred.cpu().detach().numpy(), - average=class_reduction) + sk_score = sklearn_metric(target.cpu().detach().numpy(), pred.cpu().detach().numpy(), average=class_reduction) sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) pl_score = torch_metric(pred, target, class_reduction=class_reduction) assert torch.allclose(sk_score, pl_score) @@ -93,10 +93,12 @@ def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, to def test_onehot(): test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - expected = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]) + expected = torch.stack( + [ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]), + ] + ) assert test_tensor.shape == (2, 5) assert expected.shape == (2, 10, 5) @@ -114,10 +116,12 @@ def test_onehot(): def test_to_categorical(): - test_tensor = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]).to(torch.float) + test_tensor = torch.stack( + [ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]), + ] + ).to(torch.float) expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) assert expected.shape == (2, 5) @@ -129,20 +133,25 @@ def test_to_categorical(): assert torch.allclose(result, expected.to(result.dtype)) -@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), -]) +@pytest.mark.parametrize( + ["pred", "target", "num_classes", "expected_num_classes"], + [ + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), + pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), + ], +) def test_get_num_classes(pred, target, num_classes, expected_num_classes): assert get_num_classes(pred, target, num_classes) == expected_num_classes -@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', - 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2) -]) +@pytest.mark.parametrize( + ["pred", "target", "expected_tp", "expected_fp", "expected_tn", "expected_fn", "expected_support"], + [ + pytest.param(torch.tensor([0.0, 2.0, 4.0, 4.0]), torch.tensor([0.0, 4.0, 3.0, 4.0]), 1, 1, 1, 1, 2), + pytest.param(to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), torch.tensor([0.0, 4.0, 3.0, 4.0]), 1, 1, 1, 1, 2), + ], +) def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4) @@ -153,18 +162,54 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect assert sup.item() == expected_support -@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp', - 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none', - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none', - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum', - torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean', - torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8)) -]) -def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): +@pytest.mark.parametrize( + ["pred", "target", "reduction", "expected_tp", "expected_fp", "expected_tn", "expected_fn", "expected_support"], + [ + pytest.param( + torch.tensor([0.0, 2.0, 4.0, 4.0]), + torch.tensor([0.0, 4.0, 3.0, 4.0]), + "none", + [1, 0, 0, 0, 1], + [0, 0, 1, 0, 1], + [3, 4, 3, 3, 1], + [0, 0, 0, 1, 1], + [1, 0, 0, 1, 2], + ), + pytest.param( + to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), + torch.tensor([0.0, 4.0, 3.0, 4.0]), + "none", + [1, 0, 0, 0, 1], + [0, 0, 1, 0, 1], + [3, 4, 3, 3, 1], + [0, 0, 0, 1, 1], + [1, 0, 0, 1, 2], + ), + pytest.param( + to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), + torch.tensor([0.0, 4.0, 3.0, 4.0]), + "sum", + torch.tensor(2), + torch.tensor(2), + torch.tensor(14), + torch.tensor(2), + torch.tensor(4), + ), + pytest.param( + to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), + torch.tensor([0.0, 4.0, 3.0, 4.0]), + "elementwise_mean", + torch.tensor(0.4), + torch.tensor(0.4), + torch.tensor(2.8), + torch.tensor(0.4), + torch.tensor(0.8), + ), + ], +) +def test_stat_scores_multiclass( + pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support +): tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction) assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) @@ -174,55 +219,30 @@ def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_f assert torch.allclose(torch.tensor(expected_support).to(sup), sup) -def test_multilabel_accuracy(): - # Dense label indicator matrix format - y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) - y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) - - assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.])) - assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.])) - assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.])) - - # num_classes does not match extracted number from input we expect a warning - with pytest.warns(RuntimeWarning, - match=r'You have set .* number of classes which is' - r' different from predicted (.*) and' - r' target (.*) number of classes'): - _ = accuracy(y2, torch.zeros_like(y2), num_classes=3) - - -def test_accuracy(): - pred = torch.tensor([0, 1, 2, 3]) - target = torch.tensor([0, 1, 2, 2]) - acc = accuracy(pred, target) - - assert acc.item() == 0.75 - - pred = torch.tensor([0, 1, 2, 2]) - target = torch.tensor([0, 1, 1, 3]) - acc = accuracy(pred, target) - - assert acc.item() == 0.50 - - -@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ - pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), - pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) -]) +@pytest.mark.parametrize( + ["pred", "target", "expected_prec", "expected_rec"], + [ + pytest.param(torch.tensor([1.0, 0.0, 1.0, 0.0]), torch.tensor([0.0, 1.0, 1.0, 0.0]), [0.5, 0.5], [0.5, 0.5]), + pytest.param( + to_onehot(torch.tensor([1.0, 0.0, 1.0, 0.0])), torch.tensor([0.0, 1.0, 1.0, 0.0]), [0.5, 0.5], [0.5, 0.5] + ), + ], +) def test_precision_recall(pred, target, expected_prec, expected_rec): - prec = precision(pred, target, class_reduction='none') - rec = recall(pred, target, class_reduction='none') + prec = precision(pred, target, class_reduction="none") + rec = recall(pred, target, class_reduction="none") assert torch.allclose(torch.tensor(expected_prec).to(prec), prec) assert torch.allclose(torch.tensor(expected_rec).to(rec), rec) -@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ - pytest.param(1, 1., 42), - pytest.param(None, 1., 42), -]) +@pytest.mark.parametrize( + ["sample_weight", "pos_label", "exp_shape"], + [ + pytest.param(1, 1.0, 42), + pytest.param(None, 1.0, 42), + ], +) def test_binary_clf_curve(sample_weight, pos_label, exp_shape): # TODO: move back the pred and target to test func arguments # if you fix the array inside the function, you'd also have fix the shape, @@ -243,9 +263,10 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): assert thresh.shape == (exp_shape,) -@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [ - pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4]) -]) +@pytest.mark.parametrize( + ["pred", "target", "expected_p", "expected_r", "expected_t"], + [pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1.0, 1.0], [1, 0.5, 0.5, 0.5, 0.0], [1, 2, 3, 4])], +) def test_pr_curve(pred, target, expected_p, expected_r, expected_t): p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) assert p.size() == r.size() @@ -256,13 +277,16 @@ def test_pr_curve(pred, target, expected_p, expected_r, expected_t): assert torch.allclose(t, torch.tensor(expected_t).to(t)) -@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ - pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), - pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), - pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), - pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), - pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), -]) +@pytest.mark.parametrize( + ["pred", "target", "expected_tpr", "expected_fpr"], + [ + pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), + pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), + pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), + pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), + pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), + ], +) def test_roc_curve(pred, target, expected_tpr, expected_fpr): fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) @@ -272,105 +296,112 @@ def test_roc_curve(pred, target, expected_tpr, expected_fpr): assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.), - pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.), - pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5), - pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.), - pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5), -]) +@pytest.mark.parametrize( + ["pred", "target", "expected"], + [ + pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.0), + pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.0), + pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5), + pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.0), + pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5), + ], +) def test_auroc(pred, target, expected): score = auroc(torch.tensor(pred), torch.tensor(target)).item() assert score == expected def test_multiclass_auroc(): - with pytest.raises(ValueError, - match=r".*probabilities, i.e. they should sum up to 1.0 over classes"): - _ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9], - [1.0, 0]]), - target=torch.tensor([0, 1])) + with pytest.raises(ValueError, match=r".*probabilities, i.e. they should sum up to 1.0 over classes"): + _ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9], [1.0, 0]]), target=torch.tensor([0, 1])) - with pytest.raises(ValueError, - match=r".*not defined when all of the classes do not occur in the target.*"): - _ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1), - target=torch.tensor([1, 0, 1, 0])) + with pytest.raises(ValueError, match=r".*not defined when all of the classes do not occur in the target.*"): + _ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1), target=torch.tensor([1, 0, 1, 0])) - with pytest.raises(ValueError, - match=r".*does not equal the number of classes passed in 'num_classes'.*"): - _ = multiclass_auroc(pred=torch.rand((5, 4)).softmax(dim=1), - target=torch.tensor([0, 1, 2, 2, 3]), - num_classes=6) + with pytest.raises(ValueError, match=r".*does not equal the number of classes passed in 'num_classes'.*"): + _ = multiclass_auroc( + pred=torch.rand((5, 4)).softmax(dim=1), target=torch.tensor([0, 1, 2, 2, 3]), num_classes=6 + ) -@pytest.mark.parametrize('n_cls', [2, 5, 10, 50]) +@pytest.mark.parametrize("n_cls", [2, 5, 10, 50]) def test_multiclass_auroc_against_sklearn(n_cls): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" n_samples = 300 pred = torch.rand(n_samples, n_cls, device=device).softmax(dim=1) target = torch.randint(n_cls, (n_samples,), device=device) # Make sure target includes all class labels so that multiclass AUROC is defined - target[10:10 + n_cls] = torch.arange(n_cls) + target[10 : 10 + n_cls] = torch.arange(n_cls) pl_score = multiclass_auroc(pred, target) # For the binary case, sklearn expects an (n_samples,) array of probabilities of # the positive class pred = pred[:, 1] if n_cls == 2 else pred - sk_score = sk_roc_auc_score(target.cpu().detach().numpy(), - pred.cpu().detach().numpy(), - multi_class="ovr") + sk_score = sk_roc_auc_score(target.cpu().detach().numpy(), pred.cpu().detach().numpy(), multi_class="ovr") sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) assert torch.allclose(sk_score, pl_score) -@pytest.mark.parametrize(['x', 'y', 'expected'], [ - pytest.param([0, 1], [0, 1], 0.5), - pytest.param([1, 0], [0, 1], 0.5), - pytest.param([1, 0, 0], [0, 1, 1], 0.5), - pytest.param([0, 1], [1, 1], 1), - pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), -]) +@pytest.mark.parametrize( + ["x", "y", "expected"], + [ + pytest.param([0, 1], [0, 1], 0.5), + pytest.param([1, 0], [0, 1], 0.5), + pytest.param([1, 0, 0], [0, 1, 1], 0.5), + pytest.param([0, 1], [1, 1], 1), + pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), + ], +) def test_auc(x, y, expected): # Test Area Under Curve (AUC) computation assert auc(torch.tensor(x), torch.tensor(y)) == expected -@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), -]) +@pytest.mark.parametrize( + ["scores", "target", "expected_score"], + [ + # Check the average_precision_score of a constant predictor is + # the TPR + # Generate a dataset with 25% of positives + # And a constant score + # The precision is then the fraction of positive whatever the recall + # is, as there is only one threshold: + pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), 0.25), + # With threshold 0.8 : 1 TP and 2 TN and one FN + pytest.param(torch.tensor([0.6, 0.7, 0.8, 9]), torch.tensor([1, 0, 0, 1]), 0.75), + ], +) def test_average_precision(scores, target, expected_score): assert average_precision(scores, target) == expected_score -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), - pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), - pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), - pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.), -]) +@pytest.mark.parametrize( + ["pred", "target", "expected"], + [ + pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0), + pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0), + pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), + pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0), + ], +) def test_dice_score(pred, target, expected): score = dice_score(torch.tensor(pred), torch.tensor(target)) assert score == expected -@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ - pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), - pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), - pytest.param(False, 'none', 0, torch.Tensor([1, 1])), - pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), - pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), -]) +@pytest.mark.parametrize( + ["half_ones", "reduction", "ignore_index", "expected"], + [ + pytest.param(False, "none", None, torch.Tensor([1, 1, 1])), + pytest.param(False, "elementwise_mean", None, torch.Tensor([1])), + pytest.param(False, "none", 0, torch.Tensor([1, 1])), + pytest.param(True, "none", None, torch.Tensor([0.5, 0.5, 0.5])), + pytest.param(True, "elementwise_mean", None, torch.Tensor([0.5])), + pytest.param(True, "none", 0, torch.Tensor([0.5, 0.5])), + ], +) def test_iou(half_ones, reduction, ignore_index, expected): pred = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) @@ -387,19 +418,17 @@ def test_iou(half_ones, reduction, ignore_index, expected): def test_iou_input_check(): with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"): - _ = iou(pred=torch.randint(0, 2, (3, 4, 3)), - target=torch.randint(0, 2, (3, 3))) + _ = iou(pred=torch.randint(0, 2, (3, 4, 3)), target=torch.randint(0, 2, (3, 3))) with pytest.raises(ValueError, match="'pred' must contain integer targets."): - _ = iou(pred=torch.rand((3, 3)), - target=torch.randint(0, 2, (3, 3))) + _ = iou(pred=torch.rand((3, 3)), target=torch.randint(0, 2, (3, 3))) -@pytest.mark.parametrize('metric', [auroc]) +@pytest.mark.parametrize("metric", [auroc]) def test_error_on_multiclass_input(metric): """ check that these metrics raise an error if they are used for multiclass problems """ - pred = torch.randint(0, 10, (100, )) - target = torch.randint(0, 10, (100, )) + pred = torch.randint(0, 10, (100,)) + target = torch.randint(0, 10, (100,)) with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"): _ = metric(pred, target) @@ -407,35 +436,38 @@ def test_error_on_multiclass_input(metric): # TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see # https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our # `absent_score`. -@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [ - # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid - # scores the function can return ([0., 1.] range, inclusive). - # 2 classes, class 0 is correct everywhere, class 1 is absent. - pytest.param([0], [0], None, -1., 2, [1., -1.]), - pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), - # absent_score not applied if only class 0 is present and it's the only class. - pytest.param([0], [0], None, -1., 1, [1.]), - # 2 classes, class 1 is correct everywhere, class 0 is absent. - pytest.param([1], [1], None, -1., 2, [-1., 1.]), - pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), - # When 0 index ignored, class 0 does not get a score (not even the absent_score). - pytest.param([1], [1], 0, -1., 2, [1.0]), - # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. - pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), - pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), - # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. - pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), - pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class - # 2 is absent. - pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class - # 2 is absent. - pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), - # Sanity checks with absent_score of 1.0. - pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), - pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), -]) +@pytest.mark.parametrize( + ["pred", "target", "ignore_index", "absent_score", "num_classes", "expected"], + [ + # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid + # scores the function can return ([0., 1.] range, inclusive). + # 2 classes, class 0 is correct everywhere, class 1 is absent. + pytest.param([0], [0], None, -1.0, 2, [1.0, -1.0]), + pytest.param([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]), + # absent_score not applied if only class 0 is present and it's the only class. + pytest.param([0], [0], None, -1.0, 1, [1.0]), + # 2 classes, class 1 is correct everywhere, class 0 is absent. + pytest.param([1], [1], None, -1.0, 2, [-1.0, 1.0]), + pytest.param([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]), + # When 0 index ignored, class 0 does not get a score (not even the absent_score). + pytest.param([1], [1], 0, -1.0, 2, [1.0]), + # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. + pytest.param([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]), + pytest.param([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]), + # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. + pytest.param([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]), + pytest.param([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class + # 2 is absent. + pytest.param([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class + # 2 is absent. + pytest.param([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]), + # Sanity checks with absent_score of 1.0. + pytest.param([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]), + pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]), + ], +) def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): iou_val = iou( pred=torch.tensor(pred), @@ -443,26 +475,29 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, ignore_index=ignore_index, absent_score=absent_score, num_classes=num_classes, - reduction='none', + reduction="none", ) assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) # example data taken from # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py -@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [ - # Ignoring an index outside of [0, num_classes-1] should have no effect. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), - # Ignoring a valid index drops only that index from the result. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), - # When reducing to mean or sum, the ignored index does not contribute to the output. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), -]) +@pytest.mark.parametrize( + ["pred", "target", "ignore_index", "num_classes", "reduction", "expected"], + [ + # Ignoring an index outside of [0, num_classes-1] should have no effect. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]), + # Ignoring a valid index drops only that index from the result. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1 / 2]), + # When reducing to mean or sum, the ignored index does not contribute to the output. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "elementwise_mean", [7 / 12]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), + ], +) def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): iou_val = iou( pred=torch.tensor(pred), From 96d40c87d7de9d30af8a130328efdaac002bedff Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 12:27:41 +0100 Subject: [PATCH 43/94] Minor changes --- docs/source/metrics.rst | 6 ++++++ .../metrics/classification/{utils.py => helpers.py} | 0 tests/metrics/classification/test_inputs.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) rename pytorch_lightning/metrics/classification/{utils.py => helpers.py} (100%) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 6b9dd8307a457..ee141dc74a679 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -424,6 +424,12 @@ recall [func] .. autofunction:: pytorch_lightning.metrics.functional.classification.recall :noindex: +select_topk [func] +~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.utils.select_topk + :noindex: + stat_scores [func] ~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/metrics/classification/utils.py b/pytorch_lightning/metrics/classification/helpers.py similarity index 100% rename from pytorch_lightning/metrics/classification/utils.py rename to pytorch_lightning/metrics/classification/helpers.py diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 6b5a03fcf1ea6..8ad3dd99240c1 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -3,7 +3,7 @@ from torch import randint, rand from pytorch_lightning.metrics.utils import to_onehot, select_topk -from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification from tests.metrics.classification.inputs import ( Input, _binary_inputs as _bin, From 627d99ad42b80eeb881f735f368b0b8cd5be4251 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 12:40:39 +0100 Subject: [PATCH 44/94] Fix imports --- pytorch_lightning/metrics/functional/accuracy.py | 2 +- pytorch_lightning/metrics/functional/hamming_loss.py | 2 +- tests/metrics/classification/test_accuracy.py | 2 +- tests/metrics/classification/test_hamming_loss.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index f029f1a399870..56c0412044156 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -14,7 +14,7 @@ from typing import Tuple, Optional import torch -from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification def _accuracy_update( diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py index 08c7c9056a960..ecd6d77011ec0 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -14,7 +14,7 @@ from typing import Tuple, Union import torch -from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification def _hamming_loss_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 9cbf5ec855885..9af19474064cf 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -5,7 +5,7 @@ from pytorch_lightning.metrics import Accuracy from pytorch_lightning.metrics.functional import accuracy -from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, diff --git a/tests/metrics/classification/test_hamming_loss.py b/tests/metrics/classification/test_hamming_loss.py index 1d30a7cdabc12..aab6fa353b4bc 100644 --- a/tests/metrics/classification/test_hamming_loss.py +++ b/tests/metrics/classification/test_hamming_loss.py @@ -4,7 +4,7 @@ from pytorch_lightning.metrics import HammingLoss from pytorch_lightning.metrics.functional import hamming_loss -from pytorch_lightning.metrics.classification.utils import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, From de3defb83af688f98d38fce6010cb603e3ded2b8 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 13:06:09 +0100 Subject: [PATCH 45/94] Improve docstring descriptions --- pytorch_lightning/metrics/classification/hamming_loss.py | 8 +++++++- pytorch_lightning/metrics/functional/accuracy.py | 7 ++++++- pytorch_lightning/metrics/functional/hamming_loss.py | 8 +++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/classification/hamming_loss.py b/pytorch_lightning/metrics/classification/hamming_loss.py index 5432a03c8403b..42f7f56439f26 100644 --- a/pytorch_lightning/metrics/classification/hamming_loss.py +++ b/pytorch_lightning/metrics/classification/hamming_loss.py @@ -20,7 +20,13 @@ class HammingLoss(Metric): """ - Computes the share of wrongly predicted labels. + Computes the average Hamming loss or <`Hamming distance`> between targets and predictions: + + .. math:: \text{Hamming loss} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`th label of the :math:`i`th sample of that + tensor. This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it treats each possible label separately - meaning that, for example, multi-class data is diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 56c0412044156..67ffdf40dd69f 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -52,7 +52,12 @@ def accuracy( top_k: Optional[int] = None, ) -> torch.Tensor: """ - Computes the share of entirely correctly predicted samples. + Computes `Accuracy `_: + + .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y_i}) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. This metric generalizes to subset accuracy for multilabel data, and similarly for multi-dimensional multi-class data: for the sample to be counted as correct, the the diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py index ecd6d77011ec0..d2d90e217bdfa 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -32,7 +32,13 @@ def _hamming_loss_compute(correct: torch.Tensor, total: Union[int, torch.Tensor] def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: """ - Computes the share of wrongly predicted labels. + Computes the average Hamming loss or <`Hamming distance`> between targets and predictions: + + .. math:: \text{Hamming loss} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`th label of the :math:`i`th sample of that + tensor. This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it treats each possible label separately - meaning that, for example, multi-class data is From f3c47f980dbf00c5014c45a26e7a39233b039b03 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 13:47:44 +0100 Subject: [PATCH 46/94] Fix edge case and simplify testing --- .../metrics/classification/helpers.py | 4 +- tests/metrics/classification/test_inputs.py | 159 +++++++++--------- 2 files changed, 86 insertions(+), 77 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 7cfc505e6c673..afb97e6e0a74f 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -425,9 +425,9 @@ def _input_format_classification( preds = select_topk(preds, top_k) else: num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 - preds = to_onehot(preds, num_classes) + preds = to_onehot(preds, max(2,num_classes)) - target = to_onehot(target, num_classes) + target = to_onehot(target, max(2,num_classes)) if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8ad3dd99240c1..c4d01d282fa57 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -112,49 +112,50 @@ def _mlmd_prob_to_mc_preds_tr(x): @pytest.mark.parametrize( - "inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", + "inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target", [ ############################# # Test usual expected cases - (_bin, THRESHOLD, None, False, None, "multi-class", _usq, _usq), - (_bin_prob, THRESHOLD, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), - (_ml_prob, THRESHOLD, None, None, None, "multi-label", _thrs, _idn), - (_ml, THRESHOLD, None, False, None, "multi-dim multi-class", _idn, _idn), - (_ml_prob, THRESHOLD, None, None, None, "multi-label", _ml_preds_tr, _rshp1), - (_mlmd, THRESHOLD, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), - (_mc, THRESHOLD, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), - (_mc_prob, THRESHOLD, None, None, None, "multi-class", _top1, _onehot), - (_mc_prob, THRESHOLD, None, None, 2, "multi-class", _top2, _onehot), - (_mdmc, THRESHOLD, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), - (_mdmc_prob, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), - (_mdmc_prob, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), - (_mdmc_prob_many_dims, THRESHOLD, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), - (_mdmc_prob_many_dims, THRESHOLD, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), + (_bin, None, False, None, "multi-class", _usq, _usq), + (_bin, 1, False, None, "multi-class", _usq, _usq), + (_bin_prob, None, None, None, "binary", lambda x: _usq(_thrs(x)), _usq), + (_ml_prob, None, None, None, "multi-label", _thrs, _idn), + (_ml, None, False, None, "multi-dim multi-class", _idn, _idn), + (_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), + (_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), + (_mc_prob, None, None, None, "multi-class", _top1, _onehot), + (_mc_prob, None, None, 2, "multi-class", _top2, _onehot), + (_mdmc, NUM_CLASSES, None, None, "multi-dim multi-class", _onehot, _onehot), + (_mdmc_prob, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot), + (_mdmc_prob, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot), + (_mdmc_prob_many_dims, None, None, None, "multi-dim multi-class", _top1_rshp2, _onehot_rshp1), + (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), ########################### # Test some special cases # Binary as multiclass - (_bin, THRESHOLD, None, None, None, "multi-class", _onehot2, _onehot2), + (_bin, None, None, None, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass - (_bin_prob, THRESHOLD, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), + (_bin_prob, None, True, None, "binary", _probs_to_mc_preds_tr, _onehot2), # Multilabel as multiclass - (_ml, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), + (_ml, None, True, None, "multi-dim multi-class", _onehot2, _onehot2), # Multilabel probs as multiclass - (_ml_prob, THRESHOLD, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), + (_ml_prob, None, True, None, "multi-label", _probs_to_mc_preds_tr, _onehot2), # Multidim multilabel as multiclass - (_mlmd, THRESHOLD, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), + (_mlmd, None, True, None, "multi-dim multi-class", _onehot2_rshp1, _onehot2_rshp1), # Multidim multilabel probs as multiclass - (_mlmd_prob, THRESHOLD, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), + (_mlmd_prob, None, True, None, "multi-label", _mlmd_prob_to_mc_preds_tr, _onehot2_rshp1), # Multiclass prob with 2 classes as binary - (_mc_prob_2cls, THRESHOLD, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), + (_mc_prob_2cls, None, False, None, "multi-class", lambda x: _top1(x)[:, [1]], _usq), # Multi-dim multi-class with 2 classes as multi-label - (_mdmc_prob_2cls, THRESHOLD, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), + (_mdmc_prob_2cls, None, False, None, "multi-dim multi-class", lambda x: _top1(x)[:, 1], _idn), ], ) -def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): +def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): preds_out, target_out, mode = _input_format_classification( preds=inputs.preds[0], target=inputs.target[0], - threshold=threshold, + threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k, @@ -168,7 +169,7 @@ def test_usual_cases(inputs, threshold, num_classes, is_multiclass, top_k, exp_m preds_out, target_out, mode = _input_format_classification( preds=inputs.preds[0][[0], ...], target=inputs.target[0][[0], ...], - threshold=threshold, + threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k, @@ -194,92 +195,100 @@ def test_threshold(): ######################################################################## +def test_incorrect_threshold(): + with pytest.raises(ValueError): + _input_format_classification(preds=rand(size=(7,)), target=randint(high=2, size=(7,)), threshold=1.5) + + @pytest.mark.parametrize( - "preds, target, threshold, num_classes, is_multiclass, top_k", + "preds, target, num_classes, is_multiclass", [ # Target not integer - (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), 0.5, None, None, None), + (randint(high=2, size=(7,)), randint(high=2, size=(7,)).float(), None, None), # Target negative - (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), 0.5, None, None, None), + (randint(high=2, size=(7,)), -randint(high=2, size=(7,)), None, None), # Preds negative integers - (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), + (-randint(high=2, size=(7,)), randint(high=2, size=(7,)), None, None), # Negative probabilities - (-rand(size=(7,)), randint(high=2, size=(7,)), 0.5, None, None, None), - # Threshold outside of [0,1] - (rand(size=(7,)), randint(high=2, size=(7,)), 1.5, None, None, None), + (-rand(size=(7,)), randint(high=2, size=(7,)), None, None), # is_multiclass=False and target > 1 - (rand(size=(7,)), randint(low=2, high=4, size=(7,)), 0.5, None, False, None), + (rand(size=(7,)), randint(low=2, high=4, size=(7,)), None, False), # is_multiclass=False and preds integers with > 1 - (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), 0.5, None, False, None), + (randint(low=2, high=4, size=(7,)), randint(high=2, size=(7,)), None, False), # Wrong batch size - (randint(high=2, size=(8,)), randint(high=2, size=(7,)), 0.5, None, None, None), + (randint(high=2, size=(8,)), randint(high=2, size=(7,)), None, None), # Completely wrong shape - (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), 0.5, None, None, None), + (randint(high=2, size=(7,)), randint(high=2, size=(7, 4)), None, None), # Same #dims, different shape - (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), 0.5, None, None, None), + (randint(high=2, size=(7, 3)), randint(high=2, size=(7, 4)), None, None), # Same shape and preds floats, target not binary - (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), 0.5, None, None, None), + (rand(size=(7, 3)), randint(low=2, high=4, size=(7, 3)), None, None), # #dims in preds = 1 + #dims in target, C shape not second or last - (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), + (rand(size=(7, 3, 4, 3)), randint(high=4, size=(7, 3, 3)), None, None), # #dims in preds = 1 + #dims in target, preds not float - (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), 0.5, None, None, None), + (randint(high=2, size=(7, 3, 3, 4)), randint(high=4, size=(7, 3, 3)), None, None), # is_multiclass=False, with C dimension > 2 - (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), 0.5, None, False, None), + (_mc_prob.preds[0], randint(high=2, size=(BATCH_SIZE,)), None, False), # Probs of multiclass preds do not sum up to 1 - (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), 0.5, None, None, None), + (rand(size=(7, 3, 5)), randint(high=2, size=(7, 5)), None, None), # Max target larger or equal to C dimension - (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), 0.5, None, None, None), + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE,)), None, None), # C dimension not equal to num_classes - (_mc_prob.preds[0], _mc_prob.target[0], 0.5, NUM_CLASSES + 1, None, None), + (_mc_prob.preds[0], _mc_prob.target[0], NUM_CLASSES + 1, None), # Max target larger than num_classes (with #dim preds = 1 + #dims target) - ( - _mc_prob.preds[0], - randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), - 0.5, - 4, - None, - None, - ), + (_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), 4, None), # Max target larger than num_classes (with #dim preds = #dims target) - (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 0.5, 4, None, None), + (randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None), # Max preds larger than num_classes (with #dim preds = #dims target) - (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 0.5, 4, None, None), + (randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None), # Num_classes=1, but is_multiclass not false - (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 0.5, 1, None, 1), + (randint(high=2, size=(7,)), randint(high=2, size=(7,)), 1, None), # is_multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes - (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), + (randint(high=2, size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), # Multilabel input with implied class dimension != num_classes - (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 0.5, 4, False, None), + (rand(size=(7, 3, 3)), randint(high=2, size=(7, 3, 3)), 4, False), # Multilabel input with is_multiclass=True, but num_classes != 2 (or None) - (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 0.5, 4, True, None), + (rand(size=(7, 3)), randint(high=2, size=(7, 3)), 4, True), # Binary input, num_classes > 2 - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 4, None, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 4, None), # Binary input, num_classes == 2 and is_multiclass not True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, None, None), - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 2, False, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 2, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 2, False), # Binary input, num_classes == 1 and is_multiclass=True - (rand(size=(7,)), randint(high=2, size=(7,)), 0.5, 1, True, None), + (rand(size=(7,)), randint(high=2, size=(7,)), 1, True), + ], +) +def test_incorrect_inputs(preds, target, num_classes, is_multiclass): + with pytest.raises(ValueError): + _input_format_classification( + preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass + ) + + +@pytest.mark.parametrize( + "preds, target, num_classes, is_multiclass, top_k", + [ # Topk set with non (md)mc prob data - (_bin.preds[0], _bin.target[0], 0.5, None, None, 2), - (_bin_prob.preds[0], _bin_prob.target[0], 0.5, None, None, 2), - (_mc.preds[0], _mc.target[0], 0.5, None, None, 2), - (_ml.preds[0], _ml.target[0], 0.5, None, None, 2), - (_mlmd.preds[0], _mlmd.target[0], 0.5, None, None, 2), - (_ml_prob.preds[0], _ml_prob.target[0], 0.5, None, None, 2), - (_mlmd_prob.preds[0], _mlmd_prob.target[0], 0.5, None, None, 2), - (_mdmc.preds[0], _mdmc.target[0], 0.5, None, None, 2), + (_bin.preds[0], _bin.target[0], None, None, 2), + (_bin_prob.preds[0], _bin_prob.target[0], None, None, 2), + (_mc.preds[0], _mc.target[0], None, None, 2), + (_ml.preds[0], _ml.target[0], None, None, 2), + (_mlmd.preds[0], _mlmd.target[0], None, None, 2), + (_ml_prob.preds[0], _ml_prob.target[0], None, None, 2), + (_mlmd_prob.preds[0], _mlmd_prob.target[0], None, None, 2), + (_mdmc.preds[0], _mdmc.target[0], None, None, 2), # top_k =2 with 2 classes, is_multiclass=False - (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], 0.5, None, False, 2), + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2), # top_k = number of classes (C dimension) - (_mc_prob.preds[0], _mc_prob.target[0], 0.5, None, None, NUM_CLASSES), + (_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES), ], ) -def test_incorrect_inputs(preds, target, threshold, num_classes, is_multiclass, top_k): +def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k): with pytest.raises(ValueError): _input_format_classification( preds=preds, target=target, - threshold=threshold, + threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k, From b7ced6e7469732fef0f503953e815ad758461c3f Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 15:54:17 +0100 Subject: [PATCH 47/94] Fix docs --- pytorch_lightning/metrics/classification/accuracy.py | 3 ++- .../metrics/classification/hamming_loss.py | 10 ++++++---- pytorch_lightning/metrics/functional/accuracy.py | 5 +++-- pytorch_lightning/metrics/functional/hamming_loss.py | 10 ++++++---- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index a215eb8644e98..e2c49225c8798 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -23,7 +23,8 @@ class Accuracy(Metric): r""" Computes `Accuracy `_: - .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y_i}) + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. diff --git a/pytorch_lightning/metrics/classification/hamming_loss.py b/pytorch_lightning/metrics/classification/hamming_loss.py index 42f7f56439f26..ac2e04ec88612 100644 --- a/pytorch_lightning/metrics/classification/hamming_loss.py +++ b/pytorch_lightning/metrics/classification/hamming_loss.py @@ -19,13 +19,15 @@ class HammingLoss(Metric): - """ - Computes the average Hamming loss or <`Hamming distance`> between targets and predictions: + r""" + Computes the average Hamming loss or `Hamming distance `_ + between targets and predictions: - .. math:: \text{Hamming loss} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + .. math:: + \text{Hamming loss} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, - and :math:`\bullet_{il}` refers to the :math:`l`th label of the :math:`i`th sample of that + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that tensor. This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 67ffdf40dd69f..a90def06555ba 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -51,10 +51,11 @@ def accuracy( mdmc_accuracy: str = "subset", top_k: Optional[int] = None, ) -> torch.Tensor: - """ + r""" Computes `Accuracy `_: - .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y_i}) + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py index d2d90e217bdfa..206e6eb7061f4 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -31,13 +31,15 @@ def _hamming_loss_compute(correct: torch.Tensor, total: Union[int, torch.Tensor] def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: - """ - Computes the average Hamming loss or <`Hamming distance`> between targets and predictions: + r""" + Computes the average Hamming loss or `Hamming distance `_ + between targets and predictions: - .. math:: \text{Hamming loss} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + .. math:: + \text{Hamming loss} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, - and :math:`\bullet_{il}` refers to the :math:`l`th label of the :math:`i`th sample of that + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that tensor. This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it From e91e56448cc247b3386a37a813cf90997380e172 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 16:03:45 +0100 Subject: [PATCH 48/94] PEP8 --- pytorch_lightning/metrics/classification/accuracy.py | 2 +- pytorch_lightning/metrics/classification/hamming_loss.py | 2 +- pytorch_lightning/metrics/classification/helpers.py | 4 ++-- pytorch_lightning/metrics/functional/accuracy.py | 2 +- pytorch_lightning/metrics/functional/hamming_loss.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index e2c49225c8798..e2b3f1a520258 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -23,7 +23,7 @@ class Accuracy(Metric): r""" Computes `Accuracy `_: - .. math:: + .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a diff --git a/pytorch_lightning/metrics/classification/hamming_loss.py b/pytorch_lightning/metrics/classification/hamming_loss.py index ac2e04ec88612..c4fc74eeb6a0e 100644 --- a/pytorch_lightning/metrics/classification/hamming_loss.py +++ b/pytorch_lightning/metrics/classification/hamming_loss.py @@ -23,7 +23,7 @@ class HammingLoss(Metric): Computes the average Hamming loss or `Hamming distance `_ between targets and predictions: - .. math:: + .. math:: \text{Hamming loss} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index afb97e6e0a74f..0f010a63af505 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -425,9 +425,9 @@ def _input_format_classification( preds = select_topk(preds, top_k) else: num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 - preds = to_onehot(preds, max(2,num_classes)) + preds = to_onehot(preds, max(2, num_classes)) - target = to_onehot(target, max(2,num_classes)) + target = to_onehot(target, max(2, num_classes)) if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index a90def06555ba..707e76b480d9c 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -54,7 +54,7 @@ def accuracy( r""" Computes `Accuracy `_: - .. math:: + .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py index 206e6eb7061f4..9ef322aa841e4 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -35,7 +35,7 @@ def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0 Computes the average Hamming loss or `Hamming distance `_ between targets and predictions: - .. math:: + .. math:: \text{Hamming loss} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, From 798ec0353a68fa34aaf51064089c9fa0caf7e8b2 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Sun, 6 Dec 2020 16:07:34 +0100 Subject: [PATCH 49/94] Reorder imports --- pytorch_lightning/metrics/classification/__init__.py | 3 +-- pytorch_lightning/metrics/functional/__init__.py | 7 ++----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 8a69e0f1c8c8b..15eb3a8b2ad91 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.metrics.classification.hamming_loss import HammingLoss -from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.average_precision import AveragePrecision from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 +from pytorch_lightning.metrics.classification.hamming_loss import HammingLoss from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve from pytorch_lightning.metrics.classification.roc import ROC diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index fddaeb71f6e38..f7d5bf7189353 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -24,10 +24,12 @@ stat_scores_multiple_classes, iou, ) +from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # TODO: unify metrics between class and functional, add below from pytorch_lightning.metrics.functional.explained_variance import explained_variance from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 +from pytorch_lightning.metrics.functional.hamming_loss import hamming_loss from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error @@ -37,8 +39,3 @@ from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity from pytorch_lightning.metrics.functional.ssim import ssim - -from pytorch_lightning.metrics.functional.accuracy import accuracy -from pytorch_lightning.metrics.functional.hamming_loss import hamming_loss -from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix -from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 From a7c143e98073a4d95d7ca6f7e151380d2d4acf49 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:06:57 +0100 Subject: [PATCH 50/94] Update changelog --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ba46ebdc8520..d0442253a4a89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## Unreleased +### Added + +- `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) + +- `HammingLoss` metric to compute the hamming loss (distance) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) + +### Changed + +- `Accuracy` metrics now computes subset accuracy for multi-label inputs (consistent with scikit-learn's `accuracy_score`) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) + ### Fixed - Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138)) From 531ae33a98bc1496dc63145b3b937c9b6ca2dbc1 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:08:46 +0100 Subject: [PATCH 51/94] Update docstring --- pytorch_lightning/metrics/classification/accuracy.py | 3 ++- pytorch_lightning/metrics/functional/accuracy.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index e2b3f1a520258..029b035194333 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -54,7 +54,8 @@ class has to be correctly predicted across all extra dimension for each sample i If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension - must be predicted correctly (the ``top_k`` option still applies here). + must be predicted correctly (the ``top_k`` option still applies here). The final score is then + simply the number of totally correctly predicted samples. top_k: Number of highest probability entries for each sample to convert to 1s, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 707e76b480d9c..ee30eeb70fa69 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -87,7 +87,8 @@ class has to be correctly predicted across all extra dimension for each sample i If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension - must be predicted correctly (the ``top_k`` option still applies here). + must be predicted correctly (the ``top_k`` option still applies here). The final score is then + simply the number of totally correctly predicted samples. top_k: Number of highest probability entries for each sample to convert to 1s, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The From a66cf310a57723486e69a7c5baf6a6abc599d295 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:22:40 +0100 Subject: [PATCH 52/94] Update docstring --- pytorch_lightning/metrics/classification/accuracy.py | 12 ++++++------ pytorch_lightning/metrics/functional/accuracy.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 029b035194333..944e874fa5df8 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -29,16 +29,16 @@ class Accuracy(Metric): Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - This metric generalizes to subset accuracy for multilabel data, and similarly for - multi-dimensional multi-class data: for the sample to be counted as correct, the the - class has to be correctly predicted across all extra dimension for each sample in the - ``N`` dimension. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` - is this is not what you want. - For multi-class and multi-dimensional multi-class data with probability predictions, the parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the top-K highest probability items are considered to find the correct label. + This metric generalizes to subset accuracy for multilabel data: for the sample to be counted as + correct, all labels in that sample have to be correctly predicted. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` + is this is not what you want. In multi-dimensional multi-class case the `mdmc_accuracy` parameters + gives you a choice between computing the subset accuracy, or counting each sample on the extra + axis separately. + Accepts all input types listed in :ref:`metrics:Input types`. Args: diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index ee30eeb70fa69..34383c7064350 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -60,16 +60,16 @@ def accuracy( Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - This metric generalizes to subset accuracy for multilabel data, and similarly for - multi-dimensional multi-class data: for the sample to be counted as correct, the the - class has to be correctly predicted across all extra dimension for each sample in the - ``N`` dimension. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` - is this is not what you want. - For multi-class and multi-dimensional multi-class data with probability predictions, the parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the top-K highest probability items are considered to find the correct label. + This metric generalizes to subset accuracy for multilabel data: for the sample to be counted as + correct, all labels in that sample have to be correctly predicted. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` + is this is not what you want. In multi-dimensional multi-class case the `mdmc_accuracy` parameters + gives you a choice between computing the subset accuracy, or counting each sample on the extra + axis separately. + Accepts all input types listed in :ref:`metrics:Input types`. Args: From 89b09f8a2f06af3cd6c4573d17bf4b7c3f4c1d4f Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:49:23 +0100 Subject: [PATCH 53/94] Reverse formatting changes for tests --- .../metrics/functional/test_classification.py | 382 ++++++++---------- 1 file changed, 159 insertions(+), 223 deletions(-) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 7262fbfd728c9..7bb3df9d8e392 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -25,18 +25,15 @@ from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical -@pytest.mark.parametrize( - ["sklearn_metric", "torch_metric", "only_binary"], - [ - pytest.param(partial(sk_jaccard_score, average="macro"), iou, False, id="iou"), - pytest.param(partial(sk_precision, average="micro"), precision, False, id="precision"), - pytest.param(partial(sk_recall, average="micro"), recall, False, id="recall"), - pytest.param(sk_roc_auc_score, auroc, True, id="auroc"), - ], -) +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [ + pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'), + pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'), + pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'), + pytest.param(sk_roc_auc_score, auroc, True, id='auroc') +]) def test_against_sklearn(sklearn_metric, torch_metric, only_binary): """Compare PL metrics to sklearn version. """ - device = "cuda" if torch.cuda.is_available() else "cpu" + device = 'cuda' if torch.cuda.is_available() else 'cpu' # for metrics with only_binary=False, we try out different combinations of number # of labels in pred and target (also test binary) @@ -47,7 +44,8 @@ def test_against_sklearn(sklearn_metric, torch_metric, only_binary): pred = torch.randint(n_cls_pred, (300,), device=device) target = torch.randint(n_cls_target, (300,), device=device) - sk_score = sklearn_metric(target.cpu().detach().numpy(), pred.cpu().detach().numpy()) + sk_score = sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()) pl_score = torch_metric(pred, target) # if multi output @@ -60,21 +58,20 @@ def test_against_sklearn(sklearn_metric, torch_metric, only_binary): assert torch.allclose(sk_score, pl_score) -@pytest.mark.parametrize("class_reduction", ["micro", "macro", "weighted"]) -@pytest.mark.parametrize( - ["sklearn_metric", "torch_metric"], - [ - pytest.param(sk_precision, precision, id="precision"), - pytest.param(sk_recall, recall, id="recall"), - ], -) +@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted']) +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(sk_precision, precision, id='precision'), + pytest.param(sk_recall, recall, id='recall'), +]) def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric): - """Test metrics where the class_reduction parameter have a correponding - value in sklearn""" - device = "cuda" if torch.cuda.is_available() else "cpu" + """ Test metrics where the class_reduction parameter have a correponding + value in sklearn """ + device = 'cuda' if torch.cuda.is_available() else 'cpu' pred = torch.randint(10, (300,), device=device) target = torch.randint(10, (300,), device=device) - sk_score = sklearn_metric(target.cpu().detach().numpy(), pred.cpu().detach().numpy(), average=class_reduction) + sk_score = sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + average=class_reduction) sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) pl_score = torch_metric(pred, target, class_reduction=class_reduction) assert torch.allclose(sk_score, pl_score) @@ -82,12 +79,10 @@ def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, to def test_onehot(): test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - expected = torch.stack( - [ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]), - ] - ) + expected = torch.stack([ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) + ]) assert test_tensor.shape == (2, 5) assert expected.shape == (2, 10, 5) @@ -105,12 +100,10 @@ def test_onehot(): def test_to_categorical(): - test_tensor = torch.stack( - [ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]), - ] - ).to(torch.float) + test_tensor = torch.stack([ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) + ]).to(torch.float) expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) assert expected.shape == (2, 5) @@ -122,25 +115,20 @@ def test_to_categorical(): assert torch.allclose(result, expected.to(result.dtype)) -@pytest.mark.parametrize( - ["pred", "target", "num_classes", "expected_num_classes"], - [ - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - ], -) +@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), + pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), +]) def test_get_num_classes(pred, target, num_classes, expected_num_classes): assert get_num_classes(pred, target, num_classes) == expected_num_classes -@pytest.mark.parametrize( - ["pred", "target", "expected_tp", "expected_fp", "expected_tn", "expected_fn", "expected_support"], - [ - pytest.param(torch.tensor([0.0, 2.0, 4.0, 4.0]), torch.tensor([0.0, 4.0, 3.0, 4.0]), 1, 1, 1, 1, 2), - pytest.param(to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), torch.tensor([0.0, 4.0, 3.0, 4.0]), 1, 1, 1, 1, 2), - ], -) +@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', + 'expected_tn', 'expected_fn', 'expected_support'], [ + pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2) +]) def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4) @@ -151,54 +139,18 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect assert sup.item() == expected_support -@pytest.mark.parametrize( - ["pred", "target", "reduction", "expected_tp", "expected_fp", "expected_tn", "expected_fn", "expected_support"], - [ - pytest.param( - torch.tensor([0.0, 2.0, 4.0, 4.0]), - torch.tensor([0.0, 4.0, 3.0, 4.0]), - "none", - [1, 0, 0, 0, 1], - [0, 0, 1, 0, 1], - [3, 4, 3, 3, 1], - [0, 0, 0, 1, 1], - [1, 0, 0, 1, 2], - ), - pytest.param( - to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), - torch.tensor([0.0, 4.0, 3.0, 4.0]), - "none", - [1, 0, 0, 0, 1], - [0, 0, 1, 0, 1], - [3, 4, 3, 3, 1], - [0, 0, 0, 1, 1], - [1, 0, 0, 1, 2], - ), - pytest.param( - to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), - torch.tensor([0.0, 4.0, 3.0, 4.0]), - "sum", - torch.tensor(2), - torch.tensor(2), - torch.tensor(14), - torch.tensor(2), - torch.tensor(4), - ), - pytest.param( - to_onehot(torch.tensor([0.0, 2.0, 4.0, 4.0])), - torch.tensor([0.0, 4.0, 3.0, 4.0]), - "elementwise_mean", - torch.tensor(0.4), - torch.tensor(0.4), - torch.tensor(2.8), - torch.tensor(0.4), - torch.tensor(0.8), - ), - ], -) -def test_stat_scores_multiclass( - pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support -): +@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp', + 'expected_tn', 'expected_fn', 'expected_support'], [ + pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none', + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none', + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum', + torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean', + torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8)) +]) +def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction) assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) @@ -208,30 +160,22 @@ def test_stat_scores_multiclass( assert torch.allclose(torch.tensor(expected_support).to(sup), sup) -@pytest.mark.parametrize( - ["pred", "target", "expected_prec", "expected_rec"], - [ - pytest.param(torch.tensor([1.0, 0.0, 1.0, 0.0]), torch.tensor([0.0, 1.0, 1.0, 0.0]), [0.5, 0.5], [0.5, 0.5]), - pytest.param( - to_onehot(torch.tensor([1.0, 0.0, 1.0, 0.0])), torch.tensor([0.0, 1.0, 1.0, 0.0]), [0.5, 0.5], [0.5, 0.5] - ), - ], -) +@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ + pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), + pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) +]) def test_precision_recall(pred, target, expected_prec, expected_rec): - prec = precision(pred, target, class_reduction="none") - rec = recall(pred, target, class_reduction="none") + prec = precision(pred, target, class_reduction='none') + rec = recall(pred, target, class_reduction='none') assert torch.allclose(torch.tensor(expected_prec).to(prec), prec) assert torch.allclose(torch.tensor(expected_rec).to(rec), rec) -@pytest.mark.parametrize( - ["sample_weight", "pos_label", "exp_shape"], - [ - pytest.param(1, 1.0, 42), - pytest.param(None, 1.0, 42), - ], -) +@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ + pytest.param(1, 1., 42), + pytest.param(None, 1., 42), +]) def test_binary_clf_curve(sample_weight, pos_label, exp_shape): # TODO: move back the pred and target to test func arguments # if you fix the array inside the function, you'd also have fix the shape, @@ -252,94 +196,90 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): assert thresh.shape == (exp_shape,) -@pytest.mark.parametrize( - ["pred", "target", "expected"], - [ - pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.0), - pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.0), - pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5), - pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.0), - pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5), - ], -) +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.), + pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.), + pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5), + pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.), + pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5), +]) def test_auroc(pred, target, expected): score = auroc(torch.tensor(pred), torch.tensor(target)).item() assert score == expected def test_multiclass_auroc(): - with pytest.raises(ValueError, match=r".*probabilities, i.e. they should sum up to 1.0 over classes"): - _ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9], [1.0, 0]]), target=torch.tensor([0, 1])) + with pytest.raises(ValueError, + match=r".*probabilities, i.e. they should sum up to 1.0 over classes"): + _ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9], + [1.0, 0]]), + target=torch.tensor([0, 1])) - with pytest.raises(ValueError, match=r".*not defined when all of the classes do not occur in the target.*"): - _ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1), target=torch.tensor([1, 0, 1, 0])) + with pytest.raises(ValueError, + match=r".*not defined when all of the classes do not occur in the target.*"): + _ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1), + target=torch.tensor([1, 0, 1, 0])) - with pytest.raises(ValueError, match=r".*does not equal the number of classes passed in 'num_classes'.*"): - _ = multiclass_auroc( - pred=torch.rand((5, 4)).softmax(dim=1), target=torch.tensor([0, 1, 2, 2, 3]), num_classes=6 - ) + with pytest.raises(ValueError, + match=r".*does not equal the number of classes passed in 'num_classes'.*"): + _ = multiclass_auroc(pred=torch.rand((5, 4)).softmax(dim=1), + target=torch.tensor([0, 1, 2, 2, 3]), + num_classes=6) -@pytest.mark.parametrize("n_cls", [2, 5, 10, 50]) +@pytest.mark.parametrize('n_cls', [2, 5, 10, 50]) def test_multiclass_auroc_against_sklearn(n_cls): - device = "cuda" if torch.cuda.is_available() else "cpu" + device = 'cuda' if torch.cuda.is_available() else 'cpu' n_samples = 300 pred = torch.rand(n_samples, n_cls, device=device).softmax(dim=1) target = torch.randint(n_cls, (n_samples,), device=device) # Make sure target includes all class labels so that multiclass AUROC is defined - target[10 : 10 + n_cls] = torch.arange(n_cls) + target[10:10 + n_cls] = torch.arange(n_cls) pl_score = multiclass_auroc(pred, target) # For the binary case, sklearn expects an (n_samples,) array of probabilities of # the positive class pred = pred[:, 1] if n_cls == 2 else pred - sk_score = sk_roc_auc_score(target.cpu().detach().numpy(), pred.cpu().detach().numpy(), multi_class="ovr") + sk_score = sk_roc_auc_score(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + multi_class="ovr") sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) assert torch.allclose(sk_score, pl_score) -@pytest.mark.parametrize( - ["x", "y", "expected"], - [ - pytest.param([0, 1], [0, 1], 0.5), - pytest.param([1, 0], [0, 1], 0.5), - pytest.param([1, 0, 0], [0, 1, 1], 0.5), - pytest.param([0, 1], [1, 1], 1), - pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), - ], -) +@pytest.mark.parametrize(['x', 'y', 'expected'], [ + pytest.param([0, 1], [0, 1], 0.5), + pytest.param([1, 0], [0, 1], 0.5), + pytest.param([1, 0, 0], [0, 1, 1], 0.5), + pytest.param([0, 1], [1, 1], 1), + pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), +]) def test_auc(x, y, expected): # Test Area Under Curve (AUC) computation assert auc(torch.tensor(x), torch.tensor(y)) == expected -@pytest.mark.parametrize( - ["pred", "target", "expected"], - [ - pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.0), - pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.0), - pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), - pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.0), - ], -) +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), + pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), + pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), + pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.), +]) def test_dice_score(pred, target, expected): score = dice_score(torch.tensor(pred), torch.tensor(target)) assert score == expected -@pytest.mark.parametrize( - ["half_ones", "reduction", "ignore_index", "expected"], - [ - pytest.param(False, "none", None, torch.Tensor([1, 1, 1])), - pytest.param(False, "elementwise_mean", None, torch.Tensor([1])), - pytest.param(False, "none", 0, torch.Tensor([1, 1])), - pytest.param(True, "none", None, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, "elementwise_mean", None, torch.Tensor([0.5])), - pytest.param(True, "none", 0, torch.Tensor([0.5, 0.5])), - ], -) +@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ + pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), + pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), + pytest.param(False, 'none', 0, torch.Tensor([1, 1])), + pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), + pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), + pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), +]) def test_iou(half_ones, reduction, ignore_index, expected): pred = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) @@ -356,17 +296,19 @@ def test_iou(half_ones, reduction, ignore_index, expected): def test_iou_input_check(): with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"): - _ = iou(pred=torch.randint(0, 2, (3, 4, 3)), target=torch.randint(0, 2, (3, 3))) + _ = iou(pred=torch.randint(0, 2, (3, 4, 3)), + target=torch.randint(0, 2, (3, 3))) with pytest.raises(ValueError, match="'pred' must contain integer targets."): - _ = iou(pred=torch.rand((3, 3)), target=torch.randint(0, 2, (3, 3))) + _ = iou(pred=torch.rand((3, 3)), + target=torch.randint(0, 2, (3, 3))) -@pytest.mark.parametrize("metric", [auroc]) +@pytest.mark.parametrize('metric', [auroc]) def test_error_on_multiclass_input(metric): """ check that these metrics raise an error if they are used for multiclass problems """ - pred = torch.randint(0, 10, (100,)) - target = torch.randint(0, 10, (100,)) + pred = torch.randint(0, 10, (100, )) + target = torch.randint(0, 10, (100, )) with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"): _ = metric(pred, target) @@ -374,38 +316,35 @@ def test_error_on_multiclass_input(metric): # TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see # https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our # `absent_score`. -@pytest.mark.parametrize( - ["pred", "target", "ignore_index", "absent_score", "num_classes", "expected"], - [ - # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid - # scores the function can return ([0., 1.] range, inclusive). - # 2 classes, class 0 is correct everywhere, class 1 is absent. - pytest.param([0], [0], None, -1.0, 2, [1.0, -1.0]), - pytest.param([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]), - # absent_score not applied if only class 0 is present and it's the only class. - pytest.param([0], [0], None, -1.0, 1, [1.0]), - # 2 classes, class 1 is correct everywhere, class 0 is absent. - pytest.param([1], [1], None, -1.0, 2, [-1.0, 1.0]), - pytest.param([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]), - # When 0 index ignored, class 0 does not get a score (not even the absent_score). - pytest.param([1], [1], 0, -1.0, 2, [1.0]), - # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. - pytest.param([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]), - pytest.param([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]), - # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. - pytest.param([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]), - pytest.param([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class - # 2 is absent. - pytest.param([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class - # 2 is absent. - pytest.param([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]), - # Sanity checks with absent_score of 1.0. - pytest.param([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]), - pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]), - ], -) +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [ + # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid + # scores the function can return ([0., 1.] range, inclusive). + # 2 classes, class 0 is correct everywhere, class 1 is absent. + pytest.param([0], [0], None, -1., 2, [1., -1.]), + pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), + # absent_score not applied if only class 0 is present and it's the only class. + pytest.param([0], [0], None, -1., 1, [1.]), + # 2 classes, class 1 is correct everywhere, class 0 is absent. + pytest.param([1], [1], None, -1., 2, [-1., 1.]), + pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), + # When 0 index ignored, class 0 does not get a score (not even the absent_score). + pytest.param([1], [1], 0, -1., 2, [1.0]), + # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. + pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), + pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), + # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. + pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), + pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class + # 2 is absent. + pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class + # 2 is absent. + pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), + # Sanity checks with absent_score of 1.0. + pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), + pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), +]) def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): iou_val = iou( pred=torch.tensor(pred), @@ -413,29 +352,26 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, ignore_index=ignore_index, absent_score=absent_score, num_classes=num_classes, - reduction="none", + reduction='none', ) assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) # example data taken from # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py -@pytest.mark.parametrize( - ["pred", "target", "ignore_index", "num_classes", "reduction", "expected"], - [ - # Ignoring an index outside of [0, num_classes-1] should have no effect. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]), - # Ignoring a valid index drops only that index from the result. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1 / 2]), - # When reducing to mean or sum, the ignored index does not contribute to the output. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "elementwise_mean", [7 / 12]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), - ], -) +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [ + # Ignoring an index outside of [0, num_classes-1] should have no effect. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), + # Ignoring a valid index drops only that index from the result. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), + # When reducing to mean or sum, the ignored index does not contribute to the output. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), +]) def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): iou_val = iou( pred=torch.tensor(pred), @@ -444,4 +380,4 @@ def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, ex num_classes=num_classes, reduction=reduction, ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) \ No newline at end of file From e7154377be8adbfcf1b682ae2930a555827b88d4 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:50:48 +0100 Subject: [PATCH 54/94] Change parameter order --- .../metrics/classification/accuracy.py | 14 +++++++------- pytorch_lightning/metrics/functional/accuracy.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 944e874fa5df8..7c4bcb247c766 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -45,6 +45,12 @@ class Accuracy(Metric): threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. mdmc_accuracy: Determines how should the extra dimension be handeled in case of multi-dimensional multi-class inputs. Options are ``"global"`` or ``"subset"``. @@ -56,12 +62,6 @@ class Accuracy(Metric): ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension must be predicted correctly (the ``top_k`` option still applies here). The final score is then simply the number of totally correctly predicted samples. - top_k: - Number of highest probability entries for each sample to convert to 1s, relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -93,8 +93,8 @@ class Accuracy(Metric): def __init__( self, threshold: float = 0.5, - mdmc_accuracy: str = "subset", top_k: Optional[int] = None, + mdmc_accuracy: str = "subset", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 34383c7064350..ca88213706c4f 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -48,8 +48,8 @@ def accuracy( preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, - mdmc_accuracy: str = "subset", top_k: Optional[int] = None, + mdmc_accuracy: str = "subset", ) -> torch.Tensor: r""" Computes `Accuracy `_: @@ -78,6 +78,12 @@ def accuracy( threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + top_k: + Number of highest probability entries for each sample to convert to 1s, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. mdmc_accuracy: Determines how should the extra dimension be handeled in case of multi-dimensional multi-class inputs. Options are ``"global"`` or ``"subset"``. @@ -89,12 +95,6 @@ def accuracy( ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension must be predicted correctly (the ``top_k`` option still applies here). The final score is then simply the number of totally correctly predicted samples. - top_k: - Number of highest probability entries for each sample to convert to 1s, relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. Example: From d5daec8013168a5276aaaaa72e01ed6c2b94588e Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:53:49 +0100 Subject: [PATCH 55/94] Remove formatting changes 2/2 --- .../metrics/functional/classification.py | 449 +++++++++++++----- 1 file changed, 335 insertions(+), 114 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index afb3b0f90803d..a346a75be0ce8 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,17 +15,102 @@ from typing import Callable, Optional, Sequence, Tuple import torch +from pytorch_lightning.metrics.functional.reduction import class_reduce, reduce from torch.nn import functional as F -from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce from pytorch_lightning.utilities import rank_zero_warn +def to_onehot( + tensor: torch.Tensor, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + Converts a dense label tensor to one-hot format + + Args: + tensor: dense label tensor, with shape [N, d1, d2, ...] + num_classes: number of classes C + + Output: + A sparse label tensor with shape [N, C, d1, d2, ...] + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> to_onehot(x) + tensor([[0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + """ + if num_classes is None: + num_classes = int(tensor.max().detach().item() + 1) + dtype, device, shape = tensor.dtype, tensor.device, tensor.shape + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], + dtype=dtype, device=device) + index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + return tensor_onehot.scatter_(1, index, 1.0) + + +def to_categorical( + tensor: torch.Tensor, + argmax_dim: int = 1 +) -> torch.Tensor: + """ + Converts a tensor of probabilities to a dense label tensor + + Args: + tensor: probabilities to get the categorical label [N, d1, d2, ...] + argmax_dim: dimension to apply + + Return: + A tensor with categorical labels [N, d2, ...] + + Example: + + >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + >>> to_categorical(x) + tensor([1, 0]) + + """ + return torch.argmax(tensor, dim=argmax_dim) + + +def get_num_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, +) -> int: + """ + Calculates the number of classes for a given prediction and target tensor. + + Args: + pred: predicted values + target: true labels + num_classes: number of classes if known + + Return: + An integer that represents the number of classes. + """ + num_target_classes = int(target.max().detach().item() + 1) + num_pred_classes = int(pred.max().detach().item() + 1) + num_all_classes = max(num_target_classes, num_pred_classes) + + if num_classes is None: + num_classes = num_all_classes + elif num_classes != num_all_classes: + rank_zero_warn(f'You have set {num_classes} number of classes which is' + f' different from predicted ({num_pred_classes}) and' + f' target ({num_target_classes}) number of classes', + RuntimeWarning) + return num_classes + + def stat_scores( - pred: torch.Tensor, - target: torch.Tensor, - class_index: int, - argmax_dim: int = 1, + pred: torch.Tensor, + target: torch.Tensor, + class_index: int, argmax_dim: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculates the number of true positive, false positive, true negative @@ -63,11 +148,11 @@ def stat_scores( def stat_scores_multiple_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - argmax_dim: int = 1, - reduction: str = "none", + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + argmax_dim: int = 1, + reduction: str = 'none', ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculates the number of true positive, false positive, true negative @@ -116,13 +201,13 @@ def stat_scores_multiple_classes( if target.dtype != torch.bool: target = target.clamp_max(max=num_classes) - possible_reductions = ("none", "sum", "elementwise_mean") + possible_reductions = ('none', 'sum', 'elementwise_mean') if reduction not in possible_reductions: raise ValueError("reduction type %s not supported" % reduction) - if reduction == "none": - pred = pred.view((-1,)).long() - target = target.view((-1,)).long() + if reduction == 'none': + pred = pred.view((-1, )).long() + target = target.view((-1, )).long() tps = torch.zeros((num_classes + 1,), device=pred.device) fps = torch.zeros((num_classes + 1,), device=pred.device) @@ -145,7 +230,7 @@ def stat_scores_multiple_classes( fns = fns[:num_classes] sups = sups[:num_classes] - elif reduction == "sum" or reduction == "elementwise_mean": + elif reduction == 'sum' or reduction == 'elementwise_mean': count_match_true = (pred == target).sum().float() oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim) @@ -155,7 +240,7 @@ def stat_scores_multiple_classes( tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn) sups = pred.nelement() - oob_sup.float() - if reduction == "elementwise_mean": + if reduction == 'elementwise_mean': tps /= num_classes fps /= num_classes fns /= num_classes @@ -171,17 +256,17 @@ def _confmat_normalize(cm): nan_elements = cm[torch.isnan(cm)].nelement() if nan_elements != 0: cm[torch.isnan(cm)] = 0 - rank_zero_warn(f"{nan_elements} nan values found in confusion matrix have been replaced with zeros.") + rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') return cm def precision_recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = "micro", - return_support: bool = False, - return_state: bool = False, + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', + return_support: bool = False, + return_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds @@ -217,17 +302,17 @@ def precision_recall( precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction) if return_state: - return {"tps": tps, "fps": fps, "fns": fns, "sups": sups} + return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups} if return_support: return precision, recall, sups return precision, recall def precision( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = "micro", + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes precision score. @@ -254,14 +339,15 @@ def precision( tensor(0.7500) """ - return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] + return precision_recall(pred=pred, target=target, + num_classes=num_classes, class_reduction=class_reduction)[0] def recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = "micro", + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes recall score. @@ -287,14 +373,15 @@ def recall( >>> recall(x, y) tensor(0.7500) """ - return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] + return precision_recall(pred=pred, target=target, + num_classes=num_classes, class_reduction=class_reduction)[1] def _binary_clf_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1.0, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py @@ -313,7 +400,7 @@ def _binary_clf_curve( if sample_weight is not None: weight = sample_weight[desc_score_indices] else: - weight = 1.0 + weight = 1. # pred typically has many tied values. Here we extract # the indices associated with the distinct values. We also @@ -334,18 +421,15 @@ def _binary_clf_curve( return fps, tps, pred[threshold_idxs] -# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1.0, +def roc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - .. warning:: Deprecated - Args: pred: estimated probabilities target: ground-truth labels @@ -359,7 +443,7 @@ def __roc( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = __roc(x, y) + >>> fpr, tpr, thresholds = roc(x, y) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr @@ -368,7 +452,9 @@ def __roc( tensor([4, 3, 2, 1, 0]) """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) + fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, + sample_weight=sample_weight, + pos_label=pos_label) # Add an extra threshold position # to make sure that the curve starts at (0, 0) @@ -389,18 +475,15 @@ def __roc( return fpr, tpr, thresholds -# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __multiclass_roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, +def multiclass_roc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. - .. warning:: Deprecated - Args: pred: estimated probabilities target: ground-truth labels @@ -418,7 +501,7 @@ def __multiclass_roc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> __multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), @@ -430,12 +513,124 @@ def __multiclass_roc( for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) + class_roc_vals.append(roc(pred=pred_c, target=target, + sample_weight=sample_weight, pos_label=c)) return tuple(class_roc_vals) -def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True) -> torch.Tensor: +def precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes precision-recall pairs for different thresholds. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class + + Return: + precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target) + >>> precision + tensor([0.6667, 0.5000, 0.0000, 1.0000]) + >>> recall + tensor([1.0000, 0.5000, 0.0000, 0.0000]) + >>> thresholds + tensor([1, 2, 3]) + + """ + fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, + sample_weight=sample_weight, + pos_label=pos_label) + + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained + # and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([reversed(precision[sl]), + torch.ones(1, dtype=precision.dtype, + device=precision.device)]) + + recall = torch.cat([reversed(recall[sl]), + torch.zeros(1, dtype=recall.dtype, + device=recall.device)]) + + thresholds = torch.tensor(reversed(thresholds[sl])) + + return precision, recall, thresholds + + +def multiclass_precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes precision-recall pairs for different thresholds given a multiclass scores. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weight + num_classes: number of classes + + Return: + number of classes, precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target) + >>> nb_classes + (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) + >>> precision + (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) + >>> recall + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) + """ + num_classes = get_num_classes(pred, target, num_classes) + + class_pr_vals = [] + for c in range(num_classes): + pred_c = pred[:, c] + + class_pr_vals.append(precision_recall_curve( + pred=pred_c, + target=target, + sample_weight=sample_weight, pos_label=c)) + + return tuple(class_pr_vals) + + +def auc( + x: torch.Tensor, + y: torch.Tensor, + reorder: bool = True +) -> torch.Tensor: """ Computes Area Under the Curve (AUC) using the trapezoidal rule @@ -456,16 +651,14 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True) -> torch.Tensor: >>> auc(x, y) tensor(4.) """ - direction = 1.0 + direction = 1. if reorder: - rank_zero_warn( - "The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1" - " Note that when `reorder` is True, the unstable algorithm of torch.argsort is" - " used internally to sort 'x' which may in some cases cause inaccuracies" - " in the result.", - DeprecationWarning, - ) + rank_zero_warn("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1" + " Note that when `reorder` is True, the unstable algorithm of torch.argsort is" + " used internally to sort 'x' which may in some cases cause inaccuracies" + " in the result.", + DeprecationWarning) # can't use lexsort here since it is not implemented for torch order = torch.argsort(x) x, y = x[order], y[order] @@ -473,12 +666,11 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True) -> torch.Tensor: dx = x[1:] - x[:-1] if (dx < 0).any(): if (dx, 0).all(): - direction = -1.0 + direction = -1. else: # TODO: Update message on removing reorder - raise ValueError( - "Reorder is not turned on, and the 'x' array is" f" neither increasing or decreasing: {x}" - ) + raise ValueError("Reorder is not turned on, and the 'x' array is" + f" neither increasing or decreasing: {x}") return direction * torch.trapz(y, x) @@ -513,10 +705,10 @@ def new_func(*args, **kwargs) -> torch.Tensor: def auroc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1.0, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., ) -> torch.Tensor: """ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores @@ -538,24 +730,22 @@ def auroc( tensor(0.5000) """ if any(target > 1): - raise ValueError( - "AUROC metric is meant for binary classification, but" - " target tensor contains value different from 0 and 1." - " Use `multiclass_auroc` for multi class classification." - ) + raise ValueError('AUROC metric is meant for binary classification, but' + ' target tensor contains value different from 0 and 1.' + ' Use `multiclass_auroc` for multi class classification.') @auc_decorator(reorder=True) def _auroc(pred, target, sample_weight, pos_label): - return __roc(pred, target, sample_weight, pos_label) + return roc(pred, target, sample_weight, pos_label) return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) def multiclass_auroc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, ) -> torch.Tensor: """ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass @@ -577,44 +767,77 @@ def multiclass_auroc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> multiclass_auroc(pred, target, num_classes=4) + >>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE tensor(0.6667) """ if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): raise ValueError( "Multiclass AUROC metric expects the target scores to be" - " probabilities, i.e. they should sum up to 1.0 over classes" - ) + " probabilities, i.e. they should sum up to 1.0 over classes") if torch.unique(target).size(0) != pred.size(1): raise ValueError( f"Number of classes found in in 'target' ({torch.unique(target).size(0)})" f" does not equal the number of columns in 'pred' ({pred.size(1)})." " Multiclass AUROC is not defined when all of the classes do not" - " occur in the target labels." - ) + " occur in the target labels.") if num_classes is not None and num_classes != pred.size(1): raise ValueError( f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal" - f" the number of classes passed in 'num_classes' ({num_classes})." - ) + f" the number of classes passed in 'num_classes' ({num_classes}).") @multiclass_auc_decorator(reorder=False) def _multiclass_auroc(pred, target, sample_weight, num_classes): - return __multiclass_roc(pred, target, sample_weight, num_classes) + return multiclass_roc(pred, target, sample_weight, num_classes) - class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes) + class_aurocs = _multiclass_auroc(pred=pred, target=target, + sample_weight=sample_weight, + num_classes=num_classes) return torch.mean(class_aurocs) +def average_precision( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> torch.Tensor: + """ + Compute average precision from prediction scores + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class + + Return: + Tensor containing average precision score + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> average_precision(x, y) + tensor(0.3333) + """ + precision, recall, _ = precision_recall_curve(pred=pred, target=target, + sample_weight=sample_weight, + pos_label=pos_label) + # Return the step function integral + # The following works because the last entry of precision is + # guaranteed to be 1, as returned by precision_recall_curve + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + def dice_score( - pred: torch.Tensor, - target: torch.Tensor, - bg: bool = False, - nan_score: float = 0.0, - no_fg_score: float = 0.0, - reduction: str = "elementwise_mean", + pred: torch.Tensor, + target: torch.Tensor, + bg: bool = False, + nan_score: float = 0.0, + no_fg_score: float = 0.0, + reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Compute dice score from prediction scores @@ -646,7 +869,7 @@ def dice_score( """ num_classes = pred.shape[1] - bg = 1 - int(bool(bg)) + bg = (1 - int(bool(bg))) scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) for i in range(bg, num_classes): if not (target == i).any(): @@ -664,12 +887,12 @@ def dice_score( def iou( - pred: torch.Tensor, - target: torch.Tensor, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - num_classes: Optional[int] = None, - reduction: str = "elementwise_mean", + pred: torch.Tensor, + target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Intersection over union, or Jaccard index calculation. @@ -741,11 +964,9 @@ def iou( # Remove the ignored class index from the scores. if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: - scores = torch.cat( - [ - scores[:ignore_index], - scores[ignore_index + 1 :], - ] - ) + scores = torch.cat([ + scores[:ignore_index], + scores[ignore_index + 1:], + ]) return reduce(scores, reduction=reduction) From c820060b2a42698be8859f0bd3bebfdb3e8a66c0 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:56:03 +0100 Subject: [PATCH 56/94] Remove formatting 3/3 --- .../metrics/functional/classification.py | 252 ++---------------- 1 file changed, 15 insertions(+), 237 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index a346a75be0ce8..1ffeb14fdee82 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,98 +15,12 @@ from typing import Callable, Optional, Sequence, Tuple import torch -from pytorch_lightning.metrics.functional.reduction import class_reduce, reduce from torch.nn import functional as F +from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce from pytorch_lightning.utilities import rank_zero_warn -def to_onehot( - tensor: torch.Tensor, - num_classes: Optional[int] = None, -) -> torch.Tensor: - """ - Converts a dense label tensor to one-hot format - - Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] - num_classes: number of classes C - - Output: - A sparse label tensor with shape [N, C, d1, d2, ...] - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> to_onehot(x) - tensor([[0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]]) - - """ - if num_classes is None: - num_classes = int(tensor.max().detach().item() + 1) - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) - return tensor_onehot.scatter_(1, index, 1.0) - - -def to_categorical( - tensor: torch.Tensor, - argmax_dim: int = 1 -) -> torch.Tensor: - """ - Converts a tensor of probabilities to a dense label tensor - - Args: - tensor: probabilities to get the categorical label [N, d1, d2, ...] - argmax_dim: dimension to apply - - Return: - A tensor with categorical labels [N, d2, ...] - - Example: - - >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) - >>> to_categorical(x) - tensor([1, 0]) - - """ - return torch.argmax(tensor, dim=argmax_dim) - - -def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, -) -> int: - """ - Calculates the number of classes for a given prediction and target tensor. - - Args: - pred: predicted values - target: true labels - num_classes: number of classes if known - - Return: - An integer that represents the number of classes. - """ - num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(pred.max().detach().item() + 1) - num_all_classes = max(num_target_classes, num_pred_classes) - - if num_classes is None: - num_classes = num_all_classes - elif num_classes != num_all_classes: - rank_zero_warn(f'You have set {num_classes} number of classes which is' - f' different from predicted ({num_pred_classes}) and' - f' target ({num_target_classes}) number of classes', - RuntimeWarning) - return num_classes - - def stat_scores( pred: torch.Tensor, target: torch.Tensor, @@ -421,7 +335,8 @@ def _binary_clf_curve( return fps, tps, pred[threshold_idxs] -def roc( +# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py +def __roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -430,6 +345,8 @@ def roc( """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. + .. warning:: Deprecated + Args: pred: estimated probabilities target: ground-truth labels @@ -443,7 +360,7 @@ def roc( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = roc(x, y) + >>> fpr, tpr, thresholds = __roc(x, y) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr @@ -475,7 +392,8 @@ def roc( return fpr, tpr, thresholds -def multiclass_roc( +# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py +def __multiclass_roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -484,6 +402,8 @@ def multiclass_roc( """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. + .. warning:: Deprecated + Args: pred: estimated probabilities target: ground-truth labels @@ -501,7 +421,7 @@ def multiclass_roc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> __multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), @@ -513,119 +433,11 @@ def multiclass_roc( for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(roc(pred=pred_c, target=target, - sample_weight=sample_weight, pos_label=c)) + class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) return tuple(class_roc_vals) -def precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes precision-recall pairs for different thresholds. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - precision, recall, thresholds - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - - precision = tps / (tps + fps) - recall = tps / tps[-1] - - # stop when full recall attained - # and reverse the outputs so recall is decreasing - last_ind = torch.where(tps == tps[-1])[0][0] - sl = slice(0, last_ind.item() + 1) - - # need to call reversed explicitly, since including that to slice would - # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), - torch.ones(1, dtype=precision.dtype, - device=precision.device)]) - - recall = torch.cat([reversed(recall[sl]), - torch.zeros(1, dtype=recall.dtype, - device=recall.device)]) - - thresholds = torch.tensor(reversed(thresholds[sl])) - - return precision, recall, thresholds - - -def multiclass_precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes precision-recall pairs for different thresholds given a multiclass scores. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weight - num_classes: number of classes - - Return: - number of classes, precision, recall, thresholds - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target) - >>> nb_classes - (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) - >>> precision - (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) - >>> recall - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) - """ - num_classes = get_num_classes(pred, target, num_classes) - - class_pr_vals = [] - for c in range(num_classes): - pred_c = pred[:, c] - - class_pr_vals.append(precision_recall_curve( - pred=pred_c, - target=target, - sample_weight=sample_weight, pos_label=c)) - - return tuple(class_pr_vals) - - def auc( x: torch.Tensor, y: torch.Tensor, @@ -736,7 +548,7 @@ def auroc( @auc_decorator(reorder=True) def _auroc(pred, target, sample_weight, pos_label): - return roc(pred, target, sample_weight, pos_label) + return __roc(pred, target, sample_weight, pos_label) return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -767,7 +579,7 @@ def multiclass_auroc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> multiclass_auroc(pred, target, num_classes=4) tensor(0.6667) """ if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): @@ -789,7 +601,7 @@ def multiclass_auroc( @multiclass_auc_decorator(reorder=False) def _multiclass_auroc(pred, target, sample_weight, num_classes): - return multiclass_roc(pred, target, sample_weight, num_classes) + return __multiclass_roc(pred, target, sample_weight, num_classes) class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, @@ -797,40 +609,6 @@ def _multiclass_auroc(pred, target, sample_weight, num_classes): return torch.mean(class_aurocs) -def average_precision( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> torch.Tensor: - """ - Compute average precision from prediction scores - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - Tensor containing average precision score - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> average_precision(x, y) - tensor(0.3333) - """ - precision, recall, _ = precision_recall_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - # Return the step function integral - # The following works because the last entry of precision is - # guaranteed to be 1, as returned by precision_recall_curve - return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) - - def dice_score( pred: torch.Tensor, target: torch.Tensor, From b576de011b62aaa3a64e97e6138aa7b08c0d208c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 19:57:39 +0100 Subject: [PATCH 57/94] . --- tests/metrics/functional/test_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 7bb3df9d8e392..9a410dc7636e0 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -380,4 +380,4 @@ def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, ex num_classes=num_classes, reduction=reduction, ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) \ No newline at end of file + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) From dae341b9585a89de197f37f927cea78d9c3e8b24 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 20:06:37 +0100 Subject: [PATCH 58/94] Improve description of top_k parameter --- pytorch_lightning/metrics/classification/accuracy.py | 2 +- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 7c4bcb247c766..6045c9fba3065 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -46,7 +46,7 @@ class Accuracy(Metric): Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 top_k: - Number of highest probability entries for each sample to convert to 1s, relevant + Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The default value (``None``) will be interpreted as 1 for these inputs. diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index ca88213706c4f..0664b00be44ae 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -79,7 +79,7 @@ def accuracy( Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 top_k: - Number of highest probability entries for each sample to convert to 1s, relevant + Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The default value (``None``) will be interpreted as 1 for these inputs. From b2d2b715593d58f54ab1aa55dc0fff3483545881 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 7 Dec 2020 20:13:04 +0100 Subject: [PATCH 59/94] Apply suggestions from code review --- pytorch_lightning/metrics/classification/accuracy.py | 3 ++- pytorch_lightning/metrics/classification/hamming_loss.py | 3 ++- pytorch_lightning/metrics/functional/accuracy.py | 2 +- pytorch_lightning/metrics/functional/hamming_loss.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 6045c9fba3065..435bc4f788809 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -44,7 +44,7 @@ class Accuracy(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + `(0,1)` predictions, in the case of binary or multi-label inputs. Default: `0.5` top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -110,6 +110,7 @@ def __init__( self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" self.threshold = threshold self.top_k = top_k self.mdmc_accuracy = mdmc_accuracy diff --git a/pytorch_lightning/metrics/classification/hamming_loss.py b/pytorch_lightning/metrics/classification/hamming_loss.py index c4fc74eeb6a0e..92e3c98a6ea70 100644 --- a/pytorch_lightning/metrics/classification/hamming_loss.py +++ b/pytorch_lightning/metrics/classification/hamming_loss.py @@ -40,7 +40,7 @@ class HammingLoss(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + `(0,1)` predictions, in the case of binary or multi-label inputs. Default: `0.5` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -81,6 +81,7 @@ def __init__( self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" self.threshold = threshold def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 0664b00be44ae..87d88131a0575 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -109,6 +109,6 @@ def accuracy( >>> accuracy(preds, target, top_k=2) tensor(0.6667) """ - + assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) return _accuracy_compute(correct, total) diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py index 9ef322aa841e4..d13aac6bcf45b 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -63,6 +63,6 @@ def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0 tensor(0.2500) """ - + assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" correct, total = _hamming_loss_update(preds, target, threshold) return _hamming_loss_compute(correct, total) From 9b2a399732630aa17f821d57487205aadf53bcf5 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 20:20:18 +0100 Subject: [PATCH 60/94] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/accuracy.py | 8 ++++---- pytorch_lightning/metrics/functional/accuracy.py | 2 +- pytorch_lightning/metrics/functional/hamming_loss.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 435bc4f788809..02b0767db67ad 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -35,8 +35,8 @@ class Accuracy(Metric): This metric generalizes to subset accuracy for multilabel data: for the sample to be counted as correct, all labels in that sample have to be correctly predicted. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` - is this is not what you want. In multi-dimensional multi-class case the `mdmc_accuracy` parameters - gives you a choice between computing the subset accuracy, or counting each sample on the extra + is this is not what you want. In a multi-dimensional multi-class case, the `mdmc_accuracy` parameters + gives you a choice between computing the subset accuracy or counting each sample on the extra axis separately. Accepts all input types listed in :ref:`metrics:Input types`. @@ -52,13 +52,13 @@ class Accuracy(Metric): Should be left at default (``None``) for all other types of inputs. mdmc_accuracy: - Determines how should the extra dimension be handeled in case of multi-dimensional multi-class + Determines how should the extra dimension be handled in case of multi-dimensional multi-class inputs. Options are ``"global"`` or ``"subset"``. If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension were unrolled into a new sample dimension. - If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the + If ``"subset"``, then the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension must be predicted correctly (the ``top_k`` option still applies here). The final score is then simply the number of totally correctly predicted samples. diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 87d88131a0575..d2fa6df175db8 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -91,7 +91,7 @@ def accuracy( If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension were unrolled into a new sample dimension. - If ``"subset"``, than the equivalent of subset accuracy is performed for each sample on the + If ``"subset"``, then the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension must be predicted correctly (the ``top_k`` option still applies here). The final score is then simply the number of totally correctly predicted samples. diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py index d13aac6bcf45b..b708428f59085 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -63,6 +63,6 @@ def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0 tensor(0.2500) """ - assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" + assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" correct, total = _hamming_loss_update(preds, target, threshold) return _hamming_loss_compute(correct, total) From 0952df21b51155a2d6b373faa1861cbf2c2c549b Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 20:21:22 +0100 Subject: [PATCH 61/94] Remove unneeded assert --- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index d2fa6df175db8..4fc7c184d523e 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -109,6 +109,6 @@ def accuracy( >>> accuracy(preds, target, top_k=2) tensor(0.6667) """ - assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" + correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) return _accuracy_compute(correct, total) From c7fe698044b30bc1413d11bc890aeae711122661 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 20:22:37 +0100 Subject: [PATCH 62/94] Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 4fc7c184d523e..6df325917e085 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -85,7 +85,7 @@ def accuracy( Should be left at default (``None``) for all other types of inputs. mdmc_accuracy: - Determines how should the extra dimension be handeled in case of multi-dimensional multi-class + Determines how should the extra dimension be handled in case of multi-dimensional multi-class inputs. Options are ``"global"`` or ``"subset"``. If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension From e2bc0abee62fa854897cefcb5df88433964659d0 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 20:23:35 +0100 Subject: [PATCH 63/94] Remove unneeded assert --- pytorch_lightning/metrics/functional/hamming_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_loss.py index b708428f59085..9ef322aa841e4 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_loss.py @@ -63,6 +63,6 @@ def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0 tensor(0.2500) """ - assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" + correct, total = _hamming_loss_update(preds, target, threshold) return _hamming_loss_compute(correct, total) From 8801f8a6cb435907e309d68ddb746565df13549c Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 20:32:12 +0100 Subject: [PATCH 64/94] Explicit checking of parameter values --- pytorch_lightning/metrics/classification/accuracy.py | 8 +++++++- pytorch_lightning/metrics/functional/accuracy.py | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 02b0767db67ad..ad953264935ca 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -110,9 +110,15 @@ def __init__( self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" + if not 0 <= threshold <= 1: + raise ValueError("The `threshold` should lie in the [0,1] interval.") + self.threshold = threshold self.top_k = top_k + + if mdmc_accuracy not in ["global", "subset"]: + raise ValueError("The `mdmc_accuracy` should be either 'subset' or 'global'.") + self.mdmc_accuracy = mdmc_accuracy def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 6df325917e085..d9c30dca241cd 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -110,5 +110,8 @@ def accuracy( tensor(0.6667) """ + if mdmc_accuracy not in ["global", "subset"]: + raise ValueError("The `mdmc_accuracy` should be either 'subset' or 'global'.") + correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) return _accuracy_compute(correct, total) From c32b36e8024cd0011426b729f44c16fc8d1fccf4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 7 Dec 2020 21:19:35 +0100 Subject: [PATCH 65/94] Apply suggestions from code review Co-authored-by: Nicki Skafte --- pytorch_lightning/metrics/classification/accuracy.py | 2 +- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index ad953264935ca..0bef35480dae1 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -94,7 +94,7 @@ def __init__( self, threshold: float = 0.5, top_k: Optional[int] = None, - mdmc_accuracy: str = "subset", + mdmc_accuracy: str = "global", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index d9c30dca241cd..9b6715e5f0bb7 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -37,7 +37,7 @@ def _accuracy_update( correct = (sample_correct == sample_total).sum() total = target.shape[0] - return (torch.tensor(correct, device=preds.device), torch.tensor(total, device=preds.device)) + return correct, total def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: From 0314c7d39cd250ff2f1e8761b58a2f0367e17d24 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 7 Dec 2020 23:47:38 +0100 Subject: [PATCH 66/94] Apply suggestions from code review --- .../metrics/classification/accuracy.py | 7 ++- .../metrics/classification/hamming_loss.py | 3 +- .../metrics/functional/accuracy.py | 10 +++- tests/metrics/classification/test_accuracy.py | 56 ++++++++++++------- 4 files changed, 49 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 0bef35480dae1..06b2ab8338db6 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -56,7 +56,8 @@ class Accuracy(Metric): inputs. Options are ``"global"`` or ``"subset"``. If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension - were unrolled into a new sample dimension. + were unrolled into a new sample dimension. If predictions are labels, this option is equivalent + to first flattening ``preds`` and ``target``, and then computing accuracy. If ``"subset"``, then the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension @@ -112,8 +113,10 @@ def __init__( if not 0 <= threshold <= 1: raise ValueError("The `threshold` should lie in the [0,1] interval.") - self.threshold = threshold + + if top_k <= 0: + raise ValueError("The `top_k` should be an integer larger than 1.") self.top_k = top_k if mdmc_accuracy not in ["global", "subset"]: diff --git a/pytorch_lightning/metrics/classification/hamming_loss.py b/pytorch_lightning/metrics/classification/hamming_loss.py index 92e3c98a6ea70..bcc4fb8c5e668 100644 --- a/pytorch_lightning/metrics/classification/hamming_loss.py +++ b/pytorch_lightning/metrics/classification/hamming_loss.py @@ -81,7 +81,8 @@ def __init__( self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - assert 0 <= threshold <= 1, f"threshold: {threshold} is out of range" + if not 0 <= threshold <= 1: + raise ValueError("The `threshold` should lie in the [0,1] interval.") self.threshold = threshold def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 9b6715e5f0bb7..05e0f216a4fbe 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -25,7 +25,7 @@ def _accuracy_update( if mode in ["binary", "multi-label"]: correct = (preds == target).all(dim=1).sum() - total = target.shape[0] + total = torch.tensor(target.shape[0], device=target.device) elif mdmc_accuracy == "global": correct = (preds * target).sum() total = target.sum() @@ -35,7 +35,7 @@ def _accuracy_update( sample_total = target.sum(dim=extra_dims) correct = (sample_correct == sample_total).sum() - total = target.shape[0] + total = torch.tensor(target.shape[0], device=target.device) return correct, total @@ -89,7 +89,8 @@ def accuracy( inputs. Options are ``"global"`` or ``"subset"``. If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension - were unrolled into a new sample dimension. + were unrolled into a new sample dimension. If predictions are labels, this option is equivalent + to first flattening ``preds`` and ``target``, and then computing accuracy. If ``"subset"``, then the equivalent of subset accuracy is performed for each sample on the ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension @@ -113,5 +114,8 @@ def accuracy( if mdmc_accuracy not in ["global", "subset"]: raise ValueError("The `mdmc_accuracy` should be either 'subset' or 'global'.") + if top_k <= 0: + raise ValueError("The `top_k` should be an integer larger than 1.") + correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) return _accuracy_compute(correct, total) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 9af19474064cf..e29939f779e70 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import pytest import torch @@ -23,50 +25,62 @@ torch.manual_seed(42) -def _sk_accuracy(preds, target): - sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) +def _sk_accuracy(preds, target, mdmc_accuracy): + sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) + if mode == "multi-dim multi-class": + if mdmc_accuracy == "global": + sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) + sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) + elif mdmc_accuracy == "subset": + return np.all(sk_preds == sk_target, axis=(1, 2)).mean() + return sk_accuracy(y_true=sk_target, y_pred=sk_preds) @pytest.mark.parametrize( - "preds, target", + "preds, target, mdmc_accuracy", [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target), - (_binary_inputs.preds, _binary_inputs.target), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), - (_multilabel_inputs.preds, _multilabel_inputs.target), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target), - (_multiclass_inputs.preds, _multiclass_inputs.target), - (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), - (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), - (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), + (_binary_prob_inputs.preds, _binary_prob_inputs.target, "global"), + (_binary_inputs.preds, _binary_inputs.target, "global"), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, "global"), + (_multilabel_inputs.preds, _multilabel_inputs.target, "subset"), # As this is treated as MDMC by default + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, "global"), + (_multiclass_inputs.preds, _multiclass_inputs.target, "global"), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, "global"), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, "subset"), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, "global"), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, "subset"), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, "global"), + ( + _multilabel_multidim_inputs.preds, + _multilabel_multidim_inputs.target, + "subset", # As this is treated as MDMC by default + ), ], ) class TestAccuracies(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target): + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, mdmc_accuracy): self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=Accuracy, - sk_metric=_sk_accuracy, + sk_metric=partial(_sk_accuracy, mdmc_accuracy=mdmc_accuracy), dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD}, + metric_args={"threshold": THRESHOLD, "mdmc_accuracy": mdmc_accuracy}, ) - def test_accuracy_fn(self, preds, target): + def test_accuracy_fn(self, preds, target, mdmc_accuracy): self.run_functional_metric_test( preds, target, metric_functional=accuracy, - sk_metric=_sk_accuracy, - metric_args={"threshold": THRESHOLD}, + sk_metric=partial(_sk_accuracy, mdmc_accuracy=mdmc_accuracy), + metric_args={"threshold": THRESHOLD, "mdmc_accuracy": mdmc_accuracy}, ) From 152cadfe27f55ced79100215c90192778df460a0 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 00:30:09 +0100 Subject: [PATCH 67/94] Fix top_k checking --- pytorch_lightning/metrics/classification/accuracy.py | 2 +- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 06b2ab8338db6..722b0bcf15dca 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -115,7 +115,7 @@ def __init__( raise ValueError("The `threshold` should lie in the [0,1] interval.") self.threshold = threshold - if top_k <= 0: + if top_k is not None and top_k <= 0: raise ValueError("The `top_k` should be an integer larger than 1.") self.top_k = top_k diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 05e0f216a4fbe..bef192f82f625 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -114,7 +114,7 @@ def accuracy( if mdmc_accuracy not in ["global", "subset"]: raise ValueError("The `mdmc_accuracy` should be either 'subset' or 'global'.") - if top_k <= 0: + if top_k is not None and top_k <= 0: raise ValueError("The `top_k` should be an integer larger than 1.") correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) From 022d6a604f6f0cfe657d8a96b0804e61d5e920b1 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 00:35:59 +0100 Subject: [PATCH 68/94] PEP8 --- tests/metrics/classification/test_accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index e29939f779e70..93043ca514ad3 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -35,7 +35,7 @@ def _sk_accuracy(preds, target, mdmc_accuracy): sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) elif mdmc_accuracy == "subset": return np.all(sk_preds == sk_target, axis=(1, 2)).mean() - + return sk_accuracy(y_true=sk_target, y_pred=sk_preds) From 9efc9634ba8a2bd694fa309ff20195526bebed8b Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 01:24:59 +0100 Subject: [PATCH 69/94] Don't check dist_sync in test --- tests/metrics/classification/test_accuracy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 93043ca514ad3..4b263b68ddd9c 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -72,6 +72,8 @@ def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, mdmc_accura sk_metric=partial(_sk_accuracy, mdmc_accuracy=mdmc_accuracy), dist_sync_on_step=dist_sync_on_step, metric_args={"threshold": THRESHOLD, "mdmc_accuracy": mdmc_accuracy}, + check_dist_sync_on_step=False, + check_batch=False ) def test_accuracy_fn(self, preds, target, mdmc_accuracy): From d992f7df0b0b02836a7fbde72616f2b02deb34b7 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 01:41:45 +0100 Subject: [PATCH 70/94] add back check_dist_sync_on_step --- tests/metrics/classification/test_accuracy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 4b263b68ddd9c..93043ca514ad3 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -72,8 +72,6 @@ def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, mdmc_accura sk_metric=partial(_sk_accuracy, mdmc_accuracy=mdmc_accuracy), dist_sync_on_step=dist_sync_on_step, metric_args={"threshold": THRESHOLD, "mdmc_accuracy": mdmc_accuracy}, - check_dist_sync_on_step=False, - check_batch=False ) def test_accuracy_fn(self, preds, target, mdmc_accuracy): From a7260603903f6ee19190b0940e54285324f73531 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 12:04:28 +0100 Subject: [PATCH 71/94] Make sure half-precision inputs are transformed (#5013) --- pytorch_lightning/metrics/classification/helpers.py | 5 +++++ tests/metrics/classification/test_inputs.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index afb97e6e0a74f..71fba4a12ffb1 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -404,6 +404,11 @@ def _input_format_classification( else: preds, target = preds.squeeze(), target.squeeze() + # Convert half precision tensors to full precision, as not all ops are supported + # print(acc(preds.half(), target)) - for example, min() is not supported + if preds.dtype == torch.float16: + preds = preds.float() + case = _check_classification_inputs( preds, target, diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index c4d01d282fa57..87a2352dbed33 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -22,6 +22,8 @@ torch.manual_seed(42) # Some additional inputs to test on +_ml_prob_half = Input(_ml_prob.preds.half(), _ml_prob.target) + _mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) _mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) _mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) @@ -133,6 +135,8 @@ def _mlmd_prob_to_mc_preds_tr(x): (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), ########################### # Test some special cases + # Make sure that half precision works, i.e. is converted to full precision + (_ml_prob_half, None, None, None, "multi-label", _ml_preds_tr, _rshp1), # Binary as multiclass (_bin, None, None, None, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass From 93c5d02eac973cb02114b61eee91eac6b04e8e2a Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 12:06:48 +0100 Subject: [PATCH 72/94] Fix typo --- pytorch_lightning/metrics/classification/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 71fba4a12ffb1..69594dc8477f4 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -405,7 +405,7 @@ def _input_format_classification( preds, target = preds.squeeze(), target.squeeze() # Convert half precision tensors to full precision, as not all ops are supported - # print(acc(preds.half(), target)) - for example, min() is not supported + # for example, min() is not supported if preds.dtype == torch.float16: preds = preds.float() From 0813055877cacec43feb3e85f1fbef46547b7a67 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 12:14:04 +0100 Subject: [PATCH 73/94] Rename hamming loss to hamming distance --- docs/source/metrics.rst | 8 +++---- pytorch_lightning/metrics/__init__.py | 2 +- .../metrics/classification/__init__.py | 2 +- .../{hamming_loss.py => hamming_distance.py} | 22 +++++++++---------- .../metrics/functional/__init__.py | 2 +- .../{hamming_loss.py => hamming_distance.py} | 20 ++++++++--------- ...mming_loss.py => test_hamming_distance.py} | 14 ++++++------ 7 files changed, 35 insertions(+), 35 deletions(-) rename pytorch_lightning/metrics/classification/{hamming_loss.py => hamming_distance.py} (81%) rename pytorch_lightning/metrics/functional/{hamming_loss.py => hamming_distance.py} (72%) rename tests/metrics/classification/{test_hamming_loss.py => test_hamming_distance.py} (85%) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 2d4aeff1b087f..366a76d6916ba 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -292,10 +292,10 @@ FBeta .. autoclass:: pytorch_lightning.metrics.classification.FBeta :noindex: -Hamming Loss +Hamming Distance ~~~~~~~~~~~~ -.. autoclass:: pytorch_lightning.metrics.classification.HammingLoss +.. autoclass:: pytorch_lightning.metrics.classification.HammingDistance :noindex: Precision @@ -387,10 +387,10 @@ fbeta [func] .. autofunction:: pytorch_lightning.metrics.functional.fbeta :noindex: -hamming_loss [func] +hamming_distance [func] ~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.hamming_loss +.. autofunction:: pytorch_lightning.metrics.functional.hamming_distance :noindex: iou [func] diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 9c4bf80ae0e51..91bf23ce69d30 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -15,7 +15,7 @@ from pytorch_lightning.metrics.classification import ( Accuracy, - HammingLoss, + HammingDistance, Precision, Recall, ConfusionMatrix, diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 15eb3a8b2ad91..b4689c6d4f9ab 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -15,7 +15,7 @@ from pytorch_lightning.metrics.classification.average_precision import AveragePrecision from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 -from pytorch_lightning.metrics.classification.hamming_loss import HammingLoss +from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve from pytorch_lightning.metrics.classification.roc import ROC diff --git a/pytorch_lightning/metrics/classification/hamming_loss.py b/pytorch_lightning/metrics/classification/hamming_distance.py similarity index 81% rename from pytorch_lightning/metrics/classification/hamming_loss.py rename to pytorch_lightning/metrics/classification/hamming_distance.py index bcc4fb8c5e668..1cb8561806976 100644 --- a/pytorch_lightning/metrics/classification/hamming_loss.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -15,16 +15,16 @@ import torch from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.functional.hamming_loss import _hamming_loss_update, _hamming_loss_compute +from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_update, _hamming_distance_compute -class HammingLoss(Metric): +class HammingDistance(Metric): r""" - Computes the average Hamming loss or `Hamming distance `_ - between targets and predictions: + Computes the average `Hamming distance `_ (also + known as Hamming loss) between targets and predictions: .. math:: - \text{Hamming loss} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + \text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that @@ -54,11 +54,11 @@ class HammingLoss(Metric): Example: - >>> from pytorch_lightning.metrics import HammingLoss + >>> from pytorch_lightning.metrics import HammingDistance >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_loss = HammingLoss() - >>> hamming_loss(preds, target) + >>> hamming_distance = HammingDistance() + >>> hamming_distance(preds, target) tensor(0.2500) """ @@ -94,13 +94,13 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): preds: Predictions from model (probabilities, or labels) target: Ground truth values """ - correct, total = _hamming_loss_update(preds, target, self.threshold) + correct, total = _hamming_distance_update(preds, target, self.threshold) self.correct += correct self.total += total def compute(self) -> torch.Tensor: """ - Computes hamming loss based on inputs passed in to ``update`` previously. + Computes hamming distance based on inputs passed in to ``update`` previously. """ - return _hamming_loss_compute(self.correct, self.total) + return _hamming_distance_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index f7d5bf7189353..f203b3502b381 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -29,7 +29,7 @@ # TODO: unify metrics between class and functional, add below from pytorch_lightning.metrics.functional.explained_variance import explained_variance from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 -from pytorch_lightning.metrics.functional.hamming_loss import hamming_loss +from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error diff --git a/pytorch_lightning/metrics/functional/hamming_loss.py b/pytorch_lightning/metrics/functional/hamming_distance.py similarity index 72% rename from pytorch_lightning/metrics/functional/hamming_loss.py rename to pytorch_lightning/metrics/functional/hamming_distance.py index 9ef322aa841e4..1f8a28842c907 100644 --- a/pytorch_lightning/metrics/functional/hamming_loss.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -17,7 +17,7 @@ from pytorch_lightning.metrics.classification.helpers import _input_format_classification -def _hamming_loss_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: +def _hamming_distance_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: preds, target, _ = _input_format_classification(preds, target, threshold=threshold) correct = (preds == target).sum() @@ -26,17 +26,17 @@ def _hamming_loss_update(preds: torch.Tensor, target: torch.Tensor, threshold: f return correct, total -def _hamming_loss_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: +def _hamming_distance_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: return 1 - correct.float() / total -def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: +def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: r""" - Computes the average Hamming loss or `Hamming distance `_ - between targets and predictions: + Computes the average `Hamming distance `_ (also + known as Hamming loss) between targets and predictions: .. math:: - \text{Hamming loss} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that @@ -56,13 +56,13 @@ def hamming_loss(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0 Example: - >>> from pytorch_lightning.metrics.functional import hamming_loss + >>> from pytorch_lightning.metrics.functional import hamming_distance >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_loss(preds, target) + >>> hamming_distance(preds, target) tensor(0.2500) """ - correct, total = _hamming_loss_update(preds, target, threshold) - return _hamming_loss_compute(correct, total) + correct, total = _hamming_distance_update(preds, target, threshold) + return _hamming_distance_compute(correct, total) diff --git a/tests/metrics/classification/test_hamming_loss.py b/tests/metrics/classification/test_hamming_distance.py similarity index 85% rename from tests/metrics/classification/test_hamming_loss.py rename to tests/metrics/classification/test_hamming_distance.py index aab6fa353b4bc..f08ec27f565a1 100644 --- a/tests/metrics/classification/test_hamming_loss.py +++ b/tests/metrics/classification/test_hamming_distance.py @@ -2,8 +2,8 @@ import torch from sklearn.metrics import hamming_loss as sk_hamming_loss -from pytorch_lightning.metrics import HammingLoss -from pytorch_lightning.metrics.functional import hamming_loss +from pytorch_lightning.metrics import HammingDistance +from pytorch_lightning.metrics.functional import hamming_distance from pytorch_lightning.metrics.classification.helpers import _input_format_classification from tests.metrics.classification.inputs import ( _binary_inputs, @@ -45,25 +45,25 @@ def _sk_hamming_loss(preds, target): (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), ], ) -class TestAccuracies(MetricTester): +class TestHammingDistance(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target): + def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=HammingLoss, + metric_class=HammingDistance, sk_metric=_sk_hamming_loss, dist_sync_on_step=dist_sync_on_step, metric_args={"threshold": THRESHOLD}, ) - def test_accuracy_fn(self, preds, target): + def test_hamming_distance_fn(self, preds, target): self.run_functional_metric_test( preds, target, - metric_functional=hamming_loss, + metric_functional=hamming_distance, sk_metric=_sk_hamming_loss, metric_args={"threshold": THRESHOLD}, ) From 6bf714bf10972d82d9b370575f367bea0d7e36cd Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 12:28:59 +0100 Subject: [PATCH 74/94] Fix tests for half precision --- tests/metrics/classification/test_inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 87a2352dbed33..30d3f06707301 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -136,7 +136,7 @@ def _mlmd_prob_to_mc_preds_tr(x): ########################### # Test some special cases # Make sure that half precision works, i.e. is converted to full precision - (_ml_prob_half, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_ml_prob_half, None, None, None, "multi-label", lambda x: _ml_preds_tr(x.float()), _rshp1), # Binary as multiclass (_bin, None, None, None, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass From d12f1d67f5d7c01fc5886b56c0b24669c49d5daf Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 12:37:44 +0100 Subject: [PATCH 75/94] Fix docs underline length --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 366a76d6916ba..0057091460820 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -293,7 +293,7 @@ FBeta :noindex: Hamming Distance -~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~ .. autoclass:: pytorch_lightning.metrics.classification.HammingDistance :noindex: From a55cb466d8a31378ae97f4454a23a9aa1bfdcda4 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 12:38:07 +0100 Subject: [PATCH 76/94] Fix doc undeline length --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0057091460820..a2f7a8ad142e9 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -388,7 +388,7 @@ fbeta [func] :noindex: hamming_distance [func] -~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: pytorch_lightning.metrics.functional.hamming_distance :noindex: From 6b3b05775b6f5253507ac867ea6f9900d483d427 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 16:49:14 +0100 Subject: [PATCH 77/94] Replace mdmc_accuracy parameter with subset_accuracy --- CHANGELOG.md | 6 +- .../metrics/classification/accuracy.py | 45 +++++---- .../classification/hamming_distance.py | 3 +- .../metrics/functional/accuracy.py | 61 ++++++------ .../metrics/functional/hamming_distance.py | 3 +- tests/metrics/classification/test_accuracy.py | 93 ++++++++++--------- 6 files changed, 103 insertions(+), 108 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0442253a4a89..b4a754ba5249d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) -- `HammingLoss` metric to compute the hamming loss (distance) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) - -### Changed +- `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) -- `Accuracy` metrics now computes subset accuracy for multi-label inputs (consistent with scikit-learn's `accuracy_score`) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) +- `HammingLoss` metric to compute the hamming loss (distance) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) ### Fixed diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 722b0bcf15dca..e3c1d1ac7d211 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -33,11 +33,10 @@ class Accuracy(Metric): parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the top-K highest probability items are considered to find the correct label. - This metric generalizes to subset accuracy for multilabel data: for the sample to be counted as - correct, all labels in that sample have to be correctly predicted. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` - is this is not what you want. In a multi-dimensional multi-class case, the `mdmc_accuracy` parameters - gives you a choice between computing the subset accuracy or counting each sample on the extra - axis separately. + For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" + accuracy by default, which counts all labels or sub-samples separately. This can be + changed to subset accuracy (which requires all labels or sub-samples in the sample to + be correctly predicted) by setting `subset_accuracy=True`. Accepts all input types listed in :ref:`metrics:Input types`. @@ -51,18 +50,21 @@ class Accuracy(Metric): default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. - mdmc_accuracy: - Determines how should the extra dimension be handled in case of multi-dimensional multi-class - inputs. Options are ``"global"`` or ``"subset"``. - - If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension - were unrolled into a new sample dimension. If predictions are labels, this option is equivalent - to first flattening ``preds`` and ``target``, and then computing accuracy. - - If ``"subset"``, then the equivalent of subset accuracy is performed for each sample on the - ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension - must be predicted correctly (the ``top_k`` option still applies here). The final score is then - simply the number of totally correctly predicted samples. + subset_accuracy: + Whether to compute subset accuracy for multi-label and multi-dimensional + multi-class inputs (has no effect for other input types). Default: `False` + + For multi-label inputs, if the parameter is set to `True`, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to `False`, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to `False`, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -95,7 +97,7 @@ def __init__( self, threshold: float = 0.5, top_k: Optional[int] = None, - mdmc_accuracy: str = "global", + subset_accuracy: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -119,10 +121,7 @@ def __init__( raise ValueError("The `top_k` should be an integer larger than 1.") self.top_k = top_k - if mdmc_accuracy not in ["global", "subset"]: - raise ValueError("The `mdmc_accuracy` should be either 'subset' or 'global'.") - - self.mdmc_accuracy = mdmc_accuracy + self.subset_accuracy = subset_accuracy def update(self, preds: torch.Tensor, target: torch.Tensor): """ @@ -135,7 +134,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): """ correct, total = _accuracy_update( - preds, target, threshold=self.threshold, top_k=self.top_k, mdmc_accuracy=self.mdmc_accuracy + preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy ) self.correct += correct diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index 1cb8561806976..34654bd53a320 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -32,8 +32,7 @@ class HammingDistance(Metric): This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. If this is not what you want, consider using - :class:`~pytorch_lightning.metrics.classification.Accuracy`. + treated as if it were multi-label. Accepts all input types listed in :ref:`metrics:Input types`. diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index bef192f82f625..06aaed34faf1f 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -18,23 +18,23 @@ def _accuracy_update( - preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], mdmc_accuracy: str + preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool ) -> Tuple[torch.Tensor, torch.Tensor]: preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) - if mode in ["binary", "multi-label"]: + if mode == "binary" or (mode == "multi-label" and subset_accuracy): correct = (preds == target).all(dim=1).sum() total = torch.tensor(target.shape[0], device=target.device) - elif mdmc_accuracy == "global": + elif mode == "multi-label" and not subset_accuracy: + correct = (preds == target).sum() + total = torch.tensor(target.numel(), device=target.device) + elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy): correct = (preds * target).sum() total = target.sum() - elif mdmc_accuracy == "subset": - extra_dims = list(range(1, len(preds.shape))) - sample_correct = (preds * target).sum(dim=extra_dims) - sample_total = target.sum(dim=extra_dims) - - correct = (sample_correct == sample_total).sum() + elif mode == "multi-dim multi-class" and subset_accuracy: + sample_correct = (preds * target).sum(dim=(1, 2)) + correct = (sample_correct == target.shape[2]).sum() total = torch.tensor(target.shape[0], device=target.device) return correct, total @@ -49,7 +49,7 @@ def accuracy( target: torch.Tensor, threshold: float = 0.5, top_k: Optional[int] = None, - mdmc_accuracy: str = "subset", + subset_accuracy: bool = False, ) -> torch.Tensor: r""" Computes `Accuracy `_: @@ -64,11 +64,10 @@ def accuracy( parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the top-K highest probability items are considered to find the correct label. - This metric generalizes to subset accuracy for multilabel data: for the sample to be counted as - correct, all labels in that sample have to be correctly predicted. Consider using :class:`~pytorch_lightning.metrics.classification.HammingLoss` - is this is not what you want. In multi-dimensional multi-class case the `mdmc_accuracy` parameters - gives you a choice between computing the subset accuracy, or counting each sample on the extra - axis separately. + For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" + accuracy by default, which counts all labels or sub-samples separately. This can be + changed to subset accuracy (which requires all labels or sub-samples in the sample to + be correctly predicted) by setting `subset_accuracy=True`. Accepts all input types listed in :ref:`metrics:Input types`. @@ -84,18 +83,21 @@ def accuracy( default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. - mdmc_accuracy: - Determines how should the extra dimension be handled in case of multi-dimensional multi-class - inputs. Options are ``"global"`` or ``"subset"``. - - If ``"global"``, then the inputs are treated as if the sample (``N``) and the extra dimension - were unrolled into a new sample dimension. If predictions are labels, this option is equivalent - to first flattening ``preds`` and ``target``, and then computing accuracy. - - If ``"subset"``, then the equivalent of subset accuracy is performed for each sample on the - ``N`` dimension - that is, for the sample to count as correct, all labels on its extra dimension - must be predicted correctly (the ``top_k`` option still applies here). The final score is then - simply the number of totally correctly predicted samples. + subset_accuracy: + Whether to compute subset accuracy for multi-label and multi-dimensional + multi-class inputs (has no effect for other input types). Default: `False` + + For multi-label inputs, if the parameter is set to `True`, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to `False`, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to `False`, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. Example: @@ -111,11 +113,8 @@ def accuracy( tensor(0.6667) """ - if mdmc_accuracy not in ["global", "subset"]: - raise ValueError("The `mdmc_accuracy` should be either 'subset' or 'global'.") - if top_k is not None and top_k <= 0: raise ValueError("The `top_k` should be an integer larger than 1.") - correct, total = _accuracy_update(preds, target, threshold, top_k, mdmc_accuracy) + correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) return _accuracy_compute(correct, total) diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index 1f8a28842c907..daca66163ccf3 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -44,8 +44,7 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. If this is not what you want, consider using - :class:`~pytorch_lightning.metrics.classification.Accuracy`. + treated as if it were multi-label. Accepts all input types listed in :ref:`metrics:Input types`. diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 93043ca514ad3..61313119814b3 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -25,62 +25,63 @@ torch.manual_seed(42) -def _sk_accuracy(preds, target, mdmc_accuracy): +def _sk_accuracy(preds, target, subset_accuracy): sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - if mode == "multi-dim multi-class": - if mdmc_accuracy == "global": - sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) - sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) - elif mdmc_accuracy == "subset": - return np.all(sk_preds == sk_target, axis=(1, 2)).mean() + if mode == "multi-dim multi-class" and not subset_accuracy: + sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) + sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) + elif mode == mode == "multi-dim multi-class" and subset_accuracy: + return np.all(sk_preds == sk_target, axis=(1, 2)).mean() + elif mode == "multi-label" and not subset_accuracy: + sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) return sk_accuracy(y_true=sk_target, y_pred=sk_preds) @pytest.mark.parametrize( - "preds, target, mdmc_accuracy", + "preds, target, subset_accuracy", [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, "global"), - (_binary_inputs.preds, _binary_inputs.target, "global"), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, "global"), - (_multilabel_inputs.preds, _multilabel_inputs.target, "subset"), # As this is treated as MDMC by default - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, "global"), - (_multiclass_inputs.preds, _multiclass_inputs.target, "global"), - (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, "global"), - (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, "subset"), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, "global"), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, "subset"), - (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, "global"), - ( - _multilabel_multidim_inputs.preds, - _multilabel_multidim_inputs.target, - "subset", # As this is treated as MDMC by default - ), + (_binary_prob_inputs.preds, _binary_prob_inputs.target, False), + (_binary_inputs.preds, _binary_inputs.target, False), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, True), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, False), + (_multilabel_inputs.preds, _multilabel_inputs.target, True), + (_multilabel_inputs.preds, _multilabel_inputs.target, False), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, False), + (_multiclass_inputs.preds, _multiclass_inputs.target, False), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, False), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, True), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, False), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, True), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, True), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, False), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, True), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, False), ], ) class TestAccuracies(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, mdmc_accuracy): + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy): self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=Accuracy, - sk_metric=partial(_sk_accuracy, mdmc_accuracy=mdmc_accuracy), + sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD, "mdmc_accuracy": mdmc_accuracy}, + metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, ) - def test_accuracy_fn(self, preds, target, mdmc_accuracy): + def test_accuracy_fn(self, preds, target, subset_accuracy): self.run_functional_metric_test( preds, target, metric_functional=accuracy, - sk_metric=partial(_sk_accuracy, mdmc_accuracy=mdmc_accuracy), - metric_args={"threshold": THRESHOLD, "mdmc_accuracy": mdmc_accuracy}, + sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), + metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, ) @@ -100,24 +101,24 @@ def test_accuracy_fn(self, preds, target, mdmc_accuracy): # Replace with a proper sk_metric test once sklearn 0.24 hits :) @pytest.mark.parametrize( - "preds, target, exp_result, k, mdmc_accuracy", + "preds, target, exp_result, k, subset_accuracy", [ - (topk_preds_mc, topk_target_mc, 1 / 6, 1, "global"), - (topk_preds_mc, topk_target_mc, 3 / 6, 2, "global"), - (topk_preds_mc, topk_target_mc, 5 / 6, 3, "global"), - (topk_preds_mc, topk_target_mc, 1 / 6, 1, "subset"), - (topk_preds_mc, topk_target_mc, 3 / 6, 2, "subset"), - (topk_preds_mc, topk_target_mc, 5 / 6, 3, "subset"), - (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, "global"), - (topk_preds_mdmc, topk_target_mdmc, 8 / 18, 2, "global"), - (topk_preds_mdmc, topk_target_mdmc, 13 / 18, 3, "global"), - (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, "subset"), - (topk_preds_mdmc, topk_target_mdmc, 2 / 6, 2, "subset"), - (topk_preds_mdmc, topk_target_mdmc, 3 / 6, 3, "subset"), + (topk_preds_mc, topk_target_mc, 1 / 6, 1, False), + (topk_preds_mc, topk_target_mc, 3 / 6, 2, False), + (topk_preds_mc, topk_target_mc, 5 / 6, 3, False), + (topk_preds_mc, topk_target_mc, 1 / 6, 1, True), + (topk_preds_mc, topk_target_mc, 3 / 6, 2, True), + (topk_preds_mc, topk_target_mc, 5 / 6, 3, True), + (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, False), + (topk_preds_mdmc, topk_target_mdmc, 8 / 18, 2, False), + (topk_preds_mdmc, topk_target_mdmc, 13 / 18, 3, False), + (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, True), + (topk_preds_mdmc, topk_target_mdmc, 2 / 6, 2, True), + (topk_preds_mdmc, topk_target_mdmc, 3 / 6, 3, True), ], ) -def test_topk_accuracy(preds, target, exp_result, k, mdmc_accuracy): - topk = Accuracy(top_k=k, mdmc_accuracy=mdmc_accuracy) +def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): + topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy) for batch in range(preds.shape[0]): topk(preds[batch], target[batch]) @@ -130,7 +131,7 @@ def test_topk_accuracy(preds, target, exp_result, k, mdmc_accuracy): preds = preds.view(total_samples, 4, -1) target = target.view(total_samples, -1) - assert accuracy(preds, target, top_k=k, mdmc_accuracy=mdmc_accuracy) == exp_result + assert accuracy(preds, target, top_k=k, subset_accuracy=subset_accuracy) == exp_result # Only MC and MDMC with probs input type should be accepted for top_k From 98cb5f46d2ff27d583e0008a8497ea69f355c299 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 16:52:24 +0100 Subject: [PATCH 78/94] Update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4a754ba5249d..c5703d642f8b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) -- `HammingLoss` metric to compute the hamming loss (distance) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) +- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) ### Fixed From 474fbd09f85acbb316e6e6e43462354c5adcdadc Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 20:33:34 +0100 Subject: [PATCH 79/94] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/classification/accuracy.py | 8 ++++---- .../metrics/classification/hamming_distance.py | 6 +++--- pytorch_lightning/metrics/functional/accuracy.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index e3c1d1ac7d211..ac2856af15e3c 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -36,7 +36,7 @@ class Accuracy(Metric): For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" accuracy by default, which counts all labels or sub-samples separately. This can be changed to subset accuracy (which requires all labels or sub-samples in the sample to - be correctly predicted) by setting `subset_accuracy=True`. + be correctly predicted) by setting ``subset_accuracy=True``. Accepts all input types listed in :ref:`metrics:Input types`. @@ -115,12 +115,12 @@ def __init__( if not 0 <= threshold <= 1: raise ValueError("The `threshold` should lie in the [0,1] interval.") - self.threshold = threshold if top_k is not None and top_k <= 0: raise ValueError("The `top_k` should be an integer larger than 1.") - self.top_k = top_k + self.threshold = threshold + self.top_k = top_k self.subset_accuracy = subset_accuracy def update(self, preds: torch.Tensor, target: torch.Tensor): @@ -130,7 +130,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): Args: preds: Predictions from model (probabilities, or labels) - target: Ground truth values + target: Ground truth labels """ correct, total = _accuracy_update( diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index 34654bd53a320..a51f12e87d61d 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -39,7 +39,7 @@ class HammingDistance(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - `(0,1)` predictions, in the case of binary or multi-label inputs. Default: `0.5` + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: @@ -48,7 +48,7 @@ class HammingDistance(Metric): process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP + Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the allgather. default: None Example: @@ -91,7 +91,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): Args: preds: Predictions from model (probabilities, or labels) - target: Ground truth values + target: Ground truth labels """ correct, total = _hamming_distance_update(preds, target, self.threshold) diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 06aaed34faf1f..cdff3a7c439c1 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -73,7 +73,7 @@ def accuracy( Args: preds: Predictions from model (probabilities, or labels) - target: Ground truth values + target: Ground truth labels threshold: Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 From 03cccc32757f4d4f64e4558eb38e14bddbb62332 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 20:34:54 +0100 Subject: [PATCH 80/94] Suggestions from code review --- pytorch_lightning/metrics/classification/hamming_distance.py | 2 +- pytorch_lightning/metrics/functional/accuracy.py | 4 ++-- pytorch_lightning/metrics/functional/hamming_distance.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index a51f12e87d61d..0424272ab83c2 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -39,7 +39,7 @@ class HammingDistance(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. Default: `0.5` compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index cdff3a7c439c1..ab7213b8572ac 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -67,7 +67,7 @@ def accuracy( For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" accuracy by default, which counts all labels or sub-samples separately. This can be changed to subset accuracy (which requires all labels or sub-samples in the sample to - be correctly predicted) by setting `subset_accuracy=True`. + be correctly predicted) by setting ``subset_accuracy=True``. Accepts all input types listed in :ref:`metrics:Input types`. @@ -76,7 +76,7 @@ def accuracy( target: Ground truth labels threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. Default: `0.5` top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index daca66163ccf3..5372a63804257 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -51,7 +51,7 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. Default: `0.5` Example: From de0213efbe82d9743fea24324ddc18e3fa3f3678 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Tue, 8 Dec 2020 20:40:02 +0100 Subject: [PATCH 81/94] Fix number in docs --- pytorch_lightning/metrics/classification/accuracy.py | 2 +- pytorch_lightning/metrics/classification/hamming_distance.py | 2 +- pytorch_lightning/metrics/functional/accuracy.py | 2 +- pytorch_lightning/metrics/functional/hamming_distance.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index ac2856af15e3c..af7266263b90f 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -43,7 +43,7 @@ class Accuracy(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - `(0,1)` predictions, in the case of binary or multi-label inputs. Default: `0.5` + `(0,1)` predictions, in the case of binary or multi-label inputs. Default: 0.5 top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index 0424272ab83c2..a51f12e87d61d 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -39,7 +39,7 @@ class HammingDistance(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: `0.5` + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True dist_sync_on_step: diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index ab7213b8572ac..1d245ff35ea79 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -76,7 +76,7 @@ def accuracy( target: Ground truth labels threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: `0.5` + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index 5372a63804257..daca66163ccf3 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -51,7 +51,7 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: `0.5` + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 Example: From 0fbf93cb9f9fc662c20a13a2758350bb55d15814 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 9 Dec 2020 01:12:33 +0530 Subject: [PATCH 82/94] Update pytorch_lightning/metrics/classification/accuracy.py --- pytorch_lightning/metrics/classification/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index af7266263b90f..254f4f7e21d68 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -43,7 +43,7 @@ class Accuracy(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - `(0,1)` predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The From 1b8af657f00b7b8bfaee4c9eb287f780e4a8258c Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 11 Dec 2020 18:59:52 +0100 Subject: [PATCH 83/94] Replace topk by argsort in select_topk --- pytorch_lightning/metrics/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 92faca200d0aa..b74558cb5fe76 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -169,7 +169,9 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch [1, 1, 0]], dtype=torch.int32) """ zeros = torch.zeros_like(prob_tensor) - topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) + argsort = prob_tensor.argsort(dim, descending=True) + selected = torch.index_select(argsort, dim, torch.tensor(list(range(topk)))) + topk_tensor = zeros.scatter(dim, selected, 1.0) return topk_tensor.int() From 3c4f20056a84feaf4da11dc75cd74417ef4a25bf Mon Sep 17 00:00:00 2001 From: Tadej Date: Fri, 11 Dec 2020 19:59:10 +0100 Subject: [PATCH 84/94] Fix changelog --- CHANGELOG.md | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8337eb1d1be91..1d5ddaf406e13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) + +- `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) + +- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) ### Changed @@ -19,13 +24,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -### Added - -- `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) - -- `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) - -- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) ### Fixed From 82d550ec9868b3368c488dbca91771d2f252e44e Mon Sep 17 00:00:00 2001 From: Tadej Date: Sat, 12 Dec 2020 10:25:21 +0100 Subject: [PATCH 85/94] Add test for wrong params --- tests/metrics/classification/test_accuracy.py | 13 +++++++++++++ .../metrics/classification/test_hamming_distance.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 61313119814b3..6e44244e3ca91 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -156,3 +156,16 @@ def test_topk_accuracy_wrong_input_types(preds, target): with pytest.raises(ValueError): accuracy(preds[0], target[0], top_k=1) + + +@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)]) +def test_wrong_params(top_k, threshold): + preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target + + with pytest.raises(ValueError): + acc = Accuracy(threshold=threshold, top_k=top_k) + acc(preds, target) + acc.compute() + + with pytest.raises(ValueError): + accuracy(preds, target, threshold=threshold, top_k=top_k) diff --git a/tests/metrics/classification/test_hamming_distance.py b/tests/metrics/classification/test_hamming_distance.py index f08ec27f565a1..73c2abe771fae 100644 --- a/tests/metrics/classification/test_hamming_distance.py +++ b/tests/metrics/classification/test_hamming_distance.py @@ -67,3 +67,16 @@ def test_hamming_distance_fn(self, preds, target): sk_metric=_sk_hamming_loss, metric_args={"threshold": THRESHOLD}, ) + + +@pytest.mark.parametrize("threshold", [1.5]) +def test_wrong_params(threshold): + preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target + + with pytest.raises(ValueError): + ham_dist = HammingDistance(threshold=threshold) + ham_dist(preds, target) + ham_dist.compute() + + with pytest.raises(ValueError): + hamming_distance(preds, target, threshold=threshold) From eb9cb3c8017b1f207ac1d2c7168327fbe0b63cb5 Mon Sep 17 00:00:00 2001 From: Shachar Mirkin Date: Mon, 14 Dec 2020 13:39:29 +0100 Subject: [PATCH 86/94] Add Google Colab badges (#5111) * Add colab badges to notebook Add colab badges to notebook to notebooks 4 & 5 * Add colab badges Co-authored-by: chaton --- notebooks/04-transformers-text-classification.ipynb | 7 +++++++ notebooks/05-trainer-flags-overview.ipynb | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/notebooks/04-transformers-text-classification.ipynb b/notebooks/04-transformers-text-classification.ipynb index 037b24e4ddd9d..d52af84a76d97 100644 --- a/notebooks/04-transformers-text-classification.ipynb +++ b/notebooks/04-transformers-text-classification.ipynb @@ -1,5 +1,12 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/notebooks/05-trainer-flags-overview.ipynb b/notebooks/05-trainer-flags-overview.ipynb index 6413e8239bb2e..da044a9c9b5c6 100644 --- a/notebooks/05-trainer-flags-overview.ipynb +++ b/notebooks/05-trainer-flags-overview.ipynb @@ -1,5 +1,12 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, { "cell_type": "markdown", "metadata": { From 69123af3ea651a5e1cc25014da6f1c0dee433916 Mon Sep 17 00:00:00 2001 From: Tadej Svetina Date: Mon, 14 Dec 2020 20:13:58 +0100 Subject: [PATCH 87/94] Fix hanging metrics tests (#5134) --- tests/metrics/regression/test_ssim.py | 4 +--- tests/metrics/utils.py | 10 ++++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/metrics/regression/test_ssim.py b/tests/metrics/regression/test_ssim.py index f581188e89fce..8bb304850e3f2 100644 --- a/tests/metrics/regression/test_ssim.py +++ b/tests/metrics/regression/test_ssim.py @@ -53,9 +53,7 @@ def _sk_metric(preds, target, data_range, multichannel): class TestSSIM(MetricTester): atol = 6e-5 - # TODO: for some reason this test hangs with ddp=True - # @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): self.run_class_metric_test( diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index c607a466b2068..4bd6608ce3fcf 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -11,6 +11,11 @@ from pytorch_lightning.metrics import Metric +try: + set_start_method("spawn") +except RuntimeError: + pass + NUM_PROCESSES = 2 NUM_BATCHES = 10 BATCH_SIZE = 32 @@ -165,10 +170,7 @@ def setup_class(self): """Setup the metric class. This will spawn the pool of workers that are used for metric testing and setup_ddp """ - try: - set_start_method("spawn") - except RuntimeError: - pass + self.poolSize = NUM_PROCESSES self.pool = Pool(processes=self.poolSize) self.pool.starmap(setup_ddp, [(rank, self.poolSize) for rank in range(self.poolSize)]) From f68acc07c7e8722ee88d7a2f808c1c962fdd74c4 Mon Sep 17 00:00:00 2001 From: Tadej Date: Mon, 14 Dec 2020 20:22:26 +0100 Subject: [PATCH 88/94] Use torch.topk again as ddp hanging tests fixed in #5134 --- pytorch_lightning/metrics/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index b74558cb5fe76..8c59bb4991cab 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -169,10 +169,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch [1, 1, 0]], dtype=torch.int32) """ zeros = torch.zeros_like(prob_tensor) - argsort = prob_tensor.argsort(dim, descending=True) - selected = torch.index_select(argsort, dim, torch.tensor(list(range(topk)))) - topk_tensor = zeros.scatter(dim, selected, 1.0) - + topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() From 3bf3c3ac3e12e7b57f43de8f565e3a466a37f050 Mon Sep 17 00:00:00 2001 From: Tadej Date: Thu, 17 Dec 2020 11:39:04 +0100 Subject: [PATCH 89/94] Fix unwanted notebooks change --- notebooks/04-transformers-text-classification.ipynb | 7 ------- notebooks/05-trainer-flags-overview.ipynb | 7 ------- 2 files changed, 14 deletions(-) diff --git a/notebooks/04-transformers-text-classification.ipynb b/notebooks/04-transformers-text-classification.ipynb index d52af84a76d97..037b24e4ddd9d 100644 --- a/notebooks/04-transformers-text-classification.ipynb +++ b/notebooks/04-transformers-text-classification.ipynb @@ -1,12 +1,5 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Open" - ] - }, { "cell_type": "markdown", "metadata": { diff --git a/notebooks/05-trainer-flags-overview.ipynb b/notebooks/05-trainer-flags-overview.ipynb index da044a9c9b5c6..6413e8239bb2e 100644 --- a/notebooks/05-trainer-flags-overview.ipynb +++ b/notebooks/05-trainer-flags-overview.ipynb @@ -1,12 +1,5 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Open" - ] - }, { "cell_type": "markdown", "metadata": { From 44135c7c776cfa688e958547993051ebcd59070a Mon Sep 17 00:00:00 2001 From: Tadej Date: Mon, 21 Dec 2020 12:35:48 +0100 Subject: [PATCH 90/94] Fix too long line in hamming_distance --- pytorch_lightning/metrics/functional/hamming_distance.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index daca66163ccf3..2e7125354ee02 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -17,7 +17,9 @@ from pytorch_lightning.metrics.classification.helpers import _input_format_classification -def _hamming_distance_update(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, int]: +def _hamming_distance_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> Tuple[torch.Tensor, int]: preds, target, _ = _input_format_classification(preds, target, threshold=threshold) correct = (preds == target).sum() From 908e60f84f3f041df746f599c58ef4a523169637 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 21 Dec 2020 13:43:43 +0100 Subject: [PATCH 91/94] Apply suggestions from code review --- pytorch_lightning/metrics/classification/accuracy.py | 4 ++-- .../metrics/classification/hamming_distance.py | 8 ++++---- pytorch_lightning/metrics/functional/accuracy.py | 4 ++-- pytorch_lightning/metrics/functional/hamming_distance.py | 4 +++- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 254f4f7e21d68..747b339f7c0f8 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -52,7 +52,7 @@ class Accuracy(Metric): Should be left at default (``None``) for all other types of inputs. subset_accuracy: Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). Default: `False` + multi-class inputs (has no effect for other input types). For multi-label inputs, if the parameter is set to `True`, then all labels for each sample must be correctly predicted for the sample to count as correct. If it @@ -66,7 +66,7 @@ class Accuracy(Metric): ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter still applies in both cases, if set. compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True + Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index a51f12e87d61d..b3281cd60987c 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -39,17 +39,17 @@ class HammingDistance(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + `(0,1)` predictions, in the case of binary or multi-label inputs. compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True + Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False + before returning the value at the step. process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather. default: None + will be used to perform the all gather. Example: diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 1d245ff35ea79..8ba0e49b881b8 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -76,7 +76,7 @@ def accuracy( target: Ground truth labels threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + `(0,1)` predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -85,7 +85,7 @@ def accuracy( Should be left at default (``None``) for all other types of inputs. subset_accuracy: Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). Default: `False` + multi-class inputs (has no effect for other input types). For multi-label inputs, if the parameter is set to `True`, then all labels for each sample must be correctly predicted for the sample to count as correct. If it diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index 2e7125354ee02..7d8ecafd08b00 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -51,9 +51,11 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float Accepts all input types listed in :ref:`metrics:Input types`. Args: + preds: Predictions from model + target: Ground truth threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. Example: From 23a997e3a388f200f095912419214e883f5e1d77 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 21 Dec 2020 13:44:16 +0100 Subject: [PATCH 92/94] Apply suggestions from code review --- pytorch_lightning/metrics/classification/accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 747b339f7c0f8..e248c132026a4 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -43,7 +43,7 @@ class Accuracy(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + `(0,1)` predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The From b3e458d57cd8b498a42c311cd3df3d43c3d1c4e1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 21 Dec 2020 13:46:45 +0100 Subject: [PATCH 93/94] protect --- tests/metrics/classification/test_accuracy.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 6e44244e3ca91..7b28e07c894dd 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -85,36 +85,36 @@ def test_accuracy_fn(self, preds, target, subset_accuracy): ) -l1to4 = [0.1, 0.2, 0.3, 0.4] -l1to4t3 = np.array([l1to4, l1to4, l1to4]) -l1to4t3_mc = [l1to4t3.T, l1to4t3.T, l1to4t3.T] +_l1to4 = [0.1, 0.2, 0.3, 0.4] +_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4]) +_l1to4t3_mc = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T] # The preds in these examples always put highest probability on class 3, second highest on class 2, # third highest on class 1, and lowest on class 0 -topk_preds_mc = torch.tensor([l1to4t3, l1to4t3]).float() -topk_target_mc = torch.tensor([[1, 2, 3], [2, 1, 0]]) +_topk_preds_mc = torch.tensor([_l1to4t3, _l1to4t3]).float() +_topk_target_mc = torch.tensor([[1, 2, 3], [2, 1, 0]]) # This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) -topk_preds_mdmc = torch.tensor([l1to4t3_mc, l1to4t3_mc]).float() -topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) +_topk_preds_mdmc = torch.tensor([_l1to4t3_mc, _l1to4t3_mc]).float() +_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) # Replace with a proper sk_metric test once sklearn 0.24 hits :) @pytest.mark.parametrize( "preds, target, exp_result, k, subset_accuracy", [ - (topk_preds_mc, topk_target_mc, 1 / 6, 1, False), - (topk_preds_mc, topk_target_mc, 3 / 6, 2, False), - (topk_preds_mc, topk_target_mc, 5 / 6, 3, False), - (topk_preds_mc, topk_target_mc, 1 / 6, 1, True), - (topk_preds_mc, topk_target_mc, 3 / 6, 2, True), - (topk_preds_mc, topk_target_mc, 5 / 6, 3, True), - (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, False), - (topk_preds_mdmc, topk_target_mdmc, 8 / 18, 2, False), - (topk_preds_mdmc, topk_target_mdmc, 13 / 18, 3, False), - (topk_preds_mdmc, topk_target_mdmc, 1 / 6, 1, True), - (topk_preds_mdmc, topk_target_mdmc, 2 / 6, 2, True), - (topk_preds_mdmc, topk_target_mdmc, 3 / 6, 3, True), + (_topk_preds_mc, _topk_target_mc, 1 / 6, 1, False), + (_topk_preds_mc, _topk_target_mc, 3 / 6, 2, False), + (_topk_preds_mc, _topk_target_mc, 5 / 6, 3, False), + (_topk_preds_mc, _topk_target_mc, 1 / 6, 1, True), + (_topk_preds_mc, _topk_target_mc, 3 / 6, 2, True), + (_topk_preds_mc, _topk_target_mc, 5 / 6, 3, True), + (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False), + (_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False), + (_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False), + (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True), + (_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True), + (_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True), ], ) def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): From 92f5f8374be4ac234c7959553a5a8c771a94cd54 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 21 Dec 2020 18:37:01 +0530 Subject: [PATCH 94/94] Update CHANGELOG.md --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cd8243ff829e..2a46f49211268 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed - ### Fixed