Skip to content

Commit

Permalink
configurable normalizations (#68)
Browse files Browse the repository at this point in the history
* inital commit adding the normalization.

* adding dataset_statistics to each fov to facilitate the configurable augmentations

* fix indentation

* ruff

* test preprocessing

* remove redundant field

* cleanup

---------

Co-authored-by: Ziwen Liu <ziwen.liu@czbiohub.org>
  • Loading branch information
edyoshikun and ziw-liu authored Feb 26, 2024
1 parent 78aed97 commit 74e7db3
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 132 deletions.
13 changes: 13 additions & 0 deletions examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ data:
batch_size: 32
num_workers: 16
yx_patch_size: [256, 256]
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [source]
level: 'fov_statistics',
subtrahend: 'mean'
divisor: 'std'
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [target_1]
level: 'fov_statistics',
subtrahend: 'median'
divisor: 'iqr'
augmentations:
- class_path: viscy.transforms.RandWeightedCropd
init_args:
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path:
norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names}
with open_ome_zarr(dataset_path, mode="r+") as dataset:
dataset.zattrs["normalization"] = norm_meta
for _, fov in dataset.positions():
fov.zattrs["normalization"] = norm_meta
return dataset_path


Expand Down
33 changes: 10 additions & 23 deletions tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def test_preprocess(small_hcs_dataset: Path, default_channels: bool):
channel_names = dataset.channel_names
trainer = VSTrainer(accelerator="cpu")
trainer.preprocess(data_path, channel_names=channel_names, num_workers=2)
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
for channel in channel_names:
assert "dataset_statistics" in dataset.zattrs["normalization"][channel]
for _, fov in dataset.positions():
norm_metadata = fov.zattrs["normalization"]
for channel in channel_names:
assert channel in norm_metadata
assert "dataset_statistics" in norm_metadata[channel]
assert "fov_statistics" in norm_metadata[channel]


def test_datamodule_setup_predict(preprocessed_hcs_dataset):
Expand Down Expand Up @@ -45,26 +55,3 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset):
img.height,
img.width,
)


def test_datamodule_predict_scales(preprocessed_hcs_dataset):
data_path = preprocessed_hcs_dataset
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names

def get_normalized_stack(predict_scale_source):
factor = 1 if predict_scale_source is None else predict_scale_source
dm = HCSDataModule(
data_path=data_path,
source_channel=channel_names[:2],
target_channel=channel_names[2:],
z_window_size=5,
batch_size=2,
num_workers=0,
predict_scale_source=predict_scale_source,
normalize_source=True,
)
dm.setup(stage="predict")
return dm.predict_dataset[0]["source"] / factor

assert torch.allclose(get_normalized_stack(None), get_normalized_stack(2))
146 changes: 39 additions & 107 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tempfile
from glob import glob
from pathlib import Path
from typing import Callable, Iterable, Literal, Optional, Sequence, TypedDict, Union
from typing import Callable, Literal, Optional, Sequence, Union

import numpy as np
import torch
Expand All @@ -18,14 +18,15 @@
from monai.transforms import (
CenterSpatialCropd,
Compose,
InvertibleTransform,
MapTransform,
MultiSampleTrait,
RandAffined,
)
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from viscy.data.typing import ChannelMap, Sample


def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]:
"""
Expand Down Expand Up @@ -55,24 +56,6 @@ def _search_int_in_str(pattern: str, file_name: str) -> str:
raise ValueError(f"Cannot find pattern {pattern} in {file_name}.")


class ChannelMap(TypedDict, total=False):
"""Source and target channel names."""

source: Union[str, Sequence[str]]
# optional
target: Union[str, Sequence[str]]


class Sample(TypedDict, total=False):
"""Image sample type for mini-batches."""

index: tuple[str, int, int]
# optional
source: Union[Tensor, Sequence[Tensor]]
target: Union[Tensor, Sequence[Tensor]]
labels: Union[Tensor, Sequence[Tensor]]


def _collate_samples(batch: Sequence[Sample]) -> Sample:
"""Collate samples into a batch sample.
Expand All @@ -89,38 +72,6 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample:
return collated


class NormalizeSampled(MapTransform, InvertibleTransform):
"""Dictionary transform to only normalize target (fluorescence) channel.
:param Union[str, Iterable[str]] keys: keys to normalize
:param dict[str, dict] norm_meta: Plate normalization metadata
written in preprocessing
"""

def __init__(
self, keys: Union[str, Iterable[str]], norm_meta: dict[str, dict]
) -> None:
if set(keys) > set(norm_meta.keys()):
raise KeyError(f"{keys} is not a subset of {norm_meta.keys()}")
super().__init__(keys, allow_missing_keys=False)
self.norm_meta = norm_meta

def _stat(self, key: str) -> dict:
# FIXME: hard-coded key
return self.norm_meta[key]["dataset_statistics"]

def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
d = dict(data)
for key in self.keys:
d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"]
return d

def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
d = dict(data)
for key in self.keys:
d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"]


class SlidingWindowDataset(Dataset):
"""Torch dataset where each element is a window of
(C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``.
Expand Down Expand Up @@ -161,21 +112,24 @@ def _get_windows(self) -> None:
w = 0
self.window_keys = []
self.window_arrays = []
self.window_norm_meta = []
for fov in self.positions:
img_arr = fov["0"]
ts = img_arr.frames
zs = img_arr.slices - self.z_window_size + 1
w += ts * zs
self.window_keys.append(w)
self.window_arrays.append(img_arr)
self.window_norm_meta.append(fov.zattrs["normalization"])
self._max_window = w

def _find_window(self, index: int) -> tuple[int, int]:
"""Look up window given index."""
window_idx = sorted(self.window_keys + [index + 1]).index(index + 1)
w = self.window_keys[window_idx]
tz = index - self.window_keys[window_idx - 1] if window_idx > 0 else index
return self.window_arrays[self.window_keys.index(w)], tz
norm_meta = self.window_norm_meta[self.window_keys.index(w)]
return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta)

def _read_img_window(
self, img: ImageArray, ch_idx: list[str], tz: int
Expand Down Expand Up @@ -216,7 +170,7 @@ def _stack_channels(
]

def __getitem__(self, index: int) -> Sample:
img, tz = self._find_window(index)
img, tz, norm_meta = self._find_window(index)
ch_names = self.channels["source"].copy()
ch_idx = self.source_ch_idx.copy()
if self.target_ch_idx is not None:
Expand All @@ -229,6 +183,7 @@ def __getitem__(self, index: int) -> Sample:
# since adding a reference to a tensor does not copy
# maybe write a weight map in preprocessing to use more information?
sample_images["weight"] = sample_images[self.channels["target"][0]]
sample_images["norm_meta"] = norm_meta
if self.transform:
sample_images = self.transform(sample_images)
# if isinstance(sample_images, list):
Expand All @@ -238,6 +193,7 @@ def __getitem__(self, index: int) -> Sample:
sample = {
"index": sample_index,
"source": self._stack_channels(sample_images, "source"),
"norm_meta": norm_meta,
}
if self.target_ch_idx is not None:
sample["target"] = self._stack_channels(sample_images, "target")
Expand Down Expand Up @@ -312,18 +268,16 @@ class HCSDataModule(LightningDataModule):
defaults to "2.5D"
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
defaults to (256, 256)
:param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms
applied to selected channels, defaults to None (no normalization)
:param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms
applied to the training set, defaults to None (no augmentation)
:param bool caching: whether to decompress all the images and cache the result,
will store in ``/tmp/$SLURM_JOB_ID/`` if available,
defaults to False
:param bool normalize_source: whether to normalize the source channel,
defaults to False
:param Optional[Path] ground_truth_masks: path to the ground truth masks,
used in the test stage to compute segmentation metrics,
defaults to None
:param Optional[float] predict_scale_source: scale the source channel intensity,
defaults to None (no scaling)
"""

def __init__(
Expand All @@ -337,11 +291,10 @@ def __init__(
num_workers: int = 8,
architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
yx_patch_size: tuple[int, int] = (256, 256),
normalizations: Optional[list[MapTransform]] = None,
augmentations: Optional[list[MapTransform]] = None,
caching: bool = False,
normalize_source: bool = False,
ground_truth_masks: Optional[Path] = None,
predict_scale_source: Optional[float] = None,
):
super().__init__()
self.data_path = Path(data_path)
Expand All @@ -353,21 +306,11 @@ def __init__(
self.z_window_size = z_window_size
self.split_ratio = split_ratio
self.yx_patch_size = yx_patch_size
self.normalizations = normalizations
self.augmentations = augmentations
self.caching = caching
self.normalize_source = normalize_source
self.ground_truth_masks = ground_truth_masks
self.tmp_zarr = None
if predict_scale_source is not None:
if not normalize_source:
raise ValueError(
"Intensity scaling must be applied to normalized source channels."
)
if predict_scale_source <= 0:
raise ValueError(
f"Intensity scaling {predict_scale_source} should be positive."
)
self.predict_scale_source = predict_scale_source

def prepare_data(self):
if not self.caching:
Expand Down Expand Up @@ -419,31 +362,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
else:
raise NotImplementedError(f"{stage} stage")

def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]:
"""Setup stages where the target is available (evaluating performance)."""
dataset_settings["channels"]["target"] = self.target_channel
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
plate = open_ome_zarr(data_path, mode="r")
# disable metadata tracking in MONAI for performance
set_track_meta(False)
# define training stage transforms
norm_keys = self.target_channel.copy()
if self.normalize_source:
norm_keys += self.source_channel
normalize_transform = NormalizeSampled(
norm_keys,
plate.zattrs["normalization"],
)
return plate, normalize_transform

def _setup_fit(self, dataset_settings: dict):
"""Set up the training and validation datasets."""
plate, normalize_transform = self._setup_eval(dataset_settings)
# Setup the transformations
# TODO: These have a fixed order for now... (normalization->augmentation->fit_transform)
fit_transform = self._fit_transform()
train_transform = Compose(
[normalize_transform] + self._train_transform() + fit_transform
self.normalizations + self._train_transform() + fit_transform
)
val_transform = Compose([normalize_transform] + fit_transform)
val_transform = Compose(self.normalizations + fit_transform)

dataset_settings["channels"]["target"] = self.target_channel
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
plate = open_ome_zarr(data_path, mode="r")

# disable metadata tracking in MONAI for performance
set_track_meta(False)
# shuffle positions, randomness is handled globally
positions = [pos for _, pos in plate.positions()]
shuffled_indices = torch.randperm(len(positions))
Expand All @@ -465,26 +399,31 @@ def _setup_fit(self, dataset_settings: dict):
**train_dataset_settings,
)
self.val_dataset = SlidingWindowDataset(
positions[num_train_fovs:], transform=val_transform, **dataset_settings
positions[num_train_fovs:],
transform=val_transform,
**dataset_settings,
)

def _setup_test(self, dataset_settings: dict):
"""Set up the test stage."""
if self.batch_size != 1:
logging.warning(f"Ignoring batch size {self.batch_size} in test stage.")
plate, normalize_transform = self._setup_eval(dataset_settings)

dataset_settings["channels"]["target"] = self.target_channel
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
plate = open_ome_zarr(data_path, mode="r")
if self.ground_truth_masks:
self.test_dataset = MaskTestDataset(
[p for _, p in plate.positions()],
transform=normalize_transform,
transform=self.normalizations,
ground_truth_masks=self.ground_truth_masks,
**dataset_settings,
norm_meta=plate.zattrs["normalization"] ** dataset_settings,
)
else:
self.test_dataset = SlidingWindowDataset(
[p for _, p in plate.positions()],
transform=normalize_transform,
**dataset_settings,
transform=self.normalizations,
norm_meta=plate.zattrs["normalization"] ** dataset_settings,
)

def _setup_predict(self, dataset_settings: dict):
Expand All @@ -506,16 +445,9 @@ def _setup_predict(self, dataset_settings: dict):
positions = [plate[fov_name]]
elif isinstance(dataset, Plate):
positions = [p for _, p in dataset.positions()]
norm_meta = dataset.zattrs["normalization"].copy()
if self.predict_scale_source is not None:
for ch in self.source_channel:
# FIXME: hard-coded key
norm_meta[ch]["dataset_statistics"]["iqr"] /= self.predict_scale_source
predict_transform = (
NormalizeSampled(self.source_channel, norm_meta)
if self.normalize_source
else None
)

predict_transform = self.normalizations

self.predict_dataset = SlidingWindowDataset(
positions=positions,
transform=predict_transform,
Expand Down
22 changes: 22 additions & 0 deletions viscy/data/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Sequence, TypedDict, Union

from torch import Tensor


class Sample(TypedDict, total=False):
"""Image sample type for mini-batches."""

index: tuple[str, int, int]
# optional
source: Union[Tensor, Sequence[Tensor]]
target: Union[Tensor, Sequence[Tensor]]
labels: Union[Tensor, Sequence[Tensor]]
norm_meta: dict[str, dict]


class ChannelMap(TypedDict, total=False):
"""Source and target channel names."""

source: Union[str, Sequence[str]]
# optional
target: Union[str, Sequence[str]]
Loading

0 comments on commit 74e7db3

Please sign in to comment.