diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 1291dac25a..d8fd815ce9 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -83,6 +83,7 @@ def __init__( self.set_random_state(seed=seed) tarfile_name = os.path.join(root_dir, self.compressed_file_name) dataset_dir = os.path.join(root_dir, self.dataset_folder_name) + self.num_class = 0 if download: download_and_extract(self.resource, tarfile_name, root_dir, self.md5) @@ -98,6 +99,10 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: self.rann = self.R.random() + def get_num_classes(self) -> int: + """Get number of classes.""" + return self.num_class + def _generate_data_list(self, dataset_dir: str) -> List[Dict]: """ Raises: @@ -105,20 +110,22 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: """ class_names = sorted((x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x)))) - num_class = len(class_names) + self.num_class = len(class_names) image_files = [ [ os.path.join(dataset_dir, class_names[i], x) for x in os.listdir(os.path.join(dataset_dir, class_names[i])) ] - for i in range(num_class) + for i in range(self.num_class) ] - num_each = [len(image_files[i]) for i in range(num_class)] + num_each = [len(image_files[i]) for i in range(self.num_class)] image_files_list = [] image_class = [] - for i in range(num_class): + class_name = [] + for i in range(self.num_class): image_files_list.extend(image_files[i]) image_class.extend([i] * num_each[i]) + class_name.extend([class_names[i]] * num_each[i]) num_total = len(image_class) data = [] @@ -138,7 +145,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: raise ValueError( f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' ) - data.append({"image": image_files_list[i], "label": image_class[i]}) + data.append({"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]}) return data diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 28263e0722..0887734a7c 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -52,6 +52,7 @@ def _test_dataset(dataset): # testing from data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) + data.get_num_classes() _test_dataset(data) data = MedNISTDataset(root_dir=testing_dir, section="test", download=False) self.assertTupleEqual(data[0]["image"].shape, (64, 64))