Skip to content

Commit

Permalink
IoU: remove_bg -> ignore_index
Browse files Browse the repository at this point in the history
Fixes #2736

- Rename IoU metric argument from `remove_bg` -> `ignore_index`.
- Accept an optional int class index to ignore, instead of a bool and
  instead of always assuming the background class has index 0.
- If given, ignore the class index when computing the IoU output,
  regardless of reduction method.
  • Loading branch information
abrahambotros committed Sep 3, 2020
1 parent b0057d1 commit 0d5516c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 50 deletions.
14 changes: 7 additions & 7 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,20 +797,20 @@ class IoU(TensorMetric):

def __init__(
self,
ignore_index: Optional[int] = None,
not_present_score: float = 1.0,
num_classes: Optional[int] = None,
remove_bg: bool = False,
reduction: str = 'elementwise_mean'
):
"""
Args:
ignore_index: optional int specifying a target class to ignore. If given, this class index does not
contribute to the returned score, regardless of reduction method. Has no effect if given an int that is
not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target.
By default, no index is ignored, and all classes are used.
not_present_score: score to use for a class, if no instance of that class was present in either pred or
target
num_classes: Optionally specify the number of classes
remove_bg: Flag to state whether a background class has been included
within input parameters. If true, will remove background class. If
false, return IoU over all classes.
Assumes that background is '0' class in input tensor
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
Expand All @@ -819,9 +819,9 @@ def __init__(
- sum: add elements
"""
super().__init__(name='iou')
self.ignore_index = ignore_index
self.not_present_score = not_present_score
self.num_classes = num_classes
self.remove_bg = remove_bg
self.reduction = reduction

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
Expand All @@ -832,8 +832,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
return iou(
pred=y_pred,
target=y_true,
ignore_index=self.ignore_index,
not_present_score=self.not_present_score,
num_classes=self.num_classes,
remove_bg=self.remove_bg,
reduction=self.reduction,
)
33 changes: 20 additions & 13 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,9 +963,9 @@ def dice_score(
def iou(
pred: torch.Tensor,
target: torch.Tensor,
ignore_index: Optional[int] = None,
not_present_score: float = 1.0,
num_classes: Optional[int] = None,
remove_bg: bool = False,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
"""
Expand All @@ -974,12 +974,12 @@ def iou(
Args:
pred: Tensor containing predictions
target: Tensor containing targets
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no
index is ignored, and all classes are used.
not_present_score: score to use for a class, if no instance of that class was present in either pred or target
num_classes: Optionally specify the number of classes
remove_bg: Flag to state whether a background class has been included
within input parameters. If true, will remove background class. If
false, return IoU over all classes
Assumes that background is '0' class in input tensor
reduction: a method to reduce metric score over labels (default: takes the mean)
Available reduction methods:
Expand All @@ -1002,15 +1002,15 @@ def iou(
"""
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)

# Determine minimum class index we will be evaluating. If using the background, then this is 0; otherwise, if
# removing background, use 1.
min_class_idx = 1 if remove_bg else 0

tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)

scores = torch.zeros(num_classes - min_class_idx, device=pred.device, dtype=torch.float32)
scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32)

for class_idx in range(num_classes):
# Skip this class if its index is being ignored.
if class_idx == ignore_index:
continue

for class_idx in range(min_class_idx, num_classes):
tp = tps[class_idx]
fp = fps[class_idx]
fn = fns[class_idx]
Expand All @@ -1019,11 +1019,18 @@ def iou(
# If this class is not present in either the target (no support) or the pred (no true or false positives), then
# use the not_present_score for this class.
if sup + tp + fp == 0:
scores[class_idx - min_class_idx] = not_present_score
scores[class_idx] = not_present_score
continue

denom = tp + fp + fn
score = tp.to(torch.float) / denom
scores[class_idx - min_class_idx] = score
scores[class_idx] = score

# Remove the ignored class index from the scores.
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
scores = torch.cat([
scores[:ignore_index],
scores[ignore_index + 1:],
])

return reduce(scores, reduction=reduction)
78 changes: 51 additions & 27 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,23 +326,23 @@ def test_dice_score(pred, target, expected):
assert score == expected


@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [
pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])),
pytest.param(False, 'none', True, torch.Tensor([1, 1])),
pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])),
pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])),
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
])
def test_iou(half_ones, reduction, remove_bg, expected):
def test_iou(half_ones, reduction, ignore_index, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
iou_val = iou(
pred=pred,
target=target,
remove_bg=remove_bg,
ignore_index=ignore_index,
reduction=reduction,
)
assert torch.allclose(iou_val, expected, atol=1e-9)
Expand All @@ -351,46 +351,70 @@ def test_iou(half_ones, reduction, remove_bg, expected):
# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
# `not_present_score`.
@pytest.mark.parametrize(['pred', 'target', 'not_present_score', 'num_classes', 'remove_bg', 'expected'], [
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'not_present_score', 'num_classes', 'expected'], [
# Note that -1 is used as the not_present_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is not present.
pytest.param([0], [0], -1., 2, False, [1., -1.]),
pytest.param([0, 0], [0, 0], -1., 2, False, [1., -1.]),
pytest.param([0], [0], None, -1., 2, [1., -1.]),
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
# not_present_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], -1., 1, False, [1.]),
pytest.param([0], [0], None, -1., 1, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is not present.
pytest.param([1], [1], -1., 2, False, [-1., 1.]),
pytest.param([1, 1], [1, 1], -1., 2, False, [-1., 1.]),
# When background removed, class 0 does not get a score (not even the not_present_score).
pytest.param([1], [1], -1., 2, True, [1.0]),
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
# When 0 index ignored, class 0 does not get a score (not even the not_present_score).
pytest.param([1], [1], 0, -1., 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get not_present_score.
pytest.param([0, 2], [0, 2], -1., 3, False, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], -1., 3, False, [1., -1., 1.]),
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get not_present_score.
pytest.param([0, 1], [0, 1], -1., 3, False, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], -1., 3, False, [1., 1., -1.]),
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get not_present_score), class
# 2 is not present.
pytest.param([0, 1], [0, 0], -1., 3, False, [0.5, 0., -1.]),
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get not_present_score), class
# 2 is not present.
pytest.param([0, 0], [0, 1], -1., 3, False, [0.5, 0., -1.]),
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
# Sanity checks with not_present_score of 1.0.
pytest.param([0, 2], [0, 2], 1.0, 3, False, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 1.0, 3, True, [1., 1.]),
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
])
def test_iou_not_present_score(pred, target, not_present_score, num_classes, remove_bg, expected):
def test_iou_not_present_score(pred, target, ignore_index, not_present_score, num_classes, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
not_present_score=not_present_score,
num_classes=num_classes,
remove_bg=remove_bg,
reduction='none',
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))


@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
])
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
6 changes: 3 additions & 3 deletions tests/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def test_dice_coefficient(include_background):
assert isinstance(dice, torch.Tensor)


@pytest.mark.parametrize('remove_bg', [True, False])
def test_iou(remove_bg):
iou = IoU(remove_bg=remove_bg)
@pytest.mark.parametrize('ignore_index', [0, 1, None])
def test_iou(ignore_index):
iou = IoU(ignore_index=ignore_index)
assert iou.name == 'iou'

score = iou(torch.randint(0, 1, (10, 25, 25)),
Expand Down

0 comments on commit 0d5516c

Please sign in to comment.