Skip to content

Commit 90dd2cc

Browse files
thibaultdvxpre-commit-ci[bot]KumoLiuericspod
authored andcommitted
Add Average Precision to metrics (Project-MONAI#8089)
Fixes Project-MONAI#8085. ### Description Average Precision is very similar to ROCAUC, so I was very much inspired by the ROCAUC implementation. More precisely, I created: - `AveragePrecisionMetric` and `compute_average_precision` in `monai.metrics`, - a handler called `AveragePrecision` in `monai.handlers`, - three unittest modules: `test_compute_average_precision.py`, `test_handler_average_precision.py` and `test_handler_average_precision_dist.py`. I also modified the docs to mention Average Precision. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: thibaultdvx <thibault.devarax@icm-institute.org> Signed-off-by: Thibault de Varax <154365476+thibaultdvx@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Can-Zhao <volcanofly@gmail.com>
1 parent 17440c8 commit 90dd2cc

10 files changed

+499
-1
lines changed

docs/source/handlers.rst

+6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ ROC AUC metrics handler
5353
:members:
5454

5555

56+
Average Precision metric handler
57+
--------------------------------
58+
.. autoclass:: AveragePrecision
59+
:members:
60+
61+
5662
Confusion matrix metrics handler
5763
--------------------------------
5864
.. autoclass:: ConfusionMatrix

docs/source/metrics.rst

+7
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Metrics
8080
.. autoclass:: ROCAUCMetric
8181
:members:
8282

83+
`Average Precision`
84+
-------------------
85+
.. autofunction:: compute_average_precision
86+
87+
.. autoclass:: AveragePrecisionMetric
88+
:members:
89+
8390
`Confusion matrix`
8491
------------------
8592
.. autofunction:: get_confusion_matrix

monai/handlers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from .average_precision import AveragePrecision
1415
from .checkpoint_loader import CheckpointLoader
1516
from .checkpoint_saver import CheckpointSaver
1617
from .classification_saver import ClassificationSaver

monai/handlers/average_precision.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections.abc import Callable
15+
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
17+
from monai.metrics import AveragePrecisionMetric
18+
from monai.utils import Average
19+
20+
21+
class AveragePrecision(IgniteMetricHandler):
22+
"""
23+
Computes Average Precision (AP).
24+
accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.
25+
26+
Args:
27+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
28+
Type of averaging performed if not binary classification. Defaults to ``"macro"``.
29+
30+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
31+
This does not take label imbalance into account.
32+
- ``"weighted"``: calculate metrics for each label, and find their average,
33+
weighted by support (the number of true instances for each label).
34+
- ``"micro"``: calculate metrics globally by considering each element of the label
35+
indicator matrix as a label.
36+
- ``"none"``: the scores for each class are returned.
37+
38+
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
39+
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
40+
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
41+
`engine.state` and `output_transform` inherit from the ignite concept:
42+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
43+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
44+
45+
Note:
46+
Average Precision expects y to be comprised of 0's and 1's.
47+
y_pred must either be probability estimates or confidence values.
48+
49+
"""
50+
51+
def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
52+
metric_fn = AveragePrecisionMetric(average=Average(average))
53+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)

monai/metrics/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
15+
from .average_precision import AveragePrecisionMetric, compute_average_precision
1516
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
1617
from .cumulative_average import CumulativeAverage
1718
from .f_beta_score import FBetaScore

