From 305df4369c3e3044d952288a6058538304246b41 Mon Sep 17 00:00:00 2001 From: Melisande Croft <63270704+melisande-c@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:18:43 +0200 Subject: [PATCH] Feature: Add tiling method from Disentangle repository (#207) ### Description The Disentangle repository takes a slightly different approach to tiling results than the method in CAREamics. This PR reimplements the former's approach to ensure the reproduction of the results in the microSplit paper, in a way that should be compatible with the existing CAREamics prediction pipeline with minor refactoring. - **What**: Added a generator `extract_tiles` that yields the extracted tile and and tile information. This function is interchangeable with the existing `careamics.dataset.tiling.extract_tiles`. This means the resulting set of tiles can be passed to the existing `careamics.prediction_utils.stitch_prediction` to reconstruct the full image. The `InMemoryTiledPredDataset` can be refactored so the `extract_tiles` generator, can be chosen by the configuration. - **Why**: To ensure the reproduction of the results in the microSplit paper. - **How**: Moved relevant logic contained in the Disentangle classes [`GridIndexManager`](https://github.com/ashesh-0/Disentangle/blob/3dusplit/disentangle/data_loader/patch_index_manager.py) and [`MultiChDloader`](https://github.com/ashesh-0/Disentangle/blob/3dusplit/disentangle/data_loader/vanilla_dloader.py) to a set of functions contained within the module `careamics/dataset/tiling/lvae_tiled_patching.py`. This includes the `compute_tile_info` function that computes a CAREamics compatible `TileInformation` object from a given tile's location. ### Changes Made - **Added**: - `careamics/dataset/tiling/lvae_tiled_patching.py` - `tests/dataset/tiling/test_lvae_tiled_patching.py` - `src/careamics/prediction_utils/lvae_tiling_manager.py` - This can be ignored. I started by refactoring the `GridIndexManager` class (renamed to `TilingManager`). It is unused because I decided to create a functional implementation, more inline with what already exists in CAREamics. However, I have left it in, commented out, in case it is useful for the dataset implementation. ### Additional Notes and Examples Bit extra, but diagrams below demonstrate the difference between the two tiling methods. The grey area in the diagram illustrates where padding has been added to array. The Disentangle method ensures the stitched portion of a tile comes from the center. ![CAREamics_tiling](https://github.com/user-attachments/assets/1db2b00d-a8fc-47f7-ae2c-07af04f275bf) ![Asheshs_tiling](https://github.com/user-attachments/assets/3c3700dd-c48d-4fe7-933e-12200adbe6a7) #### Where to find original logic `extract_tiles` logic mostly comes from: https://github.com/ashesh-0/Disentangle/blob/ed99b2614a9e52b496947ccf157d3aaa8db52872/disentangle/data_loader/vanilla_dloader.py#L435-L465 Calculation of `overlap_crop_coords` and `stitch_coords` in `compute_tile_info` comes from: https://github.com/ashesh-0/Disentangle/blob/ed99b2614a9e52b496947ccf157d3aaa8db52872/disentangle/analysis/stitch_prediction.py#L36-L66 All other helper functions are a reimplementation of the logic in the methods of `GridIndexManager`. #### Note We should come with a better name for the two different approaches to avoid name clashes. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> --- .../dataset/tiling/lvae_tiled_patching.py | 282 ++++++++++++++ .../prediction_utils/lvae_tiling_manager.py | 362 ++++++++++++++++++ .../tiling/test_lvae_tiled_patching.py | 188 +++++++++ 3 files changed, 832 insertions(+) create mode 100644 src/careamics/dataset/tiling/lvae_tiled_patching.py create mode 100644 src/careamics/prediction_utils/lvae_tiling_manager.py create mode 100644 tests/dataset/tiling/test_lvae_tiled_patching.py diff --git a/src/careamics/dataset/tiling/lvae_tiled_patching.py b/src/careamics/dataset/tiling/lvae_tiled_patching.py new file mode 100644 index 00000000..73deedb6 --- /dev/null +++ b/src/careamics/dataset/tiling/lvae_tiled_patching.py @@ -0,0 +1,282 @@ +"""Functions to reimpliment the tiling in the Disentangle repository.""" + +import builtins +import itertools +from typing import Any, Generator, Optional, Union + +import numpy as np +from numpy.typing import NDArray + +from careamics.config.tile_information import TileInformation + + +def extract_tiles( + arr: NDArray, + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], + padding_kwargs: Optional[dict[str, Any]] = None, +) -> Generator[tuple[NDArray, TileInformation], None, None]: + """Generate tiles from the input array with specified overlap. + + The tiles cover the whole array; which will be additionally padded, to ensure that + the section of the tile that contributes to the final image comes from the center + of the tile. + + The method returns a generator that yields tuples of array and tile information, + the latter includes whether the tile is the last one, the coordinates of the + overlap crop, and the coordinates of the stitched tile. + + Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX, + where C can be a singleton. + + Parameters + ---------- + arr : np.ndarray + Array of shape (S, C, (Z), Y, X). + tile_size : 1D numpy.ndarray of tuple + Tile sizes in each dimension, of length 2 or 3. + overlaps : 1D numpy.ndarray of tuple + Overlap values in each dimension, of length 2 or 3. + padding_kwargs : dict, optional + The arguments of `np.pad` after the first two arguments, `array` and + `pad_width`. If not specified the default will be `{"mode": "reflect"}`. See + `numpy.pad` docs: + https://numpy.org/doc/stable/reference/generated/numpy.pad.html. + + Yields + ------ + Generator[Tuple[np.ndarray, TileInformation], None, None] + Tile generator, yields the tile and additional information. + """ + if padding_kwargs is None: + padding_kwargs = {"mode": "reflect"} + + # Iterate over num samples (S) + for sample_idx in range(arr.shape[0]): + sample = arr[sample_idx, ...] + data_shape = np.array(sample.shape) + + # add padding to ensure evenly spaced & overlapping tiles. + spatial_padding = compute_padding(data_shape, tile_size, overlaps) + padding = ((0, 0), *spatial_padding) + sample = np.pad(sample, padding, **padding_kwargs) + + # The number of tiles in each dimension, should be of length 2 or 3 + tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps) + # itertools.product is equivalent of nested loops + + stitch_size = tile_size - overlaps + for tile_grid_coords in itertools.product(*[range(n) for n in tile_grid_shape]): + + # calculate crop coordinates + crop_coords_start = np.array(tile_grid_coords) * stitch_size + crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = ( + ..., + *[ + slice(coords, coords + extent) + for coords, extent in zip(crop_coords_start, tile_size) + ], + ) + tile = sample[crop_slices] + + tile_info = compute_tile_info( + np.array(tile_grid_coords), + np.array(data_shape), + np.array(tile_size), + np.array(overlaps), + sample_idx, + ) + # TODO: kinda weird this is a generator, + # -> doesn't really save memory ? Don't think there are any places the + # tiles are not exracted all at the same time. + # Although I guess it would make sense for a zarr tile extractor. + yield tile, tile_info + + +def compute_tile_info( + tile_grid_coords: NDArray[np.int_], + data_shape: NDArray[np.int_], + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], + sample_id: int = 0, +) -> TileInformation: + """ + Compute the tile information for a tile with the coordinates `tile_grid_coords`. + + Parameters + ---------- + tile_grid_coords : 1D np.array of int + The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D + tiling the coordinates for the second tile in the first row of tiles would be + (0, 1). + data_shape : 1D np.array of int + The shape of the data, should be (C, (Z), Y, X) where Z is optional. + tile_size : 1D np.array of int + Tile sizes in each dimension, of length 2 or 3. + overlaps : 1D np.array of int + Overlap values in each dimension, of length 2 or 3. + sample_id : int, default=0 + An ID to identify which sample a tile belongs to. + + Returns + ------- + TileInformation + Information that describes how to crop and stitch a tile to create a full image. + """ + spatial_dims_shape = data_shape[-len(tile_size) :] + + # The extent of the tile which will make up part of the stitched image. + stitch_size = tile_size - overlaps + stitch_coords_start = tile_grid_coords * stitch_size + stitch_coords_end = stitch_coords_start + stitch_size + + tile_coords_start = stitch_coords_start - overlaps // 2 + + # --- replace out of bounds indices + out_of_lower_bound = stitch_coords_start < 0 + out_of_upper_bound = stitch_coords_end > spatial_dims_shape + stitch_coords_start[out_of_lower_bound] = 0 + stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound] + + # --- calculate overlap crop coords + overlap_crop_coords_start = stitch_coords_start - tile_coords_start + overlap_crop_coords_end = overlap_crop_coords_start + ( + stitch_coords_end - stitch_coords_start + ) + + # --- combine start and end + stitch_coords = tuple( + (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end) + ) + overlap_crop_coords = tuple( + (start, end) + for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end) + ) + + # --- Check if last tile + tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps)) + last_tile = (tile_grid_coords == (tile_grid_shape - 1)).all() + + tile_info = TileInformation( + array_shape=data_shape, + last_tile=last_tile, + overlap_crop_coords=overlap_crop_coords, + stitch_coords=stitch_coords, + sample_id=sample_id, + ) + return tile_info + + +def compute_padding( + data_shape: NDArray[np.int_], + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], +) -> tuple[tuple[int, int], ...]: + """ + Calculate padding to ensure stitched data comes from the center of a tile. + + Padding is added to an array with shape `data_shape` so that when tiles are + stitched together, the data used always comes from the center of a tile, even for + tiles at the boundaries of the array. + + Parameters + ---------- + data_shape : 1D numpy.array of int + The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X). + tile_size : 1D numpy.array of int + The tile size in each dimension, ((Z), Y, X). + overlaps : 1D numpy.array of int + The tile overlap in each dimension, ((Z), Y, X). + + Returns + ------- + tuple of (int, int) + A tuple specifying the padding to add in each dimension, each element is a two + element tuple specifying the padding to add before and after the data. This + can be used as the `pad_width` argument to `numpy.pad`. + """ + tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps)) + covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps + + pad_before = overlaps // 2 + pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before + + return tuple((before, after) for before, after in zip(pad_before, pad_after)) + + +def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int: + """Calculate the number of tiles in a specific dimension. + + Parameters + ---------- + axis_size : int + The length of the data for in a specific dimension. + tile_size : int + The length of the tiles in a specific dimension. + overlap : int + The tile overlap in a specific dimension. + + Returns + ------- + int + The number of tiles that fit in one dimension given the arguments. + """ + return int(np.ceil(axis_size / (tile_size - overlap))) + + +def total_n_tiles( + data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...] +) -> int: + """Calculate The total number of tiles over all dimensions. + + Parameters + ---------- + data_shape : 1D numpy.array of int + The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X). + tile_size : 1D numpy.array of int + The tile size in each dimension, ((Z), Y, X). + overlaps : 1D numpy.array of int + The tile overlap in each dimension, ((Z), Y, X). + + + Returns + ------- + int + The total number of tiles over all dimensions. + """ + result = 1 + # assume spatial dimension are the last dimensions so iterate backwards + for i in range(-1, -len(tile_size) - 1, -1): + result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i]) + + return result + + +def compute_tile_grid_shape( + data_shape: NDArray[np.int_], + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], +) -> tuple[int, ...]: + """Calculate the number of tiles in each dimension. + + This can be thought of as a grid of tiles. + + Parameters + ---------- + data_shape : 1D numpy.array of int + The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X). + tile_size : 1D numpy.array of int + The tile size in each dimension, ((Z), Y, X). + overlaps : 1D numpy.array of int + The tile overlap in each dimension, ((Z), Y, X). + + Returns + ------- + tuple of int + The number of tiles in each direction, ((Z, Y, X)). + """ + shape = [0 for _ in range(len(tile_size))] + # assume spatial dimension are the last dimensions so iterate backwards + for i in range(-1, -len(tile_size) - 1, -1): + shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i]) + return tuple(shape) diff --git a/src/careamics/prediction_utils/lvae_tiling_manager.py b/src/careamics/prediction_utils/lvae_tiling_manager.py new file mode 100644 index 00000000..1fc7ac81 --- /dev/null +++ b/src/careamics/prediction_utils/lvae_tiling_manager.py @@ -0,0 +1,362 @@ +"""Module contiaing tiling manager class.""" + +# # TODO: remove this file, left as a reference for now. + +# from typing import Any, Optional + +# import numpy as np +# from numpy.typing import NDArray + +# from careamics.config.tile_information import TileInformation +# from careamics.config.validators import check_axes_validity + + +# def calculate_padding( +# patch_start_location: NDArray, +# patch_size: NDArray, +# data_shape: NDArray, +# ) -> NDArray: +# patch_end_location = patch_start_location + patch_size + +# pad_before = np.zeros_like(patch_start_location) +# start_out_of_bounds = patch_start_location < 0 +# pad_before[start_out_of_bounds] = -patch_start_location[start_out_of_bounds] + +# pad_after = np.zeros_like(patch_start_location) +# end_out_of_bounds = patch_end_location > data_shape +# pad_after[end_out_of_bounds] = ( +# patch_end_location - data_shape +# )[end_out_of_bounds] + +# return np.stack([pad_before, pad_after], axis=1) + + +# def extract_tile( +# img: np.ndarray, +# grid_start_loc: tuple[int, ...], +# patch_size: tuple[int, ...], +# overlap: tuple[int, ...], +# padding: bool, +# padding_kwargs: Optional[dict[str, Any]] = None, +# ) -> NDArray: +# if padding_kwargs is None: +# padding_kwargs = {} + +# data_shape = img.shape +# patch_start_loc = np.array(grid_start_loc) - np.array(overlap) // 2 +# crop_slices = tuple( +# slice(max(0, start), min(start + size, dim_shape)) +# for start, size, dim_shape in zip(patch_start_loc, patch_size, data_shape) +# ) +# crop = img[crop_slices] +# if padding: +# pad = calculate_padding( +# patch_start_location=patch_start_loc, +# patch_size=patch_size, +# data_shape=data_shape, +# ) +# crop = np.pad(crop, pad, **padding_kwargs) + +# return crop + + +# class TilingManager: + +# def __init__( +# self, +# data_shape: tuple[int, ...], +# tile_size: tuple[int, ...], +# overlaps: tuple[int, ...], +# trim_boundary: tuple[int, ...], +# ): +# # --- validation +# if len(data_shape) != len(tile_size): +# raise ValueError( +# f"Data shape:{data_shape} and tile size:{tile_size} must have the " +# "same dimension" +# ) +# if len(data_shape) != len(overlaps): +# raise ValueError( +# f"Data shape:{data_shape} and tile overlaps:{overlaps} must have the " +# "same dimension" +# ) +# # overlaps = np.array(tile_size) - np.array(grid_shape) +# if (np.array(overlaps) < 0).any(): +# raise ValueError( +# "Tile overlap must be positive or zero in all dimension." +# ) +# if ((np.array(overlaps) % 2) != 0).any(): +# # TODO: currently not required by CAREamics tiling, +# # -> because floor divide is used. +# raise ValueError("Tile overlaps must be even.") + +# # initialize attributes +# self.data_shape = data_shape +# self.overlaps = overlaps +# self.grid_shape = tuple(np.array(tile_size) - np.array(overlaps)) +# self.patch_shape = tile_size +# self.trim_boundary = trim_boundary + +# def compute_tile_info(self, index: int, axes: str): + +# # TODO: better axis validation, data should already be in the form SC(Z)YX + +# # validate axes +# check_axes_validity(axes) +# # z will be -1 if not present +# spatial_axes = [axes.find("Z"), axes.find("Y"), axes.find("X")] + +# # convert to numpy for convenience +# data_shape = np.array(self.data_shape) +# patch_shape = np.array(self.patch_shape) + +# # --- calculate stitch coords +# stitch_coords_start = np.array(self.get_location_from_dataset_idx(index)) +# stitch_coords_end = stitch_coords_start + np.array(self.grid_shape) + +# # --- patch coords +# patch_coords_start = stitch_coords_start - np.array(self.overlaps) // 2 +# patch_coords_end = patch_coords_start + patch_shape + +# # --- replace out of bounds indices + +# out_of_lower_bound = stitch_coords_start < 0 +# out_of_upper_bound = stitch_coords_end > data_shape + +# stitch_coords_start[out_of_lower_bound] = 0 +# stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound] + +# # --- calculate overlap crop coords +# overlap_crop_coords_start = stitch_coords_start - patch_coords_start +# overlap_crop_coords_end = overlap_crop_coords_start + ( +# stitch_coords_end - stitch_coords_start +# ) + +# # --- combine start and end +# stitch_coords = tuple( +# (stitch_coords_start[axis], stitch_coords_end[axis]) +# for axis in spatial_axes +# if axis != -1 +# ) +# overlap_crop_coords = tuple( +# (overlap_crop_coords_start[axis], overlap_crop_coords_end[axis]) +# for axis in spatial_axes +# if axis != -1 +# ) + +# channel_axis = axes.find("C") +# array_shape_processed = tuple( +# data_shape[axis] for axis in [channel_axis, *spatial_axes] if axis != -1 +# ) + +# tile_info = TileInformation( +# array_shape=array_shape_processed, +# last_tile=index == self.total_grid_count() - 1, +# overlap_crop_coords=overlap_crop_coords, +# stitch_coords=stitch_coords, +# sample_id=0, # TODO: in iterable dataset this is also always 0 pretty sure +# ) +# return tile_info + +# def patch_offset(self): +# return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2 + +# def get_individual_dim_grid_count(self, dim: int): +# """ +# Returns the number of the grid in the specified dimension, ignoring all other +# dimensions. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" + +# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1: +# return self.data_shape[dim] +# elif self.trim_boundary is False: +# return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim])) +# else: +# excess_size = self.patch_shape[dim] - self.grid_shape[dim] +# return int( +# np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim]) +# ) + +# def total_grid_count(self): +# """ +# Returns the total number of grids in the dataset. +# """ +# return self.grid_count(0) * self.get_individual_dim_grid_count(0) + +# def grid_count(self, dim: int): +# """ +# Returns the total number of grids for one value in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# if dim == len(self.data_shape) - 1: +# return 1 + +# return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1) + +# def get_grid_index(self, dim: int, coordinate: int): +# """ +# Returns the index of the grid in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# assert ( +# coordinate < self.data_shape[dim] +# ), ( +# f"Coordinate {coordinate} is out of bounds for data " +# f"shape {self.data_shape}" +# ) +# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1: +# return coordinate +# elif self.trim_boundary is False: +# return np.floor(coordinate / self.grid_shape[dim]) +# else: +# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2 +# # can be <0 if coordinate is in [0,grid_shape[dim]] +# return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim])) + +# def dataset_idx_from_grid_idx(self, grid_idx: tuple): +# """ +# Returns the index of the grid in the dataset. +# """ +# assert len(grid_idx) == len( +# self.data_shape +# ), ( +# f"Dimension indices {grid_idx} must have the same dimension as data " +# f"shape {self.data_shape}" +# ) +# index = 0 +# for dim in range(len(grid_idx)): +# index += grid_idx[dim] * self.grid_count(dim) +# return index + +# def get_patch_location_from_dataset_idx(self, dataset_idx: int): +# """ +# Returns the patch location of the grid in the dataset. +# """ +# location = self.get_location_from_dataset_idx(dataset_idx) +# offset = self.patch_offset() +# return tuple(np.array(location) - np.array(offset)) + +# def get_dataset_idx_from_grid_location(self, location: tuple): +# assert len(location) == len( +# self.data_shape +# ), ( +# f"Location {location} must have the same dimension as data shape " +# f"{self.data_shape}" +# ) +# grid_idx = [ +# self.get_grid_index(dim, location[dim]) for dim in range(len(location)) +# ] +# return self.dataset_idx_from_grid_idx(tuple(grid_idx)) + +# def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int): +# """ +# Returns the grid-start coordinate of the grid in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# assert dim_index < self.get_individual_dim_grid_count( +# dim +# ), ( +# f"Dimension index {dim_index} is out of bounds for data shape " +# f"{self.data_shape}" +# ) + +# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1: +# return dim_index +# elif self.trim_boundary is False: +# return dim_index * self.grid_shape[dim] +# else: +# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2 +# return dim_index * self.grid_shape[dim] + excess_size + +# def get_location_from_dataset_idx(self, dataset_idx: int): +# grid_idx = [] +# for dim in range(len(self.data_shape)): +# grid_idx.append(dataset_idx // self.grid_count(dim)) +# dataset_idx = dataset_idx % self.grid_count(dim) +# location = [ +# self.get_gridstart_location_from_dim_index(dim, grid_idx[dim]) +# for dim in range(len(self.data_shape)) +# ] +# return tuple(location) + +# def on_boundary(self, dataset_idx: int, dim: int): +# """ +# Returns True if the grid is on the boundary in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" + +# if dim > 0: +# dataset_idx = dataset_idx % self.grid_count(dim - 1) + +# dim_index = dataset_idx // self.grid_count(dim) +# return ( +# dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1 +# ) + +# def next_grid_along_dim(self, dataset_idx: int, dim: int): +# """ +# Returns the index of the grid in the specified dimension in the specified " +# "direction. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# new_idx = dataset_idx + self.grid_count(dim) +# if new_idx >= self.total_grid_count(): +# return None +# return new_idx + +# def prev_grid_along_dim(self, dataset_idx: int, dim: int): +# """ +# Returns the index of the grid in the specified dimension in the specified " +# "direction. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# new_idx = dataset_idx - self.grid_count(dim) +# if new_idx < 0: +# return None + + +# if __name__ == "__main__": +# data_shape = (1, 1, 103, 103, 2) +# grid_shape = (1, 1, 16, 16, 2) +# patch_shape = (1, 1, 32, 32, 2) +# overlap = tuple(np.array(patch_shape) - np.array(grid_shape)) + +# trim_boundary = False +# manager = TilingManager( +# data_shape=data_shape, +# tile_size=patch_shape, +# overlaps=overlap, +# trim_boundary=trim_boundary, +# ) +# gc = manager.total_grid_count() +# print("Grid count", gc) +# for i in range(gc): +# loc = manager.get_location_from_dataset_idx(i) +# print(i, loc) +# inferred_i = manager.get_dataset_idx_from_grid_location(loc) +# assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}" + +# for i in range(5): +# print(manager.on_boundary(40, i)) diff --git a/tests/dataset/tiling/test_lvae_tiled_patching.py b/tests/dataset/tiling/test_lvae_tiled_patching.py new file mode 100644 index 00000000..05e611bd --- /dev/null +++ b/tests/dataset/tiling/test_lvae_tiled_patching.py @@ -0,0 +1,188 @@ +import numpy as np +import pytest + +from careamics.config.tile_information import TileInformation +from careamics.dataset.tiling.lvae_tiled_patching import ( + compute_padding, + compute_tile_grid_shape, + compute_tile_info, + extract_tiles, + n_tiles_1d, + total_n_tiles, +) +from careamics.prediction_utils.stitch_prediction import stitch_prediction + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_extract_tiles(data_shape, tile_size, overlaps): + """Test extracted tiles are all the same size and can reconstruct the image.""" + + arr = np.random.random_sample(data_shape).astype(np.float32) + + tile_data_generator = extract_tiles( + arr=arr, tile_size=np.array(tile_size), overlaps=np.array(overlaps) + ) + + tiles = [] + tile_infos = [] + + # Assemble all tiles and their respective coordinates + for tile, tile_info in tile_data_generator: + + overlap_crop_coords = tile_info.overlap_crop_coords + stitch_coords = tile_info.stitch_coords + + # add data to lists + tiles.append(tile) + tile_infos.append(tile_info) + + # check tile shape, ignore channel dimension + assert tile.shape[1:] == tile_size + assert len(overlap_crop_coords) == len(stitch_coords) == len(tile_size) + + # stitch_prediction returns list + stitched_arr = stitch_prediction(tiles, tile_infos)[0] + + np.testing.assert_array_equal(arr, stitched_arr) + + +def test_compute_tile_info(): + """Test `compute_tile_info` for a selection of known results.""" + + # TODO: improve this test ? + + data_shape = np.array([1, 3, 10, 9]) + tile_size = np.array([4, 4]) + overlaps = np.array([2, 2]) + + # first example + tile_info = compute_tile_info((0, 0), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=False, + overlap_crop_coords=((1, 3), (1, 3)), + stitch_coords=((0, 2), (0, 2)), + sample_id=0, + ) + + # second example + tile_info = compute_tile_info((2, 2), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=False, + overlap_crop_coords=((1, 3), (1, 3)), + stitch_coords=((4, 6), (4, 6)), + sample_id=0, + ) + + # third example + tile_info = compute_tile_info((2, 4), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=False, + overlap_crop_coords=((1, 3), (1, 2)), + stitch_coords=((4, 6), (8, 9)), + sample_id=0, + ) + + # fourth example + tile_info = compute_tile_info((4, 4), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=True, + overlap_crop_coords=((1, 3), (1, 2)), + stitch_coords=((8, 10), (8, 9)), + sample_id=0, + ) + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_compute_padding(data_shape, tile_size, overlaps): + + padding = compute_padding( + np.array(data_shape), np.array(tile_size), np.array(overlaps) + ) + + for axis, (before, after) in enumerate(padding): + # padded array should be divisible by the stitch size + stitch_size = tile_size[axis] - overlaps[axis] + axis_size = data_shape[axis + 2] # + 2 for sample and channel dims + assert (before + axis_size + after) % stitch_size == 0 + + assert before == overlaps[axis] // 2 + + +@pytest.mark.parametrize( + "axis_size, tile_size, overlap", + [(9, 4, 2), (10, 8, 4), (17, 8, 4)], +) +def test_n_tiles_1d(axis_size, tile_size, overlap): + """Test calculating the number of tiles in a specific dimension.""" + result = n_tiles_1d(axis_size, tile_size, overlap) + assert result == int(np.ceil(axis_size / (tile_size - overlap))) + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_total_n_tiles(data_shape, tile_size, overlaps): + """Test calculating the total number of tiles.""" + + result = total_n_tiles(data_shape, tile_size, overlaps) + n_tiles = 1 + for i in range(-1, -len(tile_size) - 1, -1): + n_tiles = n_tiles * int(np.ceil(data_shape[i] / (tile_size[i] - overlaps[i]))) + + assert result == n_tiles + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_compute_tile_grid_shape(data_shape, tile_size, overlaps): + """Test computing tile grid shape.""" + + result = compute_tile_grid_shape(data_shape, tile_size, overlaps) + + tile_grid_shape = tuple( + int(np.ceil(data_shape[i] / (tile_size[i] - overlaps[i]))) + for i in range(-len(tile_size), 0, 1) + ) + + assert result == tile_grid_shape