diff --git a/src/careamics/dataset/tiling/lvae_tiled_patching.py b/src/careamics/dataset/tiling/lvae_tiled_patching.py new file mode 100644 index 00000000..73deedb6 --- /dev/null +++ b/src/careamics/dataset/tiling/lvae_tiled_patching.py @@ -0,0 +1,282 @@ +"""Functions to reimpliment the tiling in the Disentangle repository.""" + +import builtins +import itertools +from typing import Any, Generator, Optional, Union + +import numpy as np +from numpy.typing import NDArray + +from careamics.config.tile_information import TileInformation + + +def extract_tiles( + arr: NDArray, + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], + padding_kwargs: Optional[dict[str, Any]] = None, +) -> Generator[tuple[NDArray, TileInformation], None, None]: + """Generate tiles from the input array with specified overlap. + + The tiles cover the whole array; which will be additionally padded, to ensure that + the section of the tile that contributes to the final image comes from the center + of the tile. + + The method returns a generator that yields tuples of array and tile information, + the latter includes whether the tile is the last one, the coordinates of the + overlap crop, and the coordinates of the stitched tile. + + Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX, + where C can be a singleton. + + Parameters + ---------- + arr : np.ndarray + Array of shape (S, C, (Z), Y, X). + tile_size : 1D numpy.ndarray of tuple + Tile sizes in each dimension, of length 2 or 3. + overlaps : 1D numpy.ndarray of tuple + Overlap values in each dimension, of length 2 or 3. + padding_kwargs : dict, optional + The arguments of `np.pad` after the first two arguments, `array` and + `pad_width`. If not specified the default will be `{"mode": "reflect"}`. See + `numpy.pad` docs: + https://numpy.org/doc/stable/reference/generated/numpy.pad.html. + + Yields + ------ + Generator[Tuple[np.ndarray, TileInformation], None, None] + Tile generator, yields the tile and additional information. + """ + if padding_kwargs is None: + padding_kwargs = {"mode": "reflect"} + + # Iterate over num samples (S) + for sample_idx in range(arr.shape[0]): + sample = arr[sample_idx, ...] + data_shape = np.array(sample.shape) + + # add padding to ensure evenly spaced & overlapping tiles. + spatial_padding = compute_padding(data_shape, tile_size, overlaps) + padding = ((0, 0), *spatial_padding) + sample = np.pad(sample, padding, **padding_kwargs) + + # The number of tiles in each dimension, should be of length 2 or 3 + tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps) + # itertools.product is equivalent of nested loops + + stitch_size = tile_size - overlaps + for tile_grid_coords in itertools.product(*[range(n) for n in tile_grid_shape]): + + # calculate crop coordinates + crop_coords_start = np.array(tile_grid_coords) * stitch_size + crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = ( + ..., + *[ + slice(coords, coords + extent) + for coords, extent in zip(crop_coords_start, tile_size) + ], + ) + tile = sample[crop_slices] + + tile_info = compute_tile_info( + np.array(tile_grid_coords), + np.array(data_shape), + np.array(tile_size), + np.array(overlaps), + sample_idx, + ) + # TODO: kinda weird this is a generator, + # -> doesn't really save memory ? Don't think there are any places the + # tiles are not exracted all at the same time. + # Although I guess it would make sense for a zarr tile extractor. + yield tile, tile_info + + +def compute_tile_info( + tile_grid_coords: NDArray[np.int_], + data_shape: NDArray[np.int_], + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], + sample_id: int = 0, +) -> TileInformation: + """ + Compute the tile information for a tile with the coordinates `tile_grid_coords`. + + Parameters + ---------- + tile_grid_coords : 1D np.array of int + The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D + tiling the coordinates for the second tile in the first row of tiles would be + (0, 1). + data_shape : 1D np.array of int + The shape of the data, should be (C, (Z), Y, X) where Z is optional. + tile_size : 1D np.array of int + Tile sizes in each dimension, of length 2 or 3. + overlaps : 1D np.array of int + Overlap values in each dimension, of length 2 or 3. + sample_id : int, default=0 + An ID to identify which sample a tile belongs to. + + Returns + ------- + TileInformation + Information that describes how to crop and stitch a tile to create a full image. + """ + spatial_dims_shape = data_shape[-len(tile_size) :] + + # The extent of the tile which will make up part of the stitched image. + stitch_size = tile_size - overlaps + stitch_coords_start = tile_grid_coords * stitch_size + stitch_coords_end = stitch_coords_start + stitch_size + + tile_coords_start = stitch_coords_start - overlaps // 2 + + # --- replace out of bounds indices + out_of_lower_bound = stitch_coords_start < 0 + out_of_upper_bound = stitch_coords_end > spatial_dims_shape + stitch_coords_start[out_of_lower_bound] = 0 + stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound] + + # --- calculate overlap crop coords + overlap_crop_coords_start = stitch_coords_start - tile_coords_start + overlap_crop_coords_end = overlap_crop_coords_start + ( + stitch_coords_end - stitch_coords_start + ) + + # --- combine start and end + stitch_coords = tuple( + (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end) + ) + overlap_crop_coords = tuple( + (start, end) + for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end) + ) + + # --- Check if last tile + tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps)) + last_tile = (tile_grid_coords == (tile_grid_shape - 1)).all() + + tile_info = TileInformation( + array_shape=data_shape, + last_tile=last_tile, + overlap_crop_coords=overlap_crop_coords, + stitch_coords=stitch_coords, + sample_id=sample_id, + ) + return tile_info + + +def compute_padding( + data_shape: NDArray[np.int_], + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], +) -> tuple[tuple[int, int], ...]: + """ + Calculate padding to ensure stitched data comes from the center of a tile. + + Padding is added to an array with shape `data_shape` so that when tiles are + stitched together, the data used always comes from the center of a tile, even for + tiles at the boundaries of the array. + + Parameters + ---------- + data_shape : 1D numpy.array of int + The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X). + tile_size : 1D numpy.array of int + The tile size in each dimension, ((Z), Y, X). + overlaps : 1D numpy.array of int + The tile overlap in each dimension, ((Z), Y, X). + + Returns + ------- + tuple of (int, int) + A tuple specifying the padding to add in each dimension, each element is a two + element tuple specifying the padding to add before and after the data. This + can be used as the `pad_width` argument to `numpy.pad`. + """ + tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps)) + covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps + + pad_before = overlaps // 2 + pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before + + return tuple((before, after) for before, after in zip(pad_before, pad_after)) + + +def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int: + """Calculate the number of tiles in a specific dimension. + + Parameters + ---------- + axis_size : int + The length of the data for in a specific dimension. + tile_size : int + The length of the tiles in a specific dimension. + overlap : int + The tile overlap in a specific dimension. + + Returns + ------- + int + The number of tiles that fit in one dimension given the arguments. + """ + return int(np.ceil(axis_size / (tile_size - overlap))) + + +def total_n_tiles( + data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...] +) -> int: + """Calculate The total number of tiles over all dimensions. + + Parameters + ---------- + data_shape : 1D numpy.array of int + The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X). + tile_size : 1D numpy.array of int + The tile size in each dimension, ((Z), Y, X). + overlaps : 1D numpy.array of int + The tile overlap in each dimension, ((Z), Y, X). + + + Returns + ------- + int + The total number of tiles over all dimensions. + """ + result = 1 + # assume spatial dimension are the last dimensions so iterate backwards + for i in range(-1, -len(tile_size) - 1, -1): + result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i]) + + return result + + +def compute_tile_grid_shape( + data_shape: NDArray[np.int_], + tile_size: NDArray[np.int_], + overlaps: NDArray[np.int_], +) -> tuple[int, ...]: + """Calculate the number of tiles in each dimension. + + This can be thought of as a grid of tiles. + + Parameters + ---------- + data_shape : 1D numpy.array of int + The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X). + tile_size : 1D numpy.array of int + The tile size in each dimension, ((Z), Y, X). + overlaps : 1D numpy.array of int + The tile overlap in each dimension, ((Z), Y, X). + + Returns + ------- + tuple of int + The number of tiles in each direction, ((Z, Y, X)). + """ + shape = [0 for _ in range(len(tile_size))] + # assume spatial dimension are the last dimensions so iterate backwards + for i in range(-1, -len(tile_size) - 1, -1): + shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i]) + return tuple(shape) diff --git a/src/careamics/prediction_utils/lvae_tiling_manager.py b/src/careamics/prediction_utils/lvae_tiling_manager.py new file mode 100644 index 00000000..1fc7ac81 --- /dev/null +++ b/src/careamics/prediction_utils/lvae_tiling_manager.py @@ -0,0 +1,362 @@ +"""Module contiaing tiling manager class.""" + +# # TODO: remove this file, left as a reference for now. + +# from typing import Any, Optional + +# import numpy as np +# from numpy.typing import NDArray + +# from careamics.config.tile_information import TileInformation +# from careamics.config.validators import check_axes_validity + + +# def calculate_padding( +# patch_start_location: NDArray, +# patch_size: NDArray, +# data_shape: NDArray, +# ) -> NDArray: +# patch_end_location = patch_start_location + patch_size + +# pad_before = np.zeros_like(patch_start_location) +# start_out_of_bounds = patch_start_location < 0 +# pad_before[start_out_of_bounds] = -patch_start_location[start_out_of_bounds] + +# pad_after = np.zeros_like(patch_start_location) +# end_out_of_bounds = patch_end_location > data_shape +# pad_after[end_out_of_bounds] = ( +# patch_end_location - data_shape +# )[end_out_of_bounds] + +# return np.stack([pad_before, pad_after], axis=1) + + +# def extract_tile( +# img: np.ndarray, +# grid_start_loc: tuple[int, ...], +# patch_size: tuple[int, ...], +# overlap: tuple[int, ...], +# padding: bool, +# padding_kwargs: Optional[dict[str, Any]] = None, +# ) -> NDArray: +# if padding_kwargs is None: +# padding_kwargs = {} + +# data_shape = img.shape +# patch_start_loc = np.array(grid_start_loc) - np.array(overlap) // 2 +# crop_slices = tuple( +# slice(max(0, start), min(start + size, dim_shape)) +# for start, size, dim_shape in zip(patch_start_loc, patch_size, data_shape) +# ) +# crop = img[crop_slices] +# if padding: +# pad = calculate_padding( +# patch_start_location=patch_start_loc, +# patch_size=patch_size, +# data_shape=data_shape, +# ) +# crop = np.pad(crop, pad, **padding_kwargs) + +# return crop + + +# class TilingManager: + +# def __init__( +# self, +# data_shape: tuple[int, ...], +# tile_size: tuple[int, ...], +# overlaps: tuple[int, ...], +# trim_boundary: tuple[int, ...], +# ): +# # --- validation +# if len(data_shape) != len(tile_size): +# raise ValueError( +# f"Data shape:{data_shape} and tile size:{tile_size} must have the " +# "same dimension" +# ) +# if len(data_shape) != len(overlaps): +# raise ValueError( +# f"Data shape:{data_shape} and tile overlaps:{overlaps} must have the " +# "same dimension" +# ) +# # overlaps = np.array(tile_size) - np.array(grid_shape) +# if (np.array(overlaps) < 0).any(): +# raise ValueError( +# "Tile overlap must be positive or zero in all dimension." +# ) +# if ((np.array(overlaps) % 2) != 0).any(): +# # TODO: currently not required by CAREamics tiling, +# # -> because floor divide is used. +# raise ValueError("Tile overlaps must be even.") + +# # initialize attributes +# self.data_shape = data_shape +# self.overlaps = overlaps +# self.grid_shape = tuple(np.array(tile_size) - np.array(overlaps)) +# self.patch_shape = tile_size +# self.trim_boundary = trim_boundary + +# def compute_tile_info(self, index: int, axes: str): + +# # TODO: better axis validation, data should already be in the form SC(Z)YX + +# # validate axes +# check_axes_validity(axes) +# # z will be -1 if not present +# spatial_axes = [axes.find("Z"), axes.find("Y"), axes.find("X")] + +# # convert to numpy for convenience +# data_shape = np.array(self.data_shape) +# patch_shape = np.array(self.patch_shape) + +# # --- calculate stitch coords +# stitch_coords_start = np.array(self.get_location_from_dataset_idx(index)) +# stitch_coords_end = stitch_coords_start + np.array(self.grid_shape) + +# # --- patch coords +# patch_coords_start = stitch_coords_start - np.array(self.overlaps) // 2 +# patch_coords_end = patch_coords_start + patch_shape + +# # --- replace out of bounds indices + +# out_of_lower_bound = stitch_coords_start < 0 +# out_of_upper_bound = stitch_coords_end > data_shape + +# stitch_coords_start[out_of_lower_bound] = 0 +# stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound] + +# # --- calculate overlap crop coords +# overlap_crop_coords_start = stitch_coords_start - patch_coords_start +# overlap_crop_coords_end = overlap_crop_coords_start + ( +# stitch_coords_end - stitch_coords_start +# ) + +# # --- combine start and end +# stitch_coords = tuple( +# (stitch_coords_start[axis], stitch_coords_end[axis]) +# for axis in spatial_axes +# if axis != -1 +# ) +# overlap_crop_coords = tuple( +# (overlap_crop_coords_start[axis], overlap_crop_coords_end[axis]) +# for axis in spatial_axes +# if axis != -1 +# ) + +# channel_axis = axes.find("C") +# array_shape_processed = tuple( +# data_shape[axis] for axis in [channel_axis, *spatial_axes] if axis != -1 +# ) + +# tile_info = TileInformation( +# array_shape=array_shape_processed, +# last_tile=index == self.total_grid_count() - 1, +# overlap_crop_coords=overlap_crop_coords, +# stitch_coords=stitch_coords, +# sample_id=0, # TODO: in iterable dataset this is also always 0 pretty sure +# ) +# return tile_info + +# def patch_offset(self): +# return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2 + +# def get_individual_dim_grid_count(self, dim: int): +# """ +# Returns the number of the grid in the specified dimension, ignoring all other +# dimensions. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" + +# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1: +# return self.data_shape[dim] +# elif self.trim_boundary is False: +# return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim])) +# else: +# excess_size = self.patch_shape[dim] - self.grid_shape[dim] +# return int( +# np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim]) +# ) + +# def total_grid_count(self): +# """ +# Returns the total number of grids in the dataset. +# """ +# return self.grid_count(0) * self.get_individual_dim_grid_count(0) + +# def grid_count(self, dim: int): +# """ +# Returns the total number of grids for one value in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# if dim == len(self.data_shape) - 1: +# return 1 + +# return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1) + +# def get_grid_index(self, dim: int, coordinate: int): +# """ +# Returns the index of the grid in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# assert ( +# coordinate < self.data_shape[dim] +# ), ( +# f"Coordinate {coordinate} is out of bounds for data " +# f"shape {self.data_shape}" +# ) +# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1: +# return coordinate +# elif self.trim_boundary is False: +# return np.floor(coordinate / self.grid_shape[dim]) +# else: +# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2 +# # can be <0 if coordinate is in [0,grid_shape[dim]] +# return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim])) + +# def dataset_idx_from_grid_idx(self, grid_idx: tuple): +# """ +# Returns the index of the grid in the dataset. +# """ +# assert len(grid_idx) == len( +# self.data_shape +# ), ( +# f"Dimension indices {grid_idx} must have the same dimension as data " +# f"shape {self.data_shape}" +# ) +# index = 0 +# for dim in range(len(grid_idx)): +# index += grid_idx[dim] * self.grid_count(dim) +# return index + +# def get_patch_location_from_dataset_idx(self, dataset_idx: int): +# """ +# Returns the patch location of the grid in the dataset. +# """ +# location = self.get_location_from_dataset_idx(dataset_idx) +# offset = self.patch_offset() +# return tuple(np.array(location) - np.array(offset)) + +# def get_dataset_idx_from_grid_location(self, location: tuple): +# assert len(location) == len( +# self.data_shape +# ), ( +# f"Location {location} must have the same dimension as data shape " +# f"{self.data_shape}" +# ) +# grid_idx = [ +# self.get_grid_index(dim, location[dim]) for dim in range(len(location)) +# ] +# return self.dataset_idx_from_grid_idx(tuple(grid_idx)) + +# def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int): +# """ +# Returns the grid-start coordinate of the grid in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# assert dim_index < self.get_individual_dim_grid_count( +# dim +# ), ( +# f"Dimension index {dim_index} is out of bounds for data shape " +# f"{self.data_shape}" +# ) + +# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1: +# return dim_index +# elif self.trim_boundary is False: +# return dim_index * self.grid_shape[dim] +# else: +# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2 +# return dim_index * self.grid_shape[dim] + excess_size + +# def get_location_from_dataset_idx(self, dataset_idx: int): +# grid_idx = [] +# for dim in range(len(self.data_shape)): +# grid_idx.append(dataset_idx // self.grid_count(dim)) +# dataset_idx = dataset_idx % self.grid_count(dim) +# location = [ +# self.get_gridstart_location_from_dim_index(dim, grid_idx[dim]) +# for dim in range(len(self.data_shape)) +# ] +# return tuple(location) + +# def on_boundary(self, dataset_idx: int, dim: int): +# """ +# Returns True if the grid is on the boundary in the specified dimension. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" + +# if dim > 0: +# dataset_idx = dataset_idx % self.grid_count(dim - 1) + +# dim_index = dataset_idx // self.grid_count(dim) +# return ( +# dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1 +# ) + +# def next_grid_along_dim(self, dataset_idx: int, dim: int): +# """ +# Returns the index of the grid in the specified dimension in the specified " +# "direction. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# new_idx = dataset_idx + self.grid_count(dim) +# if new_idx >= self.total_grid_count(): +# return None +# return new_idx + +# def prev_grid_along_dim(self, dataset_idx: int, dim: int): +# """ +# Returns the index of the grid in the specified dimension in the specified " +# "direction. +# """ +# assert dim < len( +# self.data_shape +# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}" +# assert dim >= 0, "Dimension must be greater than or equal to 0" +# new_idx = dataset_idx - self.grid_count(dim) +# if new_idx < 0: +# return None + + +# if __name__ == "__main__": +# data_shape = (1, 1, 103, 103, 2) +# grid_shape = (1, 1, 16, 16, 2) +# patch_shape = (1, 1, 32, 32, 2) +# overlap = tuple(np.array(patch_shape) - np.array(grid_shape)) + +# trim_boundary = False +# manager = TilingManager( +# data_shape=data_shape, +# tile_size=patch_shape, +# overlaps=overlap, +# trim_boundary=trim_boundary, +# ) +# gc = manager.total_grid_count() +# print("Grid count", gc) +# for i in range(gc): +# loc = manager.get_location_from_dataset_idx(i) +# print(i, loc) +# inferred_i = manager.get_dataset_idx_from_grid_location(loc) +# assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}" + +# for i in range(5): +# print(manager.on_boundary(40, i)) diff --git a/tests/dataset/tiling/test_lvae_tiled_patching.py b/tests/dataset/tiling/test_lvae_tiled_patching.py new file mode 100644 index 00000000..05e611bd --- /dev/null +++ b/tests/dataset/tiling/test_lvae_tiled_patching.py @@ -0,0 +1,188 @@ +import numpy as np +import pytest + +from careamics.config.tile_information import TileInformation +from careamics.dataset.tiling.lvae_tiled_patching import ( + compute_padding, + compute_tile_grid_shape, + compute_tile_info, + extract_tiles, + n_tiles_1d, + total_n_tiles, +) +from careamics.prediction_utils.stitch_prediction import stitch_prediction + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_extract_tiles(data_shape, tile_size, overlaps): + """Test extracted tiles are all the same size and can reconstruct the image.""" + + arr = np.random.random_sample(data_shape).astype(np.float32) + + tile_data_generator = extract_tiles( + arr=arr, tile_size=np.array(tile_size), overlaps=np.array(overlaps) + ) + + tiles = [] + tile_infos = [] + + # Assemble all tiles and their respective coordinates + for tile, tile_info in tile_data_generator: + + overlap_crop_coords = tile_info.overlap_crop_coords + stitch_coords = tile_info.stitch_coords + + # add data to lists + tiles.append(tile) + tile_infos.append(tile_info) + + # check tile shape, ignore channel dimension + assert tile.shape[1:] == tile_size + assert len(overlap_crop_coords) == len(stitch_coords) == len(tile_size) + + # stitch_prediction returns list + stitched_arr = stitch_prediction(tiles, tile_infos)[0] + + np.testing.assert_array_equal(arr, stitched_arr) + + +def test_compute_tile_info(): + """Test `compute_tile_info` for a selection of known results.""" + + # TODO: improve this test ? + + data_shape = np.array([1, 3, 10, 9]) + tile_size = np.array([4, 4]) + overlaps = np.array([2, 2]) + + # first example + tile_info = compute_tile_info((0, 0), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=False, + overlap_crop_coords=((1, 3), (1, 3)), + stitch_coords=((0, 2), (0, 2)), + sample_id=0, + ) + + # second example + tile_info = compute_tile_info((2, 2), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=False, + overlap_crop_coords=((1, 3), (1, 3)), + stitch_coords=((4, 6), (4, 6)), + sample_id=0, + ) + + # third example + tile_info = compute_tile_info((2, 4), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=False, + overlap_crop_coords=((1, 3), (1, 2)), + stitch_coords=((4, 6), (8, 9)), + sample_id=0, + ) + + # fourth example + tile_info = compute_tile_info((4, 4), data_shape[1:], tile_size, overlaps) + assert tile_info == TileInformation( + array_shape=tuple(data_shape[1:]), + last_tile=True, + overlap_crop_coords=((1, 3), (1, 2)), + stitch_coords=((8, 10), (8, 9)), + sample_id=0, + ) + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_compute_padding(data_shape, tile_size, overlaps): + + padding = compute_padding( + np.array(data_shape), np.array(tile_size), np.array(overlaps) + ) + + for axis, (before, after) in enumerate(padding): + # padded array should be divisible by the stitch size + stitch_size = tile_size[axis] - overlaps[axis] + axis_size = data_shape[axis + 2] # + 2 for sample and channel dims + assert (before + axis_size + after) % stitch_size == 0 + + assert before == overlaps[axis] // 2 + + +@pytest.mark.parametrize( + "axis_size, tile_size, overlap", + [(9, 4, 2), (10, 8, 4), (17, 8, 4)], +) +def test_n_tiles_1d(axis_size, tile_size, overlap): + """Test calculating the number of tiles in a specific dimension.""" + result = n_tiles_1d(axis_size, tile_size, overlap) + assert result == int(np.ceil(axis_size / (tile_size - overlap))) + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_total_n_tiles(data_shape, tile_size, overlaps): + """Test calculating the total number of tiles.""" + + result = total_n_tiles(data_shape, tile_size, overlaps) + n_tiles = 1 + for i in range(-1, -len(tile_size) - 1, -1): + n_tiles = n_tiles * int(np.ceil(data_shape[i] / (tile_size[i] - overlaps[i]))) + + assert result == n_tiles + + +@pytest.mark.parametrize( + "data_shape, tile_size, overlaps", + [ + # 2D + ((1, 3, 10, 9), (4, 4), (2, 2)), + ((1, 3, 10, 9), (8, 8), (4, 4)), + # 3D + ((1, 3, 8, 16, 17), (4, 4, 4), (2, 2, 2)), + ((1, 3, 8, 16, 17), (8, 8, 8), (4, 4, 4)), + ], +) +def test_compute_tile_grid_shape(data_shape, tile_size, overlaps): + """Test computing tile grid shape.""" + + result = compute_tile_grid_shape(data_shape, tile_size, overlaps) + + tile_grid_shape = tuple( + int(np.ceil(data_shape[i] / (tile_size[i] - overlaps[i]))) + for i in range(-len(tile_size), 0, 1) + ) + + assert result == tile_grid_shape