monai/metrics/average_precision.py

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
from typing import TYPE_CHECKING, cast
16+
17+
import numpy as np
18+
19+
if TYPE_CHECKING:
20+
import numpy.typing as npt
21+
22+
import torch
23+
24+
from monai.utils import Average, look_up_option
25+
26+
from .metric import CumulativeIterationMetric
27+
28+
29+
class AveragePrecisionMetric(CumulativeIterationMetric):
30+
"""
31+
Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
32+
imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.
33+
It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each
34+
threshold, with the increase in recall from the previous threshold used as the weight:
35+
36+
.. math::
37+
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
38+
:label: ap
39+
40+
where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.
41+
42+
Referring to: `sklearn.metrics.average_precision_score
43+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
44+
45+
The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.
46+
47+
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
48+
49+
Args:
50+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
51+
Type of averaging performed if not binary classification.
52+
Defaults to ``"macro"``.
53+
54+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
55+
This does not take label imbalance into account.
56+
- ``"weighted"``: calculate metrics for each label, and find their average,
57+
weighted by support (the number of true instances for each label).
58+
- ``"micro"``: calculate metrics globally by considering each element of the label
59+
indicator matrix as a label.
60+
- ``"none"``: the scores for each class are returned.
61+
62+
"""
63+
64+
def __init__(self, average: Average | str = Average.MACRO) -> None:
65+
super().__init__()
66+
self.average = average
67+
68+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
69+
return y_pred, y
70+
71+
def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:
72+
"""
73+
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
74+
This function reads the buffers and computes the Average Precision.
75+
76+
Args:
77+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
78+
Type of averaging performed if not binary classification. Defaults to `self.average`.
79+
80+
"""
81+
y_pred, y = self.get_buffer()
82+
# compute final value and do metric reduction
83+
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
84+
raise ValueError("y_pred and y must be PyTorch Tensor.")
85+
86+
return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)
87+
88+
89+
def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
90+
if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
91+
raise AssertionError("y and y_pred must be 1 dimension data with same length.")
92+
y_unique = y.unique()
93+
if len(y_unique) == 1:
94+
warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.")
95+
return float("nan")
96+
if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
97+
warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.")
98+
return float("nan")
99+
100+
n = len(y)
101+
indices = y_pred.argsort(descending=True)
102+
y = y[indices].cpu().numpy() # type: ignore[assignment]
103+
y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
104+
npos = ap = tmp_pos = 0.0
105+
106+
for i in range(n):
107+
y_i = cast(float, y[i])
108+
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
109+
tmp_pos += y_i
110+
else:
111+
tmp_pos += y_i
112+
npos += tmp_pos
113+
ap += tmp_pos * npos / (i + 1)
114+
tmp_pos = 0
115+
116+
return ap / npos
117+
118+
119+
def compute_average_precision(
120+
y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO
121+
) -> np.ndarray | float | npt.ArrayLike:
122+
"""Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
123+
imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.
124+
Referring to: `sklearn.metrics.average_precision_score
125+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
126+
127+
Args:
128+
y_pred: input data to compute, typical classification model output.
129+
the first dim must be batch, if multi-classes, it must be in One-Hot format.
130+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
131+
y: ground truth to compute AP metric, the first dim must be batch.
132+
if multi-classes, it must be in One-Hot format.
133+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
134+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
135+
Type of averaging performed if not binary classification.
136+
Defaults to ``"macro"``.
137+
138+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
139+
This does not take label imbalance into account.
140+
- ``"weighted"``: calculate metrics for each label, and find their average,
141+
weighted by support (the number of true instances for each label).
142+
- ``"micro"``: calculate metrics globally by considering each element of the label
143+
indicator matrix as a label.
144+
- ``"none"``: the scores for each class are returned.
145+
146+
Raises:
147+
ValueError: When ``y_pred`` dimension is not one of [1, 2].
148+
ValueError: When ``y`` dimension is not one of [1, 2].
149+
ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].
150+
151+
Note:
152+
Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.
153+
154+
"""
155+
y_pred_ndim = y_pred.ndimension()
156+
y_ndim = y.ndimension()
157+
if y_pred_ndim not in (1, 2):
158+
raise ValueError(
159+
f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
160+
)
161+
if y_ndim not in (1, 2):
162+
raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
163+
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
164+
y_pred = y_pred.squeeze(dim=-1)
165+
y_pred_ndim = 1
166+
if y_ndim == 2 and y.shape[1] == 1:
167+
y = y.squeeze(dim=-1)
168+
169+
if y_pred_ndim == 1:
170+
return _calculate(y_pred, y)
171+
172+
if y.shape != y_pred.shape:
173+
raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")
174+
175+
average = look_up_option(average, Average)
176+
if average == Average.MICRO:
177+
return _calculate(y_pred.flatten(), y.flatten())
178+
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
179+
ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
180+
if average == Average.NONE:
181+
return ap_values
182+
if average == Average.MACRO:
183+
return np.mean(ap_values)
184+
if average == Average.WEIGHTED:
185+
weights = [sum(y_) for y_ in y]
186+
return np.average(ap_values, weights=weights) # type: ignore[no-any-return]
187+
raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')

monai/utils/enums.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ class GridSamplePadMode(StrEnum):
213213

214214
class Average(StrEnum):
215215
"""
216-
See also: :py:class:`monai.metrics.rocauc.compute_roc_auc`
216+
See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` or
217+
:py:class:`monai.metrics.average_precision.compute_average_precision`
217218
"""
218219

219220
MACRO = "macro"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
import torch.distributed as dist
19+
20+
from monai.handlers import AveragePrecision
21+
from monai.transforms import Activations, AsDiscrete
22+
from tests.test_utils import DistCall, DistTestCase
23+
24+
25+
class TestHandlerAveragePrecision(unittest.TestCase):
26+
27+
def test_compute(self):
28+
ap_metric = AveragePrecision()
29+
act = Activations(softmax=True)
30+
to_onehot = AsDiscrete(to_onehot=2)
31+
32+
y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])]
33+
y = [torch.Tensor([0]), torch.Tensor([1])]
34+
y_pred = [act(p) for p in y_pred]
35+
y = [to_onehot(y_) for y_ in y]
36+
ap_metric.update([y_pred, y])
37+
38+
y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])]
39+
y = [torch.Tensor([0]), torch.Tensor([1])]
40+
y_pred = [act(p) for p in y_pred]
41+
y = [to_onehot(y_) for y_ in y]
42+
43+
ap_metric.update([y_pred, y])
44+
45+
ap = ap_metric.compute()
46+
np.testing.assert_allclose(0.8333333, ap)
47+
48+
49+
class DistributedAveragePrecision(DistTestCase):
50+
51+
@DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
52+
def test_compute(self):
53+
ap_metric = AveragePrecision()
54+
act = Activations(softmax=True)
55+
to_onehot = AsDiscrete(to_onehot=2)
56+
57+
device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
58+
if dist.get_rank() == 0:
59+
y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)]
60+
y = [torch.tensor([0], device=device), torch.tensor([1], device=device)]
61+
62+
if dist.get_rank() == 1:
63+
y_pred = [
64+
torch.tensor([0.2, 0.1], device=device),
65+
torch.tensor([0.1, 0.5], device=device),
66+
torch.tensor([0.3, 0.4], device=device),
67+
]
68+
y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)]
69+
70+
y_pred = [act(p) for p in y_pred]
71+
y = [to_onehot(y_) for y_ in y]
72+
ap_metric.update([y_pred, y])
73+
74+
result = ap_metric.compute()
75+
np.testing.assert_allclose(0.7778, result, rtol=1e-4)
76+
77+
78+
if __name__ == "__main__":
79+
unittest.main()

0 commit comments

Comments
 (0)