Skip to content

Commit

Permalink
[Enhance] Imporve efficiency of precision, recall, f1_score and suppo…
Browse files Browse the repository at this point in the history
…rt. (#595)

* [Enhance] Imporve efficiency of precision, recall, f1_score and support.

* Fix bugs

* Use np.maximum since torch doesn't have maximum before torch 1.7

* Fix bug
  • Loading branch information
mzr1996 authored Dec 13, 2021
1 parent 851b438 commit e3cf188
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 45 deletions.
65 changes: 38 additions & 27 deletions mmcls/core/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from torch.nn.functional import one_hot


def calculate_confusion_matrix(pred, target):
Expand All @@ -27,16 +28,17 @@ def calculate_confusion_matrix(pred, target):
(f'pred and target should be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.')

# Modified from PyTorch-Ignite
num_classes = pred.size(1)
_, pred_label = pred.topk(1, dim=1)
pred_label = pred_label.view(-1)
target_label = target.view(-1)
pred_label = torch.argmax(pred, dim=1).flatten()
target_label = target.flatten()
assert len(pred_label) == len(target_label)
confusion_matrix = torch.zeros(num_classes, num_classes)

with torch.no_grad():
for t, p in zip(target_label, pred_label):
confusion_matrix[t.long(), p.long()] += 1
return confusion_matrix
indices = num_classes * target_label + pred_label
matrix = torch.bincount(indices, minlength=num_classes**2)
matrix = matrix.reshape(num_classes, num_classes)
return matrix


def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
Expand Down Expand Up @@ -73,13 +75,15 @@ class are returned. If 'macro', calculate metrics for each class,
if average_mode not in allowed_average_mode:
raise ValueError(f'Unsupport type of averaging {average_mode}.')

if isinstance(pred, torch.Tensor):
pred = pred.numpy()
if isinstance(target, torch.Tensor):
target = target.numpy()
assert (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)),\
(f'pred and target should be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.')
if isinstance(pred, np.ndarray):
pred = torch.from_numpy(pred)
assert isinstance(pred, torch.Tensor), \
(f'pred should be torch.Tensor or np.ndarray, but got {type(pred)}.')
if isinstance(target, np.ndarray):
target = torch.from_numpy(target)
assert isinstance(target, torch.Tensor), \
f'target should be torch.Tensor or np.ndarray, ' \
f'but got {type(target)}.'

if isinstance(thrs, Number):
thrs = (thrs, )
Expand All @@ -90,30 +94,37 @@ class are returned. If 'macro', calculate metrics for each class,
raise TypeError(
f'thrs should be a number or tuple, but got {type(thrs)}.')

label = np.indices(pred.shape)[1]
pred_label = np.argsort(pred, axis=1)[:, -1]
pred_score = np.sort(pred, axis=1)[:, -1]
num_classes = pred.size(1)
pred_score, pred_label = torch.topk(pred, k=1)
pred_score = pred_score.flatten()
pred_label = pred_label.flatten()

gt_positive = one_hot(target.flatten(), num_classes)

