Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3415 Update WSIReader #3417

Merged
merged 14 commits into from
Dec 2, 2021
Merged
88 changes: 65 additions & 23 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,14 +681,14 @@ class WSIReader(ImageReader):
Read whole slide images and extract patches.

Args:
backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "Tifffile".
backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile".
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.

Note:
While "cucim" and "OpenSlide" backends both can load patches from large whole slide images
without loading the entire image into memory, "Tifffile" backend needs to load the entire image into memory
before extracting any patch; thus, memory consideration is needed when using "Tifffile" backend for
While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images
without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory
before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for
patch extraction.
"""

Expand Down Expand Up @@ -765,19 +765,24 @@ def get_data(
grid_shape: (row, columns) tuple define a grid to extract patches on that
patch_size: (height, width) the size of extracted patches at the given level
"""
# Verify inputs
if level is None:
level = self.level

if self.backend == "openslide" and size is None:
# the maximum size is set to WxH at the specified level
size = (img.shape[0] // (2 ** level) - location[0], img.shape[1] // (2 ** level) - location[1])
level = self._check_level(img, level)
if size is None:
size = self._get_image_size(img, size, level, location)

# Extract patch (or the whole image)
region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype)

# Add necessary metadata
metadata: Dict = {}
metadata["spatial_shape"] = np.asarray(region.shape[:-1])
metadata["original_channel_dim"] = -1

# Make it channel first
region = EnsureChannelFirst()(region, metadata)

# Split into patches
if patch_size is None:
patches = region
else:
Expand All @@ -788,30 +793,66 @@ def get_data(

return patches, metadata

def _check_level(self, img, level):
level = self.level

level_count = 0
if self.backend == "openslide":
level_count = img.level_count
elif self.backend == "cucim":
level_count = img.resolutions["level_count"]
elif self.backend == "tifffile":
level_count = len(img.pages)

if level > level_count - 1:
raise ValueError(f"The maximum level of this image is {level_count - 1} while level={level} is requested)!")

return level

def _get_image_size(self, img, size, level, location):
max_size = []
downsampling_factor = []
if self.backend == "openslide":
downsampling_factor = img.level_downsamples[level]
max_size = img.level_dimensions[level][::-1]
elif self.backend == "cucim":
downsampling_factor = img.resolutions["level_downsamples"][level]
max_size = img.resolutions["level_dimensions"][level][::-1]
elif self.backend == "tifffile":
level0_size = img.pages[0].shape[:2]
max_size = img.pages[level].shape[:2]
downsampling_factor = np.mean([level0_size[i] / max_size[i] for i in range(len(max_size))])

# subtract the top left corner of the patch from maximum size
level_location = [round(location[i] / downsampling_factor) for i in range(len(location))]
size = [max_size[i] - level_location[i] for i in range(len(max_size))]

return size

def _extract_region(
self,
img_obj,
size: Optional[Tuple[int, int]],
size: Tuple[int, int],
location: Tuple[int, int] = (0, 0),
level: int = 0,
dtype: DtypeLike = np.uint8,
):
if self.backend == "tifffile":
with img_obj:
region = img_obj.asarray(level=level)
if size is None:
region = region[location[0] :, location[1] :]
else:
region = region[location[0] : location[0] + size[0], location[1] : location[1] + size[1]]

# with img_obj:
region = img_obj.asarray(level=level)
if level != 0:
level0_size = img_obj.pages[0].shape[:2]
max_size = img_obj.pages[level].shape[:2]
location = (
int(location[0] / level0_size[0] * max_size[0]),
int(location[1] / level0_size[1] * max_size[1]),
)
region = region[location[0] : location[0] + size[0], location[1] : location[1] + size[1]]
else:
# reverse the order of dimensions for size and location to be compatible with image shape
# reverse the order of dimensions for size and location to become WxH
location = location[::-1]
if size is None:
region = img_obj.read_region(location=location, level=level)
else:
size = size[::-1]
region = img_obj.read_region(location=location, size=size, level=level)
size = size[::-1]
region = img_obj.read_region(location=location, size=size, level=level)

region = self.convert_to_rgb_array(region, dtype)
return region
Expand All @@ -824,6 +865,7 @@ def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8):

