diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 6e9cbca809..a620e06216 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -681,14 +681,14 @@ class WSIReader(ImageReader): Read whole slide images and extract patches. Args: - backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "Tifffile". + backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile". level: the whole slide image level at which the image is extracted. (default=0) This is overridden if the level argument is provided in `get_data`. Note: - While "cucim" and "OpenSlide" backends both can load patches from large whole slide images - without loading the entire image into memory, "Tifffile" backend needs to load the entire image into memory - before extracting any patch; thus, memory consideration is needed when using "Tifffile" backend for + While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images + without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory + before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for patch extraction. """ @@ -765,19 +765,24 @@ def get_data( grid_shape: (row, columns) tuple define a grid to extract patches on that patch_size: (height, width) the size of extracted patches at the given level """ + # Verify inputs if level is None: - level = self.level - - if self.backend == "openslide" and size is None: - # the maximum size is set to WxH at the specified level - size = (img.shape[0] // (2 ** level) - location[0], img.shape[1] // (2 ** level) - location[1]) + level = self._check_level(img, level) + if size is None: + size = self._get_image_size(img, size, level, location) + # Extract patch (or the whole image) region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) + # Add necessary metadata metadata: Dict = {} metadata["spatial_shape"] = np.asarray(region.shape[:-1]) metadata["original_channel_dim"] = -1 + + # Make it channel first region = EnsureChannelFirst()(region, metadata) + + # Split into patches if patch_size is None: patches = region else: @@ -788,30 +793,66 @@ def get_data( return patches, metadata + def _check_level(self, img, level): + level = self.level + + level_count = 0 + if self.backend == "openslide": + level_count = img.level_count + elif self.backend == "cucim": + level_count = img.resolutions["level_count"] + elif self.backend == "tifffile": + level_count = len(img.pages) + + if level > level_count - 1: + raise ValueError(f"The maximum level of this image is {level_count - 1} while level={level} is requested)!") + + return level + + def _get_image_size(self, img, size, level, location): + max_size = [] + downsampling_factor = [] + if self.backend == "openslide": + downsampling_factor = img.level_downsamples[level] + max_size = img.level_dimensions[level][::-1] + elif self.backend == "cucim": + downsampling_factor = img.resolutions["level_downsamples"][level] + max_size = img.resolutions["level_dimensions"][level][::-1] + elif self.backend == "tifffile": + level0_size = img.pages[0].shape[:2] + max_size = img.pages[level].shape[:2] + downsampling_factor = np.mean([level0_size[i] / max_size[i] for i in range(len(max_size))]) + + # subtract the top left corner of the patch from maximum size + level_location = [round(location[i] / downsampling_factor) for i in range(len(location))] + size = [max_size[i] - level_location[i] for i in range(len(max_size))] + + return size + def _extract_region( self, img_obj, - size: Optional[Tuple[int, int]], + size: Tuple[int, int], location: Tuple[int, int] = (0, 0), level: int = 0, dtype: DtypeLike = np.uint8, ): if self.backend == "tifffile": - with img_obj: - region = img_obj.asarray(level=level) - if size is None: - region = region[location[0] :, location[1] :] - else: - region = region[location[0] : location[0] + size[0], location[1] : location[1] + size[1]] - + # with img_obj: + region = img_obj.asarray(level=level) + if level != 0: + level0_size = img_obj.pages[0].shape[:2] + max_size = img_obj.pages[level].shape[:2] + location = ( + int(location[0] / level0_size[0] * max_size[0]), + int(location[1] / level0_size[1] * max_size[1]), + ) + region = region[location[0] : location[0] + size[0], location[1] : location[1] + size[1]] else: - # reverse the order of dimensions for size and location to be compatible with image shape + # reverse the order of dimensions for size and location to become WxH location = location[::-1] - if size is None: - region = img_obj.read_region(location=location, level=level) - else: - size = size[::-1] - region = img_obj.read_region(location=location, size=size, level=level) + size = size[::-1] + region = img_obj.read_region(location=location, size=size, level=level) region = self.convert_to_rgb_array(region, dtype) return region @@ -824,6 +865,7 @@ def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): # convert to numpy (if not already in numpy) raw_region = np.asarray(raw_region, dtype=dtype) + # remove alpha channel if exist (RGBA) if raw_region.shape[-1] > 3: raw_region = raw_region[..., :3] diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 61eb2d82ce..e47a22908a 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -28,6 +28,8 @@ has_cucim = has_cucim and hasattr(cucim, "CuImage") _, has_osl = optional_import("openslide") imsave, has_tiff = optional_import("tifffile", name="imsave") +_, has_codec = optional_import("imagecodecs") +has_tiff = has_tiff and has_codec FILE_URL = "https://drive.google.com/uc?id=1sGTKZlJBIz53pfqTxoTqiIQzIoEzHLAe" base_name, extension = FILE_URL.split("id=")[1], ".tiff" @@ -69,6 +71,13 @@ np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), ] +TEST_CASE_5 = [ + FILE_PATH, + {"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)}, + np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), +] + + TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW @@ -92,7 +101,7 @@ def save_rgba_tiff(array: np.ndarray, filename: str, mode: str): return filename -@skipUnless(has_cucim or has_osl, "Requires cucim or openslide!") +@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") def setUpModule(): # noqa: N802 download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f") @@ -104,25 +113,30 @@ class Tests(unittest.TestCase): @parameterized.expand([TEST_CASE_0]) def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReader(self.backend, level=level) - img_obj = reader.read(file_path) - img = reader.get_data(img_obj)[0] + with reader.read(file_path) as img_obj: + img = reader.get_data(img_obj)[0] self.assertTupleEqual(img.shape, expected_shape) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5]) def test_read_region(self, file_path, patch_info, expected_img): + # Due to CPU memory limitation ignore tifffile at level 0. + if self.backend == "tifffile" and patch_info["level"] == 0: + return reader = WSIReader(self.backend) - img_obj = reader.read(file_path) # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] - img = reader.get_data(img_obj, **patch_info)[0] + with reader.read(file_path) as img_obj: + img = reader.get_data(img_obj, **patch_info)[0] + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + self.assertIsNone(assert_array_equal(img, img2)) self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) def test_read_patches(self, file_path, patch_info, expected_img): reader = WSIReader(self.backend) - img_obj = reader.read(file_path) - img = reader.get_data(img_obj, **patch_info)[0] + with reader.read(file_path) as img_obj: + img = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) @@ -140,8 +154,8 @@ def test_read_rgba(self, img_expected): os.path.join(os.path.dirname(__file__), "testing_data", f"temp_tiff_image_{mode}.tiff"), mode=mode, ) - img_obj = reader.read(file_path) - image[mode], _ = reader.get_data(img_obj) + with reader.read(file_path) as img_obj: + image[mode], _ = reader.get_data(img_obj) self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) @@ -149,7 +163,10 @@ def test_read_rgba(self, img_expected): @parameterized.expand([TEST_CASE_TRANSFORM_0]) def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape): train_transform = Compose( - [LoadImaged(keys=["image"], reader=WSIReader, backend="cuCIM", level=level), ToTensord(keys=["image"])] + [ + LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + ToTensord(keys=["image"]), + ] ) dataset = Dataset([{"image": file_path}], transform=train_transform) data_loader = DataLoader(dataset)