Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Predict to disk (outerloop implementation) #253

Merged
merged 4 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 175 additions & 6 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
SupportedData,
SupportedLogger,
)
from careamics.dataset.dataset_utils import reshape_array
from careamics.dataset.dataset_utils import list_files, reshape_array
from careamics.file_io import WriteFunc, get_write_func
from careamics.lightning import (
FCNModule,
HyperParametersCallback,
Expand Down Expand Up @@ -519,7 +520,7 @@ def predict( # numpydoc ignore=GL08
*,
batch_size: int = 1,
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: tuple[int, ...] = (48, 48),
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["tiff", "custom"]] = None,
tta_transforms: bool = False,
Expand All @@ -535,7 +536,7 @@ def predict( # numpydoc ignore=GL08
*,
batch_size: int = 1,
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: tuple[int, ...] = (48, 48),
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["array"]] = None,
tta_transforms: bool = False,
Expand All @@ -546,7 +547,7 @@ def predict(
self,
source: Union[PredictDataModule, Path, str, NDArray],
*,
batch_size: Optional[int] = None,
batch_size: int = 1,
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
axes: Optional[str] = None,
Expand All @@ -567,7 +568,7 @@ def predict(
configuration parameters will be used, with the `patch_size` instead of
`tile_size`.

Test-time augmentation (TTA) can be switched off using the `tta_transforms`
Test-time augmentation (TTA) can be switched on using the `tta_transforms`
parameter. The TTA augmentation applies all possible flip and 90 degrees
rotations to the prediction input and averages the predictions. TTA augmentation
should not be used if you did not train with these augmentations.
Expand All @@ -580,7 +581,7 @@ def predict(

Parameters
----------
source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
source : PredictDataModule, pathlib.Path, str or numpy.ndarray
Data to predict on.
batch_size : int, default=1
Batch size for prediction.
Expand Down Expand Up @@ -668,6 +669,174 @@ def predict(
)
return convert_outputs(predictions, self.pred_datamodule.tiled)

def predict_to_disk(
self,
source: Union[PredictDataModule, Path, str],
*,
batch_size: int = 1,
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
axes: Optional[str] = None,
data_type: Optional[Literal["tiff", "custom"]] = None,
tta_transforms: bool = False,
dataloader_params: Optional[dict] = None,
read_source_func: Optional[Callable] = None,
extension_filter: str = "",
write_type: Literal["tiff", "custom"] = "tiff",
write_extension: Optional[str] = None,
write_func: Optional[WriteFunc] = None,
write_func_kwargs: Optional[dict[str, Any]] = None,
prediction_dir: Union[Path, str] = "predictions",
**kwargs,
) -> None:
"""
Make predictions on the provided data and save outputs to files.

The predictions will be saved in a new directory 'predictions' within the set
melisande-c marked this conversation as resolved.
Show resolved Hide resolved
working directory. The directory stucture within the 'predictions' directory
will match that of the source directory.

The `source` must be from files and not arrays. The file names of the
predictions will match those of the source. If there is more than one sample
within a file, the samples will be saved to seperate files. The file names of
samples will have the name of the corresponding source file but with the sample
index appended. E.g. If the the source file name is 'images.tiff' then the first
sample's prediction will be saved with the file name "image_0.tiff".
Input can be a PredictDataModule instance, a path to a data file, or a numpy
array.

If `data_type`, `axes` and `tile_size` are not provided, the training
configuration parameters will be used, with the `patch_size` instead of
`tile_size`.

Test-time augmentation (TTA) can be switched on using the `tta_transforms`
parameter. The TTA augmentation applies all possible flip and 90 degrees
rotations to the prediction input and averages the predictions. TTA augmentation
should not be used if you did not train with these augmentations.

Note that if you are using a UNet model and tiling, the tile size must be
divisible in every dimension by 2**d, where d is the depth of the model. This
avoids artefacts arising from the broken shift invariance induced by the
pooling layers of the UNet. If your image has less dimensions, as it may
happen in the Z dimension, consider padding your image.

Parameters
----------
source : PredictDataModule, pathlib.Path or str
Data to predict on.
batch_size : int, default=1
Batch size for prediction.
tile_size : tuple of int, optional
Size of the tiles to use for prediction.
tile_overlap : tuple of int, default=(48, 48)
Overlap between tiles.
axes : str, optional
Axes of the input data, by default None.
data_type : {"array", "tiff", "custom"}, optional
Type of the input data.
tta_transforms : bool, default=True
Whether to apply test-time augmentation.
dataloader_params : dict, optional
Parameters to pass to the dataloader.
read_source_func : Callable, optional
Function to read the source data.
extension_filter : str, default=""
Filter for the file extension.
write_type : {"tiff", "custom"}, default="tiff"
The data type to save as, includes custom.
write_extension : str, optional
If a known `write_type` is selected this argument is ignored. For a custom
`write_type` an extension to save the data with must be passed.
write_func : WriteFunc, optional
If a known `write_type` is selected this argument is ignored. For a custom
`write_type` a function to save the data must be passed. See notes below.
write_func_kwargs : dict of {str: any}, optional
Additional keyword arguments to be passed to the save function.
prediction_dir : Path | str, default="predictions"
The path to save the prediction results to. If `prediction_dir` is not
absolute, the directory will be assumed to be relative to the pre-set
`work_dir`. If the directory does not exist it will be created.
**kwargs : Any
Unused.

Raises
------
ValueError
If `write_type` is custom and `write_extension` is None.
ValueError
If `write_type` is custom and `write_fun is None.
ValueError
If `source` is not `str`, `Path` or `PredictDataModule`
"""
if write_func_kwargs is None:
write_func_kwargs = {}

if Path(prediction_dir).is_absolute():
write_dir = Path(prediction_dir)
else:
write_dir = self.work_dir / prediction_dir
write_dir.mkdir(exist_ok=True, parents=True)

# guards for custom types
if write_type == SupportedData.CUSTOM:
if write_extension is None:
raise ValueError(
"A `write_extension` must be provided for custom write types."
)
if write_func is None:
raise ValueError(
"A `write_func` must be provided for custom write types."
)
else:
write_func = get_write_func(write_type)
write_extension = SupportedData.get_extension(write_type)

# extract file names
if isinstance(source, PredictDataModule):
# assert not isinstance(source.pred_data, )
source_file_paths = list_files(
source.pred_data, source.data_type, source.extension_filter
)
elif isinstance(source, (str, Path)):
assert self.cfg.data_config.data_type != "array"
data_type = data_type or self.cfg.data_config.data_type
extension_filter = SupportedData.get_extension_pattern(
SupportedData(data_type)
)
source_file_paths = list_files(source, data_type, extension_filter)
else:
raise ValueError(f"Unsupported source type: '{type(source)}'.")

# predict and write each file in turn
for source_path in source_file_paths:
# source_path is relative to original source path...
# should mirror original directory structure
prediction = self.predict(
source=source_path,
batch_size=batch_size,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes,
data_type=data_type,
tta_transforms=tta_transforms,
dataloader_params=dataloader_params,
read_source_func=read_source_func,
extension_filter=extension_filter,
**kwargs,
)
# TODO: cast to float16?
write_data = np.concatenate(prediction)

# create directory structure and write path
file_write_dir = write_dir / source_path.parent.name
file_write_dir.mkdir(parents=True, exist_ok=True)
write_path = (file_write_dir / source_path.name).with_suffix(
write_extension
)

# write data
write_func(file_path=write_path, img=write_data)

def export_to_bmz(
self,
path_to_archive: Union[Path, str],
Expand Down
Loading
Loading