# convert to numpy (if not already in numpy)
raw_region = np.asarray(raw_region, dtype=dtype)

# remove alpha channel if exist (RGBA)
if raw_region.shape[-1] > 3:
raw_region = raw_region[..., :3]
Expand Down
41 changes: 29 additions & 12 deletions tests/test_wsireader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
has_cucim = has_cucim and hasattr(cucim, "CuImage")
_, has_osl = optional_import("openslide")
imsave, has_tiff = optional_import("tifffile", name="imsave")
_, has_codec = optional_import("imagecodecs")
wyli marked this conversation as resolved.
Show resolved Hide resolved
has_tiff = has_tiff and has_codec

FILE_URL = "https://drive.google.com/uc?id=1sGTKZlJBIz53pfqTxoTqiIQzIoEzHLAe"
base_name, extension = FILE_URL.split("id=")[1], ".tiff"
Expand Down Expand Up @@ -69,6 +71,13 @@
np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]),
]

TEST_CASE_5 = [
FILE_PATH,
{"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)},
np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]),
]


TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW

TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW
Expand All @@ -92,7 +101,7 @@ def save_rgba_tiff(array: np.ndarray, filename: str, mode: str):
return filename


@skipUnless(has_cucim or has_osl, "Requires cucim or openslide!")
@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
def setUpModule(): # noqa: N802
download_url(FILE_URL, FILE_PATH, "5a3cfd4fd725c50578ddb80b517b759f")

Expand All @@ -104,25 +113,30 @@ class Tests(unittest.TestCase):
@parameterized.expand([TEST_CASE_0])
def test_read_whole_image(self, file_path, level, expected_shape):
reader = WSIReader(self.backend, level=level)
img_obj = reader.read(file_path)
img = reader.get_data(img_obj)[0]
with reader.read(file_path) as img_obj:
img = reader.get_data(img_obj)[0]
self.assertTupleEqual(img.shape, expected_shape)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5])
def test_read_region(self, file_path, patch_info, expected_img):
# Due to CPU memory limitation ignore tifffile at level 0.
if self.backend == "tifffile" and patch_info["level"] == 0:
return
reader = WSIReader(self.backend)
img_obj = reader.read(file_path)
# Read twice to check multiple calls
img = reader.get_data(img_obj, **patch_info)[0]
img = reader.get_data(img_obj, **patch_info)[0]
with reader.read(file_path) as img_obj:
img = reader.get_data(img_obj, **patch_info)[0]
img2 = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, img2.shape)
self.assertIsNone(assert_array_equal(img, img2))
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))

@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
def test_read_patches(self, file_path, patch_info, expected_img):
reader = WSIReader(self.backend)
img_obj = reader.read(file_path)
img = reader.get_data(img_obj, **patch_info)[0]
with reader.read(file_path) as img_obj:
img = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))

Expand All @@ -140,16 +154,19 @@ def test_read_rgba(self, img_expected):
os.path.join(os.path.dirname(__file__), "testing_data", f"temp_tiff_image_{mode}.tiff"),
mode=mode,
)
img_obj = reader.read(file_path)
image[mode], _ = reader.get_data(img_obj)
with reader.read(file_path) as img_obj:
image[mode], _ = reader.get_data(img_obj)

self.assertIsNone(assert_array_equal(image["RGB"], img_expected))
self.assertIsNone(assert_array_equal(image["RGBA"], img_expected))

@parameterized.expand([TEST_CASE_TRANSFORM_0])
def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape):
train_transform = Compose(
[LoadImaged(keys=["image"], reader=WSIReader, backend="cuCIM", level=level), ToTensord(keys=["image"])]
[
LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level),
ToTensord(keys=["image"]),
]
)
dataset = Dataset([{"image": file_path}], transform=train_transform)
data_loader = DataLoader(dataset)
Expand Down