Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

configurable normalizations #68

Merged
merged 8 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ data:
batch_size: 32
num_workers: 16
yx_patch_size: [256, 256]
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [source]
level: 'fov_statistics',
subtrahend: 'mean'
divisor: 'std'
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [target_1]
level: 'fov_statistics',
subtrahend: 'median'
divisor: 'iqr'
augmentations:
- class_path: viscy.transforms.RandWeightedCropd
init_args:
Expand Down
141 changes: 39 additions & 102 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from viscy.data.typing import ChannelMap, Sample
from viscy.transforms import NormalizeSampled


def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]:
"""
Expand Down Expand Up @@ -55,24 +58,6 @@ def _search_int_in_str(pattern: str, file_name: str) -> str:
raise ValueError(f"Cannot find pattern {pattern} in {file_name}.")


class ChannelMap(TypedDict, total=False):
"""Source and target channel names."""

source: Union[str, Sequence[str]]
# optional
target: Union[str, Sequence[str]]


class Sample(TypedDict, total=False):
"""Image sample type for mini-batches."""

index: tuple[str, int, int]
# optional
source: Union[Tensor, Sequence[Tensor]]
target: Union[Tensor, Sequence[Tensor]]
labels: Union[Tensor, Sequence[Tensor]]


def _collate_samples(batch: Sequence[Sample]) -> Sample:
"""Collate samples into a batch sample.

Expand All @@ -89,38 +74,6 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample:
return collated


class NormalizeSampled(MapTransform, InvertibleTransform):
"""Dictionary transform to only normalize target (fluorescence) channel.

:param Union[str, Iterable[str]] keys: keys to normalize
:param dict[str, dict] norm_meta: Plate normalization metadata
written in preprocessing
"""

def __init__(
self, keys: Union[str, Iterable[str]], norm_meta: dict[str, dict]
) -> None:
if set(keys) > set(norm_meta.keys()):
raise KeyError(f"{keys} is not a subset of {norm_meta.keys()}")
super().__init__(keys, allow_missing_keys=False)
self.norm_meta = norm_meta

def _stat(self, key: str) -> dict:
# FIXME: hard-coded key
return self.norm_meta[key]["dataset_statistics"]

def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
d = dict(data)
for key in self.keys:
d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"]
return d

def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
d = dict(data)
for key in self.keys:
d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"]


