Skip to content

Commit

Permalink
[Feat] support calculate confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 committed Mar 2, 2023
1 parent 52b871f commit 5676b0d
Show file tree
Hide file tree
Showing 5 changed files with 527 additions and 15 deletions.
4 changes: 2 additions & 2 deletions mmaction/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .acc_metric import AccMetric
from .acc_metric import AccMetric, ConfusionMatrix
from .anet_metric import ANetMetric
from .ava_metric import AVAMetric

__all__ = ['AccMetric', 'AVAMetric', 'ANetMetric']
__all__ = ['AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix']
211 changes: 210 additions & 1 deletion mmaction/evaluation/metrics/acc_metric.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import Any, Optional, Sequence, Tuple, Union
from itertools import product
from typing import Any, List, Optional, Sequence, Tuple, Union

import mmengine
import numpy as np
import torch
from mmengine.evaluator import BaseMetric

from mmaction.evaluation import (mean_average_precision, mean_class_accuracy,
mmit_mean_average_precision, top_k_accuracy)
from mmaction.registry import METRICS


def to_tensor(value):
"""Convert value to torch.Tensor."""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
elif isinstance(value, Sequence) and not mmengine.is_str(value):
value = torch.tensor(value)
elif not isinstance(value, torch.Tensor):
raise TypeError(f'{type(value)} is not an available argument.')
return value


@METRICS.register_module()
class AccMetric(BaseMetric):
"""Accuracy evaluation metric."""
Expand Down Expand Up @@ -136,3 +150,198 @@ def label2array(num, label):
arr = np.zeros(num, dtype=np.float32)
arr[label] = 1.
return arr


@METRICS.register_module()
class ConfusionMatrix(BaseMetric):
r"""A metric to calculate confusion matrix for single-label tasks.
Args:
num_classes (int, optional): The number of classes. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
1. The basic usage.
>>> import torch
>>> from mmcls.evaluation import ConfusionMatrix
>>> y_pred = [0, 1, 1, 3]
>>> y_true = [0, 2, 1, 3]
>>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
tensor([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]])
>>> # plot the confusion matrix
>>> import matplotlib.pyplot as plt
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.randint(10, (1000, ))
>>> matrix = ConfusionMatrix.calculate(y_score, y_true)
>>> ConfusionMatrix().plot(matrix)
>>> plt.show()
2. In the config file
.. code:: python
val_evaluator = dict(type='ConfusionMatrix')
test_evaluator = dict(type='ConfusionMatrix')
""" # noqa: E501
default_prefix = 'confusion_matrix'

def __init__(self,
num_classes: Optional[int] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device, prefix)

self.num_classes = num_classes

def process(self, data_batch, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
pred_scores = data_sample.get('pred_scores')
gt_label = data_sample['gt_labels']['item']
if pred_scores is not None:
pred_label = pred_scores['item'].argmax(dim=0, keepdim=True)
self.num_classes = pred_scores['item'].size(0)
else:
pred_label = data_sample['pred_labels']['item']

self.results.append({
'pred_label': pred_label,
'gt_label': gt_label
})

def compute_metrics(self, results: list) -> dict:
pred_labels = []
gt_labels = []
for result in results:
pred_labels.append(result['pred_label'])
gt_labels.append(result['gt_label'])
confusion_matrix = ConfusionMatrix.calculate(
torch.cat(pred_labels),
torch.cat(gt_labels),
num_classes=self.num_classes)
return {'result': confusion_matrix}

@staticmethod
def calculate(pred, target, num_classes=None) -> dict:
"""Calculate the confusion matrix for single-label task.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. It can be labels (N, ), or scores of every
class (N, C).
target (torch.Tensor | np.ndarray | Sequence): The target of
each prediction with shape (N, ).
num_classes (Optional, int): The number of classes. If the ``pred``
is label instead of scores, this argument is required.
Defaults to None.
Returns:
torch.Tensor: The confusion matrix.
"""
pred = to_tensor(pred)
target_label = to_tensor(target).int()

assert pred.size(0) == target_label.size(0), \
f"The size of pred ({pred.size(0)}) doesn't match "\
f'the target ({target_label.size(0)}).'
assert target_label.ndim == 1

if pred.ndim == 1:
assert num_classes is not None, \
'Please specify the `num_classes` if the `pred` is labels ' \
'intead of scores.'
pred_label = pred
else:
num_classes = num_classes or pred.size(1)
pred_label = torch.argmax(pred, dim=1).flatten()

with torch.no_grad():
indices = num_classes * target_label + pred_label
matrix = torch.bincount(indices, minlength=num_classes**2)
matrix = matrix.reshape(num_classes, num_classes)

return matrix

@staticmethod
def plot(confusion_matrix: torch.Tensor,
include_values: bool = False,
cmap: str = 'viridis',
classes: Optional[List[str]] = None,
colorbar: bool = True,
show: bool = True):
"""Draw a confusion matrix by matplotlib.
Modified from `Scikit-Learn
<https://github.com/scikit-learn/scikit-learn/blob/dc580a8ef/sklearn/metrics/_plot/confusion_matrix.py#L81>`_
Args:
confusion_matrix (torch.Tensor): The confusion matrix to draw.
include_values (bool): Whether to draw the values in the figure.
Defaults to False.
cmap (str): The color map to use. Defaults to use "viridis".
classes (list[str], optional): The names of categories.
Defaults to None, which means to use index number.
colorbar (bool): Whether to show the colorbar. Defaults to True.
show (bool): Whether to show the figure immediately.
Defaults to True.
""" # noqa: E501
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 10))