precisions = []
recalls = []
f1_scores = []
for thr in thrs:
# Only prediction values larger than thr are counted as positive
_pred_label = pred_label.copy()
pred_positive = one_hot(pred_label, num_classes)
if thr is not None:
_pred_label[pred_score <= thr] = -1
pred_positive = label == _pred_label.reshape(-1, 1)
gt_positive = label == target.reshape(-1, 1)
precision = (pred_positive & gt_positive).sum(0) / np.maximum(
pred_positive.sum(0), 1) * 100
recall = (pred_positive & gt_positive).sum(0) / np.maximum(
gt_positive.sum(0), 1) * 100
f1_score = 2 * precision * recall / np.maximum(precision + recall,
1e-20)
pred_positive[pred_score <= thr] = 0
class_correct = (pred_positive & gt_positive).sum(0)
precision = class_correct / np.maximum(pred_positive.sum(0), 1.) * 100
recall = class_correct / np.maximum(gt_positive.sum(0), 1.) * 100
f1_score = 2 * precision * recall / np.maximum(
precision + recall,
torch.finfo(torch.float32).eps)
if average_mode == 'macro':
precision = float(precision.mean())
recall = float(recall.mean())
f1_score = float(f1_score.mean())
elif average_mode == 'none':
precision = precision.detach().cpu().numpy()
recall = recall.detach().cpu().numpy()
f1_score = f1_score.detach().cpu().numpy()
else:
raise ValueError(f'Unsupport type of averaging {average_mode}.')
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1_score)
Expand Down
6 changes: 2 additions & 4 deletions mmcls/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,11 @@ def evaluate(self,
if isinstance(thrs, tuple):
for key, values in eval_results_.items():
eval_results.update({
f'{key}_thr_{thr:.2f}': value.item()
f'{key}_thr_{thr:.2f}': value
for thr, value in zip(thrs, values)
})
else:
eval_results.update(
{k: v.item()
for k, v in eval_results_.items()})
eval_results.update(eval_results_)

if 'support' in metrics:
support_value = support(
Expand Down
26 changes: 16 additions & 10 deletions mmcls/models/losses/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
# Only prediction values larger than thr are counted as correct
_correct_k = correct_k & (pred_score[:, :k] > thr)
_correct_k = np.logical_or.reduce(_correct_k, axis=1)
res_thr.append(_correct_k.sum() * 100. / num)
res_thr.append((_correct_k.sum() * 100. / num).item())
if res_single:
res.append(res_thr[0])
else:
Expand Down Expand Up @@ -65,7 +65,7 @@ def accuracy_torch(pred, target, topk=(1, ), thrs=0.):
# Only prediction values larger than thr are counted as correct
_correct = correct & (pred_score.t() > thr)
correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)
res_thr.append(correct_k.mul_(100. / num))
res_thr.append((correct_k.mul_(100. / num)).item())
if res_single:
res.append(res_thr[0])
else:
Expand Down Expand Up @@ -99,14 +99,20 @@ def accuracy(pred, target, topk=1, thrs=0.):
else:
return_single = False

if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
res = accuracy_torch(pred, target, topk, thrs)
elif isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
res = accuracy_numpy(pred, target, topk, thrs)
else:
raise TypeError(
f'pred and target should both be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.')
assert isinstance(pred, (torch.Tensor, np.ndarray)), \
f'The pred should be torch.Tensor or np.ndarray ' \
f'instead of {type(pred)}.'
assert isinstance(target, (torch.Tensor, np.ndarray)), \
f'The target should be torch.Tensor or np.ndarray ' \
f'instead of {type(target)}.'

# torch version is faster in most situations.
to_tensor = (lambda x: torch.from_numpy(x)
if isinstance(x, np.ndarray) else x)
pred = to_tensor(pred)
target = to_tensor(target)

res = accuracy_torch(pred, target, topk, thrs)

return res[0] if return_single else res

Expand Down
15 changes: 11 additions & 4 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

import pytest
import torch

from mmcls.core import average_performance, mAP
from mmcls.models.losses.accuracy import Accuracy
from mmcls.models.losses.accuracy import Accuracy, accuracy_numpy


def test_mAP():
Expand Down Expand Up @@ -77,10 +79,15 @@ def test_accuracy():
assert compute_acc(pred_array, target_array)[0] == acc_top1

compute_acc = Accuracy(topk=(1, 2))
assert compute_acc(pred_tensor, target_tensor)[0] == acc_top1
assert compute_acc(pred_tensor, target_array)[0] == acc_top1
assert compute_acc(pred_tensor, target_tensor)[1] == acc_top2
assert compute_acc(pred_array, target_array)[0] == acc_top1
assert compute_acc(pred_array, target_array)[1] == acc_top2

with pytest.raises(TypeError):
compute_acc(pred_tensor, target_array)
with pytest.raises(AssertionError):
compute_acc(pred_tensor, 'other_type')

# test accuracy_numpy
compute_acc = partial(accuracy_numpy, topk=(1, 2))
assert compute_acc(pred_array, target_array)[0] == acc_top1
assert compute_acc(pred_array, target_array)[1] == acc_top2

0 comments on commit e3cf188

Please sign in to comment.