From 9f15b01afdc1dc381021ae69fdadd658be598886 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Wed, 9 Nov 2022 00:18:30 +0800 Subject: [PATCH] 5344 implement pq metric (#5377) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #5344 . ### Description This PR implements the metric of Panoptic Quality. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang Co-authored-by: Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Dženan Zukić Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> --- docs/source/handlers.rst | 6 + docs/source/metrics.rst | 7 + monai/handlers/__init__.py | 1 + monai/handlers/panoptic_quality.py | 67 ++++++ monai/metrics/__init__.py | 1 + monai/metrics/panoptic_quality.py | 292 +++++++++++++++++++++++++ monai/metrics/utils.py | 37 ++++ tests/min_tests.py | 2 + tests/test_compute_panoptic_quality.py | 111 ++++++++++ tests/test_handler_panoptic_quality.py | 86 ++++++++ tests/test_hovernet.py | 4 +- 11 files changed, 612 insertions(+), 2 deletions(-) create mode 100644 monai/handlers/panoptic_quality.py create mode 100644 monai/metrics/panoptic_quality.py create mode 100644 tests/test_compute_panoptic_quality.py create mode 100644 tests/test_handler_panoptic_quality.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 06a1255d15..5b408cfa71 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -71,6 +71,12 @@ Surface distance metrics handler :members: +Panoptic Quality metrics handler +-------------------------------- +.. autoclass:: PanopticQuality + :members: + + Mean squared error metrics handler ---------------------------------- .. autoclass:: MeanSquaredError diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 9ba5fa0607..d8da890276 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -104,6 +104,13 @@ Metrics .. autoclass:: SurfaceDiceMetric :members: +`PanopticQualityMetric` +----------------------- +.. autofunction:: compute_panoptic_quality + +.. autoclass:: PanopticQualityMetric + :members: + `Mean squared error` -------------------- .. autoclass:: MSEMetric diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 81e1d8eb3f..9880e39817 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -26,6 +26,7 @@ from .metrics_saver import MetricsSaver from .mlflow_handler import MLFlowHandler from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler +from .panoptic_quality import PanopticQuality from .parameter_scheduler import ParamSchedulerHandler from .postprocessing import PostProcessing from .probability_maps import ProbMapProducer diff --git a/monai/handlers/panoptic_quality.py b/monai/handlers/panoptic_quality.py new file mode 100644 index 0000000000..d9e5beec59 --- /dev/null +++ b/monai/handlers/panoptic_quality.py @@ -0,0 +1,67 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Union + +from monai.handlers.ignite_metric import IgniteMetric +from monai.metrics import PanopticQualityMetric +from monai.utils import MetricReduction + + +class PanopticQuality(IgniteMetric): + """ + Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations. + """ + + def __init__( + self, + num_classes: int, + metric_name: str = "pq", + reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH, + match_iou_threshold: float = 0.5, + smooth_numerator: float = 1e-6, + output_transform: Callable = lambda x: x, + save_details: bool = True, + ) -> None: + """ + + Args: + num_classes: number of classes. The number should not count the background. + metric_name: output metric. The value can be "pq", "sq" or "rq". + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. + match_iou_threshold: IOU threshould to determine the pairing between `y_pred` and `y`. Usually, + it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical. + If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the + maximal amout of unique pairing. + smooth_numerator: a small constant added to the numerator to avoid zero. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: panoptic quality of + every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. + + See also: + :py:meth:`monai.metrics.panoptic_quality.compute_panoptic_quality` + """ + metric_fn = PanopticQualityMetric( + num_classes=num_classes, + metric_name=metric_name, + reduction=reduction, + match_iou_threshold=match_iou_threshold, + smooth_numerator=smooth_numerator, + ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 4e0acbe603..ff5eb4881a 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -18,6 +18,7 @@ from .meandice import DiceMetric, compute_dice, compute_meandice from .meaniou import MeanIoU, compute_iou, compute_meaniou from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric +from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric, SSIMMetric from .rocauc import ROCAUCMetric, compute_roc_auc from .surface_dice import SurfaceDiceMetric, compute_surface_dice diff --git a/monai/metrics/panoptic_quality.py b/monai/metrics/panoptic_quality.py new file mode 100644 index 0000000000..4bf87188d5 --- /dev/null +++ b/monai/metrics/panoptic_quality.py @@ -0,0 +1,292 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Sequence, Union + +import torch + +from monai.metrics.metric import CumulativeIterationMetric +from monai.metrics.utils import do_metric_reduction, remap_instance_id +from monai.utils import MetricReduction, ensure_tuple, optional_import + +linear_sum_assignment, _ = optional_import("scipy.optimize", name="linear_sum_assignment") + +__all__ = ["PanopticQualityMetric", "compute_panoptic_quality"] + + +class PanopticQualityMetric(CumulativeIterationMetric): + """ + Compute Panoptic Quality between two instance segmentation masks. If specifying `metric_name` to "SQ" or "RQ", + Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead. + + Panoptic Quality is a metric used in panoptic segmentation tasks. This task unifies the typically distinct tasks + of semantic segmentation (assign a class label to each pixel) and + instance segmentation (detect and segment each object instance). Compared with semantic segmentation, panoptic + segmentation distinguish different instances that belong to same class. + Compared with instance segmentation, panoptic segmentation does not allow overlap and only one semantic label and + one instance id can be assigned to each pixel. + Please refer to the following paper for more details: + https://openaccess.thecvf.com/content_CVPR_2019/papers/Kirillov_Panoptic_Segmentation_CVPR_2019_paper.pdf + + This class also refers to the following implementation: + https://github.com/TissueImageAnalytics/CoNIC + + Args: + num_classes: number of classes. The number should not count the background. + metric_name: output metric. The value can be "pq", "sq" or "rq". + Except for input only one metric, multiple metrics are also supported via input a sequence of metric names + such as ("pq", "sq", "rq"). If input a sequence, a list of results with the same order + as the input names will be returned. + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. + match_iou_threshold: IOU threshould to determine the pairing between `y_pred` and `y`. Usually, + it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical. + If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the + maximal amout of unique pairing. + smooth_numerator: a small constant added to the numerator to avoid zero. + + """ + + def __init__( + self, + num_classes: int, + metric_name: Union[Sequence[str], str] = "pq", + reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH, + match_iou_threshold: float = 0.5, + smooth_numerator: float = 1e-6, + ) -> None: + super().__init__() + self.num_classes = num_classes + self.reduction = reduction + self.match_iou_threshold = match_iou_threshold + self.smooth_numerator = smooth_numerator + self.metric_name = ensure_tuple(metric_name) + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore + """ + Args: + y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the + second channel represent the instance predictions and classification predictions respectively. + y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the + second channel represent the instance labels and classification labels respectively. + Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`, + where 0 represents the background. + + Raises: + ValueError: when `y_pred` and `y` have different shapes. + ValueError: when `y_pred` and `y` have != 2 channels. + ValueError: when `y_pred` and `y` have != 4 dimensions. + + """ + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") + + if y_pred.shape[1] != 2: + raise ValueError( + f"for panoptic quality calculation, only 2 channels input is supported, got {y_pred.shape[1]}." + ) + + dims = y_pred.ndimension() + if dims != 4: + raise ValueError(f"y_pred should have 4 dimensions (batch, 2, h, w), got {dims}.") + + batch_size = y_pred.shape[0] + + outputs = torch.zeros([batch_size, self.num_classes, 4], device=y_pred.device) + + for b in range(batch_size): + true_instance, pred_instance = y[b, 0], y_pred[b, 0] + true_class, pred_class = y[b, 1], y_pred[b, 1] + for c in range(self.num_classes): + pred_instance_c = (pred_class == c + 1) * pred_instance + true_instance_c = (true_class == c + 1) * true_instance + + outputs[b, c] = compute_panoptic_quality( + pred=pred_instance_c, + gt=true_instance_c, + remap=True, + match_iou_threshold=self.match_iou_threshold, + output_confusion_matrix=True, + ) + + return outputs + + def aggregate(self, reduction: Union[MetricReduction, str, None] = None): + """ + Execute reduction logic for the output of `compute_panoptic_quality`. + + Args: + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. + + """ + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be PyTorch Tensor.") + + # do metric reduction + f, _ = do_metric_reduction(data, reduction or self.reduction) + tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3] + results = [] + for metric_name in self.metric_name: + metric_name = _check_panoptic_metric_name(metric_name) + if metric_name == "rq": + results.append(tp / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator)) + elif metric_name == "sq": + results.append(iou_sum / (tp + self.smooth_numerator)) + else: + results.append(iou_sum / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator)) + + return results[0] if len(results) == 1 else results + + +def compute_panoptic_quality( + pred: torch.Tensor, + gt: torch.Tensor, + metric_name: str = "pq", + remap: bool = True, + match_iou_threshold: float = 0.5, + smooth_numerator: float = 1e-6, + output_confusion_matrix: bool = False, +): + """Computes Panoptic Quality (PQ). If specifying `metric_name` to "SQ" or "RQ", + Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead. + + In addition, if `output_confusion_matrix` is True, the function will return a tensor with shape 4, which + represents the true positive, false positive, false negative and the sum of iou. These four values are used to + calculate PQ, and returning them directly enables further calculation over all images. + + Args: + pred: input data to compute, it must be in the form of HW and have integer type. + gt: ground truth. It must have the same shape as `pred` and have integer type. + metric_name: output metric. The value can be "pq", "sq" or "rq". + remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id. + match_iou_threshold: IOU threshould to determine the pairing between `pred` and `gt`. Usually, + it should >= 0.5, the pairing between instances of `pred` and `gt` are identical. + If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the + maximal amout of unique pairing. + smooth_numerator: a small constant added to the numerator to avoid zero. + + Raises: + ValueError: when `pred` and `gt` have different shapes. + ValueError: when `match_iou_threshold` <= 0.0 or > 1.0. + + """ + + if gt.shape != pred.shape: + raise ValueError(f"pred and gt should have same shapes, got {pred.shape} and {gt.shape}.") + if match_iou_threshold <= 0.0 or match_iou_threshold > 1.0: + raise ValueError(f"'match_iou_threshold' should be within (0, 1], got: {match_iou_threshold}.") + + gt = gt.int() + pred = pred.int() + + if remap is True: + gt = remap_instance_id(gt) + pred = remap_instance_id(pred) + + pairwise_iou, true_id_list, pred_id_list = _get_pairwise_iou(pred, gt, device=pred.device) + paired_iou, paired_true, paired_pred = _get_paired_iou( + pairwise_iou, match_iou_threshold, device=pairwise_iou.device + ) + + unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] + unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] + + tp, fp, fn = len(paired_true), len(unpaired_pred), len(unpaired_true) + iou_sum = paired_iou.sum() + + if output_confusion_matrix: + return torch.as_tensor([tp, fp, fn, iou_sum], device=pred.device) + + metric_name = _check_panoptic_metric_name(metric_name) + if metric_name == "rq": + return torch.as_tensor(tp / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device) + if metric_name == "sq": + return torch.as_tensor(iou_sum / (tp + smooth_numerator), device=pred.device) + return torch.as_tensor(iou_sum / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device) + + +def _get_id_list(gt: torch.Tensor): + id_list = list(gt.unique()) + # ensure id 0 is included + if 0 not in id_list: + id_list.insert(0, torch.tensor(0).int()) + + return id_list + + +def _get_pairwise_iou(pred: torch.Tensor, gt: torch.Tensor, device: Union[str, torch.device] = "cpu"): + pred_id_list = _get_id_list(pred) + true_id_list = _get_id_list(gt) + + pairwise_iou = torch.zeros([len(true_id_list) - 1, len(pred_id_list) - 1], dtype=torch.float, device=device) + true_masks: List[torch.Tensor] = [] + pred_masks: List[torch.Tensor] = [] + + for t in true_id_list[1:]: + t_mask = torch.as_tensor(gt == t, device=device).int() + true_masks.append(t_mask) + + for p in pred_id_list[1:]: + p_mask = torch.as_tensor(pred == p, device=device).int() + pred_masks.append(p_mask) + + for true_id in range(1, len(true_id_list)): + t_mask = true_masks[true_id - 1] + pred_true_overlap = pred[t_mask > 0] + pred_true_overlap_id = list(pred_true_overlap.unique()) + for pred_id in pred_true_overlap_id: + if pred_id == 0: + continue + p_mask = pred_masks[pred_id - 1] + total = (t_mask + p_mask).sum() + inter = (t_mask * p_mask).sum() + iou = inter / (total - inter) + pairwise_iou[true_id - 1, pred_id - 1] = iou + + return pairwise_iou, true_id_list, pred_id_list + + +def _get_paired_iou( + pairwise_iou: torch.Tensor, match_iou_threshold: float = 0.5, device: Union[str, torch.device] = "cpu" +): + if match_iou_threshold >= 0.5: + pairwise_iou[pairwise_iou <= match_iou_threshold] = 0.0 + paired_true, paired_pred = torch.nonzero(pairwise_iou)[:, 0], torch.nonzero(pairwise_iou)[:, 1] + paired_iou = pairwise_iou[paired_true, paired_pred] + paired_true += 1 + paired_pred += 1 + + return paired_iou, paired_true, paired_pred + + pairwise_iou = pairwise_iou.cpu().numpy() + paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) + paired_iou = pairwise_iou[paired_true, paired_pred] + paired_true = torch.as_tensor(list(paired_true[paired_iou > match_iou_threshold] + 1), device=device) + paired_pred = torch.as_tensor(list(paired_pred[paired_iou > match_iou_threshold] + 1), device=device) + paired_iou = paired_iou[paired_iou > match_iou_threshold] + + return paired_iou, paired_true, paired_pred + + +def _check_panoptic_metric_name(metric_name: str): + metric_name = metric_name.replace(" ", "_") + metric_name = metric_name.lower() + if metric_name in ["panoptic_quality", "pq"]: + return "pq" + if metric_name in ["segmentation_quality", "sq"]: + return "sq" + if metric_name in ["recognition_quality", "rq"]: + return "rq" + raise ValueError(f"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.") diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 324d424c5c..0c06c7768a 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -217,3 +217,40 @@ def is_binary_tensor(input: torch.Tensor, name: str): raise ValueError(f"{name} must be of type PyTorch Tensor.") if not torch.all(input.byte() == input) or input.max() > 1 or input.min() < 0: warnings.warn(f"{name} should be a binarized tensor.") + + +def remap_instance_id(pred: torch.Tensor, by_size: bool = False): + """ + This function is used to rename all instance id of `pred`, so that the id is + contiguous. + For example: all ids of the input can be [0, 1, 2] rather than [0, 2, 5]. + This function is helpful for calculating metrics like Panoptic Quality (PQ). + The implementation refers to: + + https://github.com/vqdang/hover_net + + Args: + pred: segmentation predictions in the form of torch tensor. Each + value of the tensor should be an integer, and represents the prediction of its corresponding instance id. + by_size: if True, larget instance will be assigned a smaller id. + + """ + pred_id = list(pred.unique()) + # the original implementation has the limitation that if there is no 0 in pred, error will happen + pred_id = [i for i in pred_id if i != 0] + + if len(pred_id) == 0: + return pred + if by_size is True: + instance_size = [] + for instance_id in pred_id: + instance_size.append((pred == instance_id).sum()) + + pair_data = zip(pred_id, instance_size) + pair_list = sorted(pair_data, key=lambda x: x[1], reverse=True) # type: ignore + pred_id, _ = zip(*pair_list) + + new_pred = torch.zeros_like(pred, dtype=torch.int) + for idx, instance_id in enumerate(pred_id): + new_pred[pred == instance_id] = idx + 1 + return new_pred diff --git a/tests/min_tests.py b/tests/min_tests.py index 5bbc086adf..7c06d21374 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -39,6 +39,7 @@ def run_testsuit(): "test_check_missing_files", "test_compute_ho_ver_maps", "test_compute_ho_ver_maps_d", + "test_compute_panoptic_quality", "test_contrastive_loss", "test_csv_dataset", "test_csv_iterable_dataset", @@ -75,6 +76,7 @@ def run_testsuit(): "test_handler_hausdorff_distance", "test_handler_lr_scheduler", "test_handler_mean_dice", + "test_handler_panoptic_quality", "test_handler_mean_iou", "test_handler_metrics_saver", "test_handler_metrics_saver_dist", diff --git a/tests/test_compute_panoptic_quality.py b/tests/test_compute_panoptic_quality.py new file mode 100644 index 0000000000..cf5d0deb2a --- /dev/null +++ b/tests/test_compute_panoptic_quality.py @@ -0,0 +1,111 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import List + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import PanopticQualityMetric, compute_panoptic_quality +from tests.utils import SkipIfNoModule + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# TEST_FUNC_CASE related cases are used to test for single image with HW input shape + +sample_1 = torch.randint(low=0, high=5, size=(64, 64), device=_device) +sample_2_pred = torch.as_tensor([[0, 1, 1, 1], [0, 0, 0, 0], [2, 0, 3, 3], [4, 2, 2, 0]], device=_device) +sample_2_pred_need_remap = torch.as_tensor([[0, 7, 7, 7], [0, 0, 0, 0], [1, 0, 8, 8], [9, 1, 1, 0]], device=_device) +sample_2_gt = torch.as_tensor([[1, 1, 2, 1], [0, 0, 0, 0], [1, 3, 0, 0], [4, 3, 3, 3]], device=_device) +# if pred == gt, result should be 1 +TEST_FUNC_CASE_1 = [{"pred": sample_1, "gt": sample_1, "match_iou_threshold": 0.99}, 1.0] + +# test sample_2 when match_iou_threshold = 0.5 +TEST_FUNC_CASE_2 = [{"pred": sample_2_pred, "gt": sample_2_gt, "match_iou_threshold": 0.5}, 0.25] +# test sample_2 when match_iou_threshold = 0.3, metric_name = "sq" +TEST_FUNC_CASE_3 = [{"pred": sample_2_pred, "gt": sample_2_gt, "metric_name": "sq", "match_iou_threshold": 0.3}, 0.6] +# test sample_2 when match_iou_threshold = 0.3, pred has different order, metric_name = "RQ" +TEST_FUNC_CASE_4 = [ + {"pred": sample_2_pred_need_remap, "gt": sample_2_gt, "metric_name": "RQ", "match_iou_threshold": 0.3}, + 0.75, +] + +# TEST_CLS_CASE related cases are used to test the PanopticQualityMetric with B2HW input +sample_3_pred = torch.as_tensor( + [ + [[[2, 0, 1], [2, 1, 1], [0, 1, 1]], [[0, 1, 3], [0, 0, 0], [1, 2, 1]]], + [[[1, 1, 1], [3, 2, 0], [3, 2, 1]], [[1, 1, 3], [3, 1, 1], [0, 3, 0]]], + ], + device=_device, +) + +sample_3_gt = torch.as_tensor( + [ + [[[2, 0, 0], [2, 0, 0], [2, 2, 3]], [[3, 3, 3], [3, 2, 1], [2, 2, 3]]], + [[[1, 1, 1], [0, 0, 3], [0, 0, 3]], [[0, 1, 3], [2, 1, 0], [3, 0, 3]]], + ], + device=_device, +) + +# test sample_3, num_classes = 3, match_iou_threshold = 0.5 +TEST_CLS_CASE_1 = [{"num_classes": 3, "match_iou_threshold": 0.5}, sample_3_pred, sample_3_gt, (0.0, 0.0, 0.25)] + +# test sample_3, num_classes = 3, match_iou_threshold = 0.3 +TEST_CLS_CASE_2 = [{"num_classes": 3, "match_iou_threshold": 0.3}, sample_3_pred, sample_3_gt, (0.25, 0.5, 0.25)] + +# test sample_3, num_classes = 4, match_iou_threshold = 0.3, metric_name = "segmentation_quality" +TEST_CLS_CASE_3 = [ + {"num_classes": 4, "match_iou_threshold": 0.3, "metric_name": "segmentation_quality"}, + sample_3_pred, + sample_3_gt, + (0.5, 0.5, 1.0, 0.0), +] + +# test sample_3, num_classes = 3, match_iou_threshold = 0.4, reduction = "none", metric_name = "Recognition Quality" +TEST_CLS_CASE_4 = [ + {"num_classes": 3, "reduction": "none", "match_iou_threshold": 0.4, "metric_name": "Recognition Quality"}, + sample_3_pred, + sample_3_gt, + [[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]], +] + +# test sample_3, num_classes = 3, match_iou_threshold = 0.4, reduction = "none", multiple metrics +TEST_CLS_CASE_5 = [ + {"num_classes": 3, "reduction": "none", "match_iou_threshold": 0.4, "metric_name": ["Recognition Quality", "pq"]}, + sample_3_pred, + sample_3_gt, + [torch.as_tensor([[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]]), torch.as_tensor([[0.0, 0.5, 0.0], [0.3333, 0.0, 0.4]])], +] + + +@SkipIfNoModule("scipy.optimize") +class TestPanopticQualityMetric(unittest.TestCase): + @parameterized.expand([TEST_FUNC_CASE_1, TEST_FUNC_CASE_2, TEST_FUNC_CASE_3, TEST_FUNC_CASE_4]) + def test_value(self, input_params, expected_value): + result = compute_panoptic_quality(**input_params) + np.testing.assert_allclose(result.cpu().detach().item(), expected_value, atol=1e-4) + + @parameterized.expand([TEST_CLS_CASE_1, TEST_CLS_CASE_2, TEST_CLS_CASE_3, TEST_CLS_CASE_4, TEST_CLS_CASE_5]) + def test_value_class(self, input_params, y_pred, y_gt, expected_value): + metric = PanopticQualityMetric(**input_params) + metric(y_pred, y_gt) + outputs = metric.aggregate() + if isinstance(outputs, List): + for output, value in zip(outputs, expected_value): + np.testing.assert_allclose(output.cpu().numpy(), np.asarray(value), atol=1e-4) + else: + np.testing.assert_allclose(outputs.cpu().numpy(), np.asarray(expected_value), atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_panoptic_quality.py b/tests/test_handler_panoptic_quality.py new file mode 100644 index 0000000000..a852ee929a --- /dev/null +++ b/tests/test_handler_panoptic_quality.py @@ -0,0 +1,86 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers import PanopticQuality, from_engine +from tests.utils import SkipIfNoModule, assert_allclose + +sample_1_pred = torch.as_tensor( + [[[0, 1, 1, 1], [0, 0, 5, 5], [2, 0, 3, 3], [2, 2, 2, 0]], [[0, 1, 1, 1], [0, 0, 0, 0], [2, 0, 3, 3], [4, 2, 2, 0]]] +) + +sample_1_gt = torch.as_tensor( + [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [0, 0, 1, 1], [2, 0, 3, 3], [4, 4, 4, 3]]] +) + +sample_2_pred = torch.as_tensor( + [[[3, 1, 1, 1], [3, 1, 1, 4], [3, 1, 4, 4], [3, 2, 2, 4]], [[0, 1, 1, 1], [2, 2, 2, 2], [2, 0, 0, 3], [4, 2, 2, 3]]] +) + +sample_2_gt = torch.as_tensor( + [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [2, 1, 1, 3], [2, 0, 0, 3], [4, 2, 2, 3]]] +) + +TEST_CASE_1 = [{"num_classes": 4, "output_transform": from_engine(["pred", "label"])}, [0.6667, 0.1538, 0.6667, 0.5714]] +TEST_CASE_2 = [ + { + "num_classes": 5, + "output_transform": from_engine(["pred", "label"]), + "metric_name": "rq", + "match_iou_threshold": 0.3, + }, + [0.6667, 0.7692, 0.8889, 0.5714, 0.0000], +] +TEST_CASE_3 = [ + { + "num_classes": 5, + "reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + "metric_name": "SQ", + "match_iou_threshold": 0.2, + }, + 0.8235, +] + + +@SkipIfNoModule("scipy.optimize") +class TestHandlerPanopticQuality(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_compute(self, input_params, expected_avg): + metric = PanopticQuality(**input_params) + # set up engine + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + metric.attach(engine=engine, name="panoptic_quality") + # test input a list of channel-first tensor + y_pred = [sample_1_pred, sample_2_pred] + y = [sample_1_gt, sample_2_gt] + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + y_pred = [sample_1_pred, sample_1_pred] + y = [sample_1_gt, sample_1_gt] + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["panoptic_quality"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hovernet.py b/tests/test_hovernet.py index 1a37a26ec4..389bb8c10f 100644 --- a/tests/test_hovernet.py +++ b/tests/test_hovernet.py @@ -66,7 +66,7 @@ def check_branch(branch, mode): if branch.decoderblock1.convf.kernel_size != (1, 1): return True for block in branch.decoderblock1: - if type(block) is HoVerNet._DenseLayerDecoder: + if isinstance(block, HoVerNet._DenseLayerDecoder): if block.layers.conv1.kernel_size != (1, 1) or block.layers.conv2.kernel_size != (ksize, ksize): return True @@ -76,7 +76,7 @@ def check_branch(branch, mode): return True for block in branch.decoderblock2: - if type(block) is HoVerNet._DenseLayerDecoder: + if isinstance(block, HoVerNet._DenseLayerDecoder): if block.layers.conv1.kernel_size != (1, 1) or block.layers.conv2.kernel_size != (ksize, ksize): return True