num_classes = confusion_matrix.size(0)

im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
text_ = None
cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0)

if include_values:
text_ = np.empty_like(confusion_matrix, dtype=object)

# print text with appropriate color depending on background
thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0

for i, j in product(range(num_classes), range(num_classes)):
color = cmap_max if confusion_matrix[i,
j] < thresh else cmap_min

text_cm = format(confusion_matrix[i, j], '.2g')
text_d = format(confusion_matrix[i, j], 'd')
if len(text_d) < len(text_cm):
text_cm = text_d

text_[i, j] = ax.text(
j, i, text_cm, ha='center', va='center', color=color)

display_labels = classes or np.arange(num_classes)

if colorbar:
fig.colorbar(im_, ax=ax)
ax.set(
xticks=np.arange(num_classes),
yticks=np.arange(num_classes),
xticklabels=display_labels,
yticklabels=display_labels,
ylabel='True label',
xlabel='Predicted label',
)
ax.invert_yaxis()
ax.xaxis.tick_top()

ax.set_ylim((num_classes - 0.5, -0.5))
# Automatically rotate the x labels.
fig.autofmt_xdate(ha='center')

if show:
plt.show()
return fig
102 changes: 91 additions & 11 deletions mmaction/structures/action_data_sample.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,105 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from numbers import Number
from typing import Sequence, Union

import numpy as np
import torch
from mmengine.structures import BaseDataElement, InstanceData, LabelData
from mmengine.utils import is_str


def format_label(value: Union[torch.Tensor, np.ndarray, Sequence,
int]) -> torch.Tensor:
"""Convert various python types to label-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
Returns:
:obj:`torch.Tensor`: The foramtted label tensor.
"""

# Handle single number
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
value = int(value.item())

if isinstance(value, np.ndarray):
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'

return value


def format_score(value: Union[torch.Tensor, np.ndarray,
Sequence]) -> torch.Tensor:
"""Convert various python types to score-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
Returns:
:obj:`torch.Tensor`: The foramtted score tensor.
"""

if isinstance(value, np.ndarray):
value = torch.from_numpy(value).float()
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).float()
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'

return value


class ActionDataSample(BaseDataElement):

def set_gt_labels(self, value: Union[int,
np.ndarray]) -> 'ActionDataSample':
def set_gt_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ActionDataSample':
"""Set label of ``gt_labels``."""
if isinstance(value, int):
value = torch.LongTensor([value])
elif isinstance(value, np.ndarray):
value = torch.from_numpy(value)
else:
raise TypeError(f'Type {type(value)} is not an '
f'available label type.')
label_data = getattr(self, '_gt_label', LabelData())
label_data.item = format_label(value)
self.gt_labels = label_data
return self

self.gt_labels = LabelData(item=value)
def set_pred_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ActionDataSample':
"""Set label of ``pred_label``."""
label_data = getattr(self, '_pred_label', LabelData())
label_data.item = format_label(value)
self.pred_labels = label_data
return self

def set_pred_score(self, value: torch.Tensor) -> 'ActionDataSample':
"""Set score of ``pred_label``."""
label_data = getattr(self, '_pred_label', LabelData())
label_data.item = format_score(value)
if hasattr(self, 'num_classes'):
assert len(label_data.item) == self.num_classes, \
f'The length of score {len(label_data.item)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes',
value=len(label_data.item),
field_type='metainfo')
self.pred_scores = label_data
return self

@property
Expand Down
Loading

0 comments on commit 5676b0d

Please sign in to comment.