Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check class_id validity in DetectionDataset #1536

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def _load_sample_annotation(self, sample_id: int) -> Dict[str, Union[np.ndarray,
# Filter out classes that are not in self.class_inclusion_list
if self.class_inclusion_list is not None:
sample_annotations = self._sub_class_annotation(annotation=sample_annotations)

return sample_annotations

def _load_all_annotations(self, n_samples: int) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[str, Any]]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _load_annotation(self, sample_id: int) -> dict:

yolo_format_target, invalid_labels = self._parse_yolo_label_file(
label_file_path=label_path,
num_classes=len(self.all_classes_list),
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
ignore_invalid_labels=self.ignore_invalid_labels,
show_warnings=self.show_all_warnings,
)
Expand All @@ -210,13 +211,20 @@ def _load_annotation(self, sample_id: int) -> dict:
return annotation

@staticmethod
def _parse_yolo_label_file(label_file_path: str, ignore_invalid_labels: bool = True, show_warnings: bool = True) -> Tuple[np.ndarray, List[str]]:
def _parse_yolo_label_file(
label_file_path: str,
ignore_invalid_labels: bool = True,
show_warnings: bool = True,
num_classes: Optional[int] = None,
) -> Tuple[np.ndarray, List[str]]:
"""Parse a single label file in yolo format.

#TODO: Add support for additional fields (with ConcatenatedTensorFormat)
:param label_file_path: Path to the label file in yolo format.
:param ignore_invalid_labels: Whether to ignore labels that fail to be parsed. If True ignores and logs a warning, otherwise raise an error.
:param show_warnings: Whether to show the warnings or not.
:param num_classes: Number of classes in the dataset. Used to ensure that class ids are within the range [0, num_classes - 1].
If None, ignore.

:return:
- labels: np.ndarray of shape (n_labels, 5) in yolo format (LABEL_NORMALIZED_CXCYWH)
Expand All @@ -229,12 +237,21 @@ def _parse_yolo_label_file(label_file_path: str, ignore_invalid_labels: bool = T
for line in filter(lambda x: x != "\n", lines):
try:
label_id, cx, cw, w, h = line.split()
labels_yolo_format.append([int(label_id), float(cx), float(cw), float(w), float(h)])
label_id, cx, cw, w, h = int(label_id), float(cx), float(cw), float(w), float(h)

if (num_classes is not None) and (label_id not in range(num_classes)):
raise ValueError(f"`class_id={label_id}` invalid. It should be between (0 - {num_classes - 1}).")

labels_yolo_format.append([label_id, cx, cw, w, h])
except Exception as e:
error_msg = (
f"Line `{line}` of file {label_file_path} will be ignored because not cannot be parsed to (label, cx, cy, w, h) format, "
f"with Exception:\n{e}"
)
if ignore_invalid_labels:
invalid_labels.append(line)
if show_warnings:
logger.warning(f"Line `{line}` of file {label_file_path} will be ignored because not in LABEL_NORMALIZED_CXCYWH format: {e}")
logger.warning(error_msg)
else:
raise e
raise RuntimeError(error_msg)
return np.array(labels_yolo_format) if labels_yolo_format else np.zeros((0, 5)), invalid_labels
3 changes: 2 additions & 1 deletion tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.detection_utils_test import TestDetectionUtils
from tests.unit_tests.detection_dataset_test import DetectionDatasetTest
from tests.unit_tests.detection_dataset_test import DetectionDatasetTest, TestParseYoloLabelFile
from tests.unit_tests.export_detection_model_test import TestDetectionModelExport
from tests.unit_tests.export_onnx_test import TestModelsONNXExport
from tests.unit_tests.export_pose_estimation_model_test import TestPoseEstimationModelExport
Expand Down Expand Up @@ -136,6 +136,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestRepVGGBlock))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionDatasetTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestParseYoloLabelFile))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelsONNXExport))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaxBatchesLoopBreakTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestTrainingUtils))
Expand Down
30 changes: 29 additions & 1 deletion tests/unit_tests/detection_dataset_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import unittest
from unittest.mock import patch, mock_open
from pathlib import Path
from typing import Dict
import numpy as np

from torch.utils.data import DataLoader

from super_gradients import Trainer
from super_gradients.training import models, dataloaders
from super_gradients.training.dataloaders import coco2017_train_yolo_nas, get_data_loader
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.datasets import COCODetectionDataset, YoloDarknetFormatDetectionDataset
from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH
from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
Expand Down Expand Up @@ -194,5 +196,31 @@ def test_coco_detection_metrics_with_classwise_ap(self):
trainer.train(model=model, training_params=detection_train_params_yolox, train_loader=train_loader, valid_loader=valid_loader)


class TestParseYoloLabelFile(unittest.TestCase):
def setUp(self):
self.num_classes = 3
self.sample_data_valid = "0 0.5 0.5 0.1 0.1\n1 0.6 0.6 0.2 0.2"
self.sample_data_invalid_format = "0 0.5\n1 0.6 0.6 0.2 0.2"
self.sample_data_invalid_class = "-1 0.5 0.5 0.1 0.1\n3 0.6 0.6 0.2 0.2"

def test_valid_label(self):
with patch("builtins.open", mock_open(read_data=self.sample_data_valid)):
labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3)
np.testing.assert_array_equal(labels, np.array([[0, 0.5, 0.5, 0.1, 0.1], [1, 0.6, 0.6, 0.2, 0.2]]))
self.assertEqual(invalid_labels, [])

def test_invalid_format(self):
with patch("builtins.open", mock_open(read_data=self.sample_data_invalid_format)):
labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3)
np.testing.assert_array_equal(labels, np.array([[1, 0.6, 0.6, 0.2, 0.2]]))
self.assertEqual(invalid_labels, ["0 0.5\n"])

def test_invalid_class(self):
with patch("builtins.open", mock_open(read_data=self.sample_data_invalid_class)):
labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3)
self.assertEqual(len(labels), 0)
self.assertEqual(invalid_labels, ["-1 0.5 0.5 0.1 0.1\n", "3 0.6 0.6 0.2 0.2"])


if __name__ == "__main__":
unittest.main()