class SlidingWindowDataset(Dataset):
"""Torch dataset where each element is a window of
(C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``.
Expand Down Expand Up @@ -161,21 +114,24 @@ def _get_windows(self) -> None:
w = 0
self.window_keys = []
self.window_arrays = []
self.window_norm_meta = []
for fov in self.positions:
img_arr = fov["0"]
ts = img_arr.frames
zs = img_arr.slices - self.z_window_size + 1
w += ts * zs
self.window_keys.append(w)
self.window_arrays.append(img_arr)
self.window_norm_meta.append(fov.zattrs["normalization"])
self._max_window = w

def _find_window(self, index: int) -> tuple[int, int]:
"""Look up window given index."""
window_idx = sorted(self.window_keys + [index + 1]).index(index + 1)
w = self.window_keys[window_idx]
tz = index - self.window_keys[window_idx - 1] if window_idx > 0 else index
return self.window_arrays[self.window_keys.index(w)], tz
norm_meta = self.window_norm_meta[self.window_keys.index(w)]
return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta)

def _read_img_window(
self, img: ImageArray, ch_idx: list[str], tz: int
Expand Down Expand Up @@ -216,7 +172,7 @@ def _stack_channels(
]

def __getitem__(self, index: int) -> Sample:
img, tz = self._find_window(index)
img, tz, norm_meta = self._find_window(index)
ch_names = self.channels["source"].copy()
ch_idx = self.source_ch_idx.copy()
if self.target_ch_idx is not None:
Expand All @@ -229,6 +185,7 @@ def __getitem__(self, index: int) -> Sample:
# since adding a reference to a tensor does not copy
# maybe write a weight map in preprocessing to use more information?
sample_images["weight"] = sample_images[self.channels["target"][0]]
sample_images["norm_meta"] = norm_meta
if self.transform:
sample_images = self.transform(sample_images)
# if isinstance(sample_images, list):
Expand All @@ -238,6 +195,7 @@ def __getitem__(self, index: int) -> Sample:
sample = {
"index": sample_index,
"source": self._stack_channels(sample_images, "source"),
"norm_meta": norm_meta,
}
if self.target_ch_idx is not None:
sample["target"] = self._stack_channels(sample_images, "target")
Expand Down Expand Up @@ -312,13 +270,13 @@ class HCSDataModule(LightningDataModule):
defaults to "2.5D"
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
defaults to (256, 256)
:param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms
applied to selected channels, defaults to None (no normalization)
:param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms
applied to the training set, defaults to None (no augmentation)
:param bool caching: whether to decompress all the images and cache the result,
will store in ``/tmp/$SLURM_JOB_ID/`` if available,
defaults to False
:param bool normalize_source: whether to normalize the source channel,
defaults to False
:param Optional[Path] ground_truth_masks: path to the ground truth masks,
used in the test stage to compute segmentation metrics,
defaults to None
Expand All @@ -337,9 +295,9 @@ def __init__(
num_workers: int = 8,
architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
yx_patch_size: tuple[int, int] = (256, 256),
normalizations: Optional[list[MapTransform]] = None,
augmentations: Optional[list[MapTransform]] = None,
caching: bool = False,
normalize_source: bool = False,
ground_truth_masks: Optional[Path] = None,
predict_scale_source: Optional[float] = None,
):
Expand All @@ -353,21 +311,11 @@ def __init__(
self.z_window_size = z_window_size
self.split_ratio = split_ratio
self.yx_patch_size = yx_patch_size
self.normalizations = normalizations
self.augmentations = augmentations
self.caching = caching
self.normalize_source = normalize_source
self.ground_truth_masks = ground_truth_masks
self.tmp_zarr = None
if predict_scale_source is not None:
if not normalize_source:
raise ValueError(
"Intensity scaling must be applied to normalized source channels."
)
if predict_scale_source <= 0:
raise ValueError(
f"Intensity scaling {predict_scale_source} should be positive."
)
self.predict_scale_source = predict_scale_source

def prepare_data(self):
if not self.caching:
Expand Down Expand Up @@ -419,31 +367,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
else:
raise NotImplementedError(f"{stage} stage")

def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]:
"""Setup stages where the target is available (evaluating performance)."""
dataset_settings["channels"]["target"] = self.target_channel
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
plate = open_ome_zarr(data_path, mode="r")
# disable metadata tracking in MONAI for performance
set_track_meta(False)
# define training stage transforms
norm_keys = self.target_channel.copy()
if self.normalize_source:
norm_keys += self.source_channel
normalize_transform = NormalizeSampled(
norm_keys,
plate.zattrs["normalization"],
)
return plate, normalize_transform

def _setup_fit(self, dataset_settings: dict):
"""Set up the training and validation datasets."""
plate, normalize_transform = self._setup_eval(dataset_settings)
# Setup the transformations
# TODO: These have a fixed order for now... (normalization->augmentation->fit_transform)
fit_transform = self._fit_transform()
train_transform = Compose(
[normalize_transform] + self._train_transform() + fit_transform
self.normalizations + self._train_transform() + fit_transform
)
val_transform = Compose([normalize_transform] + fit_transform)
val_transform = Compose(self.normalizations + fit_transform)

dataset_settings["channels"]["target"] = self.target_channel
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
plate = open_ome_zarr(data_path, mode="r")

# disable metadata tracking in MONAI for performance
set_track_meta(False)
# shuffle positions, randomness is handled globally
positions = [pos for _, pos in plate.positions()]
shuffled_indices = torch.randperm(len(positions))
Expand All @@ -465,26 +404,31 @@ def _setup_fit(self, dataset_settings: dict):
**train_dataset_settings,
)
self.val_dataset = SlidingWindowDataset(
positions[num_train_fovs:], transform=val_transform, **dataset_settings
positions[num_train_fovs:],
transform=val_transform,
**dataset_settings,
)

def _setup_test(self, dataset_settings: dict):
"""Set up the test stage."""
if self.batch_size != 1:
logging.warning(f"Ignoring batch size {self.batch_size} in test stage.")
plate, normalize_transform = self._setup_eval(dataset_settings)

dataset_settings["channels"]["target"] = self.target_channel
data_path = self.tmp_zarr if self.tmp_zarr else self.data_path
plate = open_ome_zarr(data_path, mode="r")
if self.ground_truth_masks:
self.test_dataset = MaskTestDataset(
[p for _, p in plate.positions()],
transform=normalize_transform,
transform=self.normalizations,
ground_truth_masks=self.ground_truth_masks,
**dataset_settings,
norm_meta=plate.zattrs["normalization"] ** dataset_settings,
)
else:
self.test_dataset = SlidingWindowDataset(
[p for _, p in plate.positions()],
transform=normalize_transform,
**dataset_settings,
transform=self.normalizations,
norm_meta=plate.zattrs["normalization"] ** dataset_settings,
)

def _setup_predict(self, dataset_settings: dict):
Expand All @@ -506,16 +450,9 @@ def _setup_predict(self, dataset_settings: dict):
positions = [plate[fov_name]]
elif isinstance(dataset, Plate):
positions = [p for _, p in dataset.positions()]
norm_meta = dataset.zattrs["normalization"].copy()
if self.predict_scale_source is not None:
for ch in self.source_channel:
# FIXME: hard-coded key
norm_meta[ch]["dataset_statistics"]["iqr"] /= self.predict_scale_source
predict_transform = (
NormalizeSampled(self.source_channel, norm_meta)
if self.normalize_source
else None
)

predict_transform = self.normalizations

self.predict_dataset = SlidingWindowDataset(
positions=positions,
transform=predict_transform,
Expand Down
22 changes: 22 additions & 0 deletions viscy/data/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Callable, Iterable, Literal, Optional, Sequence, TypedDict, Union

from torch import Tensor


class Sample(TypedDict, total=False):
"""Image sample type for mini-batches."""

index: tuple[str, int, int]
# optional
source: Union[Tensor, Sequence[Tensor]]
target: Union[Tensor, Sequence[Tensor]]
labels: Union[Tensor, Sequence[Tensor]]
norm_meta: dict[str, dict]


class ChannelMap(TypedDict, total=False):
"""Source and target channel names."""

source: Union[str, Sequence[str]]
# optional
target: Union[str, Sequence[str]]
16 changes: 14 additions & 2 deletions viscy/preprocessing/preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,31 @@ The statistics are added as dictionaries into the .zattrs file. An example of pl
}
```

FOV level statistics added to every position:
FOV level statistics added to every position as well as the dataset_statistics to read dataset statistics:

```json
"normalization": {
"Deconvolved-Nuc": {
"dataset_statistics": {
"iqr": 149.7620086669922,
"mean": 262.2070617675781,
"median": 65.5246353149414,
"std": 890.0471801757812
},
"fov_statistics": {
"iqr": 450.4745788574219,
"mean": 486.3854064941406,
"median": 83.43557739257812,
"std": 976.02392578125
}
},
"Phase3D": {
"Phase3D": {
"dataset_statistics": {
"iqr": 0.0011349652777425945,
"mean": -1.9603044165705796e-06,
"median": 3.388232289580628e-05,
"std": 0.005480962339788675
},
"fov_statistics": {
"iqr": 0.006403466919437051,
"mean": 0.0010083537781611085,
Expand Down
Loading
Loading