diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py index 46988d39c1..0493dae036 100644 --- a/mmaction/evaluation/metrics/__init__.py +++ b/mmaction/evaluation/metrics/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .acc_metric import AccMetric +from .acc_metric import AccMetric, ConfusionMatrix from .anet_metric import ANetMetric from .ava_metric import AVAMetric -__all__ = ['AccMetric', 'AVAMetric', 'ANetMetric'] +__all__ = ['AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix'] diff --git a/mmaction/evaluation/metrics/acc_metric.py b/mmaction/evaluation/metrics/acc_metric.py index 488e28aa14..6875f76336 100644 --- a/mmaction/evaluation/metrics/acc_metric.py +++ b/mmaction/evaluation/metrics/acc_metric.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy from collections import OrderedDict -from typing import Any, Optional, Sequence, Tuple, Union +from itertools import product +from typing import Any, List, Optional, Sequence, Tuple, Union +import mmengine import numpy as np +import torch from mmengine.evaluator import BaseMetric from mmaction.evaluation import (mean_average_precision, mean_class_accuracy, @@ -11,6 +14,17 @@ from mmaction.registry import METRICS +def to_tensor(value): + """Convert value to torch.Tensor.""" + if isinstance(value, np.ndarray): + value = torch.from_numpy(value) + elif isinstance(value, Sequence) and not mmengine.is_str(value): + value = torch.tensor(value) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'{type(value)} is not an available argument.') + return value + + @METRICS.register_module() class AccMetric(BaseMetric): """Accuracy evaluation metric.""" @@ -136,3 +150,198 @@ def label2array(num, label): arr = np.zeros(num, dtype=np.float32) arr[label] = 1. return arr + + +@METRICS.register_module() +class ConfusionMatrix(BaseMetric): + r"""A metric to calculate confusion matrix for single-label tasks. + + Args: + num_classes (int, optional): The number of classes. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + + 1. The basic usage. + + >>> import torch + >>> from mmcls.evaluation import ConfusionMatrix + >>> y_pred = [0, 1, 1, 3] + >>> y_true = [0, 2, 1, 3] + >>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4) + tensor([[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) + >>> # plot the confusion matrix + >>> import matplotlib.pyplot as plt + >>> y_score = torch.rand((1000, 10)) + >>> y_true = torch.randint(10, (1000, )) + >>> matrix = ConfusionMatrix.calculate(y_score, y_true) + >>> ConfusionMatrix().plot(matrix) + >>> plt.show() + + 2. In the config file + + .. code:: python + + val_evaluator = dict(type='ConfusionMatrix') + test_evaluator = dict(type='ConfusionMatrix') + """ # noqa: E501 + default_prefix = 'confusion_matrix' + + def __init__(self, + num_classes: Optional[int] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device, prefix) + + self.num_classes = num_classes + + def process(self, data_batch, data_samples: Sequence[dict]) -> None: + for data_sample in data_samples: + pred_scores = data_sample.get('pred_scores') + gt_label = data_sample['gt_labels']['item'] + if pred_scores is not None: + pred_label = pred_scores['item'].argmax(dim=0, keepdim=True) + self.num_classes = pred_scores['item'].size(0) + else: + pred_label = data_sample['pred_labels']['item'] + + self.results.append({ + 'pred_label': pred_label, + 'gt_label': gt_label + }) + + def compute_metrics(self, results: list) -> dict: + pred_labels = [] + gt_labels = [] + for result in results: + pred_labels.append(result['pred_label']) + gt_labels.append(result['gt_label']) + confusion_matrix = ConfusionMatrix.calculate( + torch.cat(pred_labels), + torch.cat(gt_labels), + num_classes=self.num_classes) + return {'result': confusion_matrix} + + @staticmethod + def calculate(pred, target, num_classes=None) -> dict: + """Calculate the confusion matrix for single-label task. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. It can be labels (N, ), or scores of every + class (N, C). + target (torch.Tensor | np.ndarray | Sequence): The target of + each prediction with shape (N, ). + num_classes (Optional, int): The number of classes. If the ``pred`` + is label instead of scores, this argument is required. + Defaults to None. + + Returns: + torch.Tensor: The confusion matrix. + """ + pred = to_tensor(pred) + target_label = to_tensor(target).int() + + assert pred.size(0) == target_label.size(0), \ + f"The size of pred ({pred.size(0)}) doesn't match "\ + f'the target ({target_label.size(0)}).' + assert target_label.ndim == 1 + + if pred.ndim == 1: + assert num_classes is not None, \ + 'Please specify the `num_classes` if the `pred` is labels ' \ + 'intead of scores.' + pred_label = pred + else: + num_classes = num_classes or pred.size(1) + pred_label = torch.argmax(pred, dim=1).flatten() + + with torch.no_grad(): + indices = num_classes * target_label + pred_label + matrix = torch.bincount(indices, minlength=num_classes**2) + matrix = matrix.reshape(num_classes, num_classes) + + return matrix + + @staticmethod + def plot(confusion_matrix: torch.Tensor, + include_values: bool = False, + cmap: str = 'viridis', + classes: Optional[List[str]] = None, + colorbar: bool = True, + show: bool = True): + """Draw a confusion matrix by matplotlib. + + Modified from `Scikit-Learn + `_ + + Args: + confusion_matrix (torch.Tensor): The confusion matrix to draw. + include_values (bool): Whether to draw the values in the figure. + Defaults to False. + cmap (str): The color map to use. Defaults to use "viridis". + classes (list[str], optional): The names of categories. + Defaults to None, which means to use index number. + colorbar (bool): Whether to show the colorbar. Defaults to True. + show (bool): Whether to show the figure immediately. + Defaults to True. + """ # noqa: E501 + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 10)) + + num_classes = confusion_matrix.size(0) + + im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap) + text_ = None + cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0) + + if include_values: + text_ = np.empty_like(confusion_matrix, dtype=object) + + # print text with appropriate color depending on background + thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0 + + for i, j in product(range(num_classes), range(num_classes)): + color = cmap_max if confusion_matrix[i, + j] < thresh else cmap_min + + text_cm = format(confusion_matrix[i, j], '.2g') + text_d = format(confusion_matrix[i, j], 'd') + if len(text_d) < len(text_cm): + text_cm = text_d + + text_[i, j] = ax.text( + j, i, text_cm, ha='center', va='center', color=color) + + display_labels = classes or np.arange(num_classes) + + if colorbar: + fig.colorbar(im_, ax=ax) + ax.set( + xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=display_labels, + yticklabels=display_labels, + ylabel='True label', + xlabel='Predicted label', + ) + ax.invert_yaxis() + ax.xaxis.tick_top() + + ax.set_ylim((num_classes - 0.5, -0.5)) + # Automatically rotate the x labels. + fig.autofmt_xdate(ha='center') + + if show: + plt.show() + return fig diff --git a/mmaction/structures/action_data_sample.py b/mmaction/structures/action_data_sample.py index c75f6654a1..02e46bd151 100644 --- a/mmaction/structures/action_data_sample.py +++ b/mmaction/structures/action_data_sample.py @@ -1,25 +1,105 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Union +from numbers import Number +from typing import Sequence, Union import numpy as np import torch from mmengine.structures import BaseDataElement, InstanceData, LabelData +from mmengine.utils import is_str + + +def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, + int]) -> torch.Tensor: + """Convert various python types to label-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. + + Returns: + :obj:`torch.Tensor`: The foramtted label tensor. + """ + + # Handle single number + if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: + value = int(value.item()) + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).to(torch.long) + elif isinstance(value, int): + value = torch.LongTensor([value]) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def format_score(value: Union[torch.Tensor, np.ndarray, + Sequence]) -> torch.Tensor: + """Convert various python types to score-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence): Score values. + + Returns: + :obj:`torch.Tensor`: The foramtted score tensor. + """ + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).float() + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).float() + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value class ActionDataSample(BaseDataElement): - def set_gt_labels(self, value: Union[int, - np.ndarray]) -> 'ActionDataSample': + def set_gt_label( + self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] + ) -> 'ActionDataSample': """Set label of ``gt_labels``.""" - if isinstance(value, int): - value = torch.LongTensor([value]) - elif isinstance(value, np.ndarray): - value = torch.from_numpy(value) - else: - raise TypeError(f'Type {type(value)} is not an ' - f'available label type.') + label_data = getattr(self, '_gt_label', LabelData()) + label_data.item = format_label(value) + self.gt_labels = label_data + return self - self.gt_labels = LabelData(item=value) + def set_pred_label( + self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] + ) -> 'ActionDataSample': + """Set label of ``pred_label``.""" + label_data = getattr(self, '_pred_label', LabelData()) + label_data.item = format_label(value) + self.pred_labels = label_data + return self + + def set_pred_score(self, value: torch.Tensor) -> 'ActionDataSample': + """Set score of ``pred_label``.""" + label_data = getattr(self, '_pred_label', LabelData()) + label_data.item = format_score(value) + if hasattr(self, 'num_classes'): + assert len(label_data.item) == self.num_classes, \ + f'The length of score {len(label_data.item)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', + value=len(label_data.item), + field_type='metainfo') + self.pred_scores = label_data return self @property diff --git a/tests/evaluation/metrics/test_acc_metric.py b/tests/evaluation/metrics/test_acc_metric.py index 273155858c..5aa14b6f1d 100644 --- a/tests/evaluation/metrics/test_acc_metric.py +++ b/tests/evaluation/metrics/test_acc_metric.py @@ -1,7 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np import torch -from mmaction.evaluation import AccMetric +from mmaction.evaluation import AccMetric, ConfusionMatrix +from mmaction.registry import METRICS +from mmaction.structures import ActionDataSample def generate_data(num_classes=5, random_label=False): @@ -41,3 +46,113 @@ def test_accmetric(): assert eval_results['mean1'] == 1.0 assert eval_results['mmit_mean_average_precision'] == 1.0 return + + +class TestConfusionMatrix(TestCase): + + def test_evaluate(self): + """Test using the metric in the same way as Evalutor.""" + pred = [ + ActionDataSample().set_pred_score(i).set_pred_label( + j).set_gt_label(k).to_dict() for i, j, k in zip([ + torch.tensor([0.7, 0.0, 0.3]), + torch.tensor([0.5, 0.2, 0.3]), + torch.tensor([0.4, 0.5, 0.1]), + torch.tensor([0.0, 0.0, 1.0]), + torch.tensor([0.0, 0.0, 1.0]), + torch.tensor([0.0, 0.0, 1.0]), + ], [0, 0, 1, 2, 2, 2], [0, 0, 1, 2, 1, 0]) + ] + + # Test with score (use score instead of label if score exists) + metric = METRICS.build(dict(type='ConfusionMatrix')) + metric.process(None, pred) + res = metric.evaluate(6) + self.assertIsInstance(res, dict) + self.assertTensorEqual( + res['confusion_matrix/result'], + torch.tensor([ + [2, 0, 1], + [0, 1, 1], + [0, 0, 1], + ])) + + # Test with label + for sample in pred: + del sample['pred_scores'] + metric = METRICS.build(dict(type='ConfusionMatrix')) + metric.process(None, pred) + with self.assertRaisesRegex(AssertionError, + 'Please specify the `num_classes`'): + metric.evaluate(6) + + metric = METRICS.build(dict(type='ConfusionMatrix', num_classes=3)) + metric.process(None, pred) + self.assertIsInstance(res, dict) + self.assertTensorEqual( + res['confusion_matrix/result'], + torch.tensor([ + [2, 0, 1], + [0, 1, 1], + [0, 0, 1], + ])) + + def test_calculate(self): + y_true = np.array([0, 0, 1, 2, 1, 0]) + y_label = torch.tensor([0, 0, 1, 2, 2, 2]) + y_score = [ + [0.7, 0.0, 0.3], + [0.5, 0.2, 0.3], + [0.4, 0.5, 0.1], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + ] + + # Test with score + cm = ConfusionMatrix.calculate(y_score, y_true) + self.assertIsInstance(cm, torch.Tensor) + self.assertTensorEqual( + cm, torch.tensor([ + [2, 0, 1], + [0, 1, 1], + [0, 0, 1], + ])) + + # Test with label + with self.assertRaisesRegex(AssertionError, + 'Please specify the `num_classes`'): + ConfusionMatrix.calculate(y_label, y_true) + + cm = ConfusionMatrix.calculate(y_label, y_true, num_classes=3) + self.assertIsInstance(cm, torch.Tensor) + self.assertTensorEqual( + cm, torch.tensor([ + [2, 0, 1], + [0, 1, 1], + [0, 0, 1], + ])) + + # Test with invalid inputs + with self.assertRaisesRegex(TypeError, " is not"): + ConfusionMatrix.calculate(y_label, 'hi') + + def test_plot(self): + import matplotlib.pyplot as plt + + cm = torch.tensor([[2, 0, 1], [0, 1, 1], [0, 0, 1]]) + fig = ConfusionMatrix.plot(cm, include_values=True, show=False) + + self.assertIsInstance(fig, plt.Figure) + + def assertTensorEqual(self, + tensor: torch.Tensor, + value: float, + msg=None, + **kwarg): + tensor = tensor.to(torch.float32) + value = torch.tensor(value).float() + try: + torch.testing.assert_allclose(tensor, value, **kwarg) + except AssertionError as e: + self.fail(self._formatMessage(msg, str(e))) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py new file mode 100644 index 0000000000..ff560af931 --- /dev/null +++ b/tools/analysis_tools/confusion_matrix.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import tempfile + +import torch +from mmengine import dump, list_from_file, load +from mmengine.config import Config, DictAction +from mmengine.evaluator import Evaluator +from mmengine.runner import Runner + +from mmaction.evaluation import ConfusionMatrix +from mmaction.registry import DATASETS +from mmaction.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Eval a checkpoint and draw the confusion matrix.') + parser.add_argument('config', help='test config file path') + parser.add_argument( + 'ckpt_or_result', + type=str, + help='The checkpoint file (.pth) or ' + 'dumpped predictions pickle file (.pkl).') + parser.add_argument('--out', help='the file to save the confusion matrix.') + parser.add_argument( + '--show', + action='store_true', + help='whether to display the metric result by matplotlib if supports.') + parser.add_argument( + '--show-path', type=str, help='Path to save the visualization image.') + parser.add_argument( + '--include-values', + action='store_true', + help='To draw the values in the figure.') + parser.add_argument('--label-file', default=None, help='Labelmap file') + parser.add_argument( + '--target-classes', + type=int, + nargs='+', + default=[], + help='Selected classes to evaluate, and remains will be neglected') + parser.add_argument( + '--cmap', + type=str, + default='viridis', + help='The color map to use. Defaults to "viridis".') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # register all modules in mmcls into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + if args.ckpt_or_result.endswith('.pth'): + # Set confusion matrix as the metric. + cfg.test_evaluator = dict(type='ConfusionMatrix') + + cfg.load_from = str(args.ckpt_or_result) + + with tempfile.TemporaryDirectory() as tmpdir: + cfg.work_dir = tmpdir + runner = Runner.from_cfg(cfg) + classes = runner.test_loop.dataloader.dataset.metainfo.get( + 'classes') + cm = runner.test()['confusion_matrix/result'] + else: + predictions = load(args.ckpt_or_result) + evaluator = Evaluator(ConfusionMatrix()) + metrics = evaluator.offline_evaluate(predictions, None) + cm = metrics['confusion_matrix/result'] + try: + # Try to build the dataset. + dataset = DATASETS.build({ + **cfg.test_dataloader.dataset, 'pipeline': [] + }) + classes = dataset.metainfo.get('classes') + except Exception: + classes = None + + if args.label_file is not None: + classes = list_from_file(args.label_file) + if classes is None: + num_classes = cm.shape[0] + classes = list(range(num_classes)) + + if args.target_classes: + assert len(args.target_classes) > 1, \ + 'please ensure select more than one class' + target_idx = torch.tensor(args.target_classes) + cm = cm[target_idx][:, target_idx] + classes = [classes[idx] for idx in target_idx] + + if args.out is not None: + dump(cm, args.out) + + if args.show or args.show_path is not None: + fig = ConfusionMatrix.plot( + cm, + show=args.show, + classes=classes, + include_values=args.include_values, + cmap=args.cmap) + if args.show_path is not None: + fig.savefig(args.show_path) + print(f'The confusion matrix is saved at {args.show_path}.') + + +if __name__ == '__main__': + main()