From c6692f1419808b26ca205aded069e9f62436e908 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 15:19:18 -0800 Subject: [PATCH 01/74] refactor data loading into its own module --- examples/demo_dlmbl/debug_log_graph.py | 2 +- examples/demo_dlmbl/solution.py | 2 +- tests/light/test_data.py | 2 +- viscy/cli/cli.py | 2 +- viscy/data/__init__.py | 0 viscy/{light/data.py => data/hcs.py} | 0 viscy/light/engine.py | 2 +- viscy/light/predict_writer.py | 2 +- viscy/scripts/profiling.py | 2 +- 9 files changed, 7 insertions(+), 7 deletions(-) create mode 100644 viscy/data/__init__.py rename viscy/{light/data.py => data/hcs.py} (100%) diff --git a/examples/demo_dlmbl/debug_log_graph.py b/examples/demo_dlmbl/debug_log_graph.py index 1819b02f..ec987118 100644 --- a/examples/demo_dlmbl/debug_log_graph.py +++ b/examples/demo_dlmbl/debug_log_graph.py @@ -19,7 +19,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # Trainer class and UNet. from viscy.light.engine import VSUNet diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 933f939d..2c81aa6f 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -83,7 +83,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # training augmentations from viscy.transforms import ( diff --git a/tests/light/test_data.py b/tests/light/test_data.py index 263f8f90..153f175f 100644 --- a/tests/light/test_data.py +++ b/tests/light/test_data.py @@ -4,7 +4,7 @@ from iohub import open_ome_zarr from pytest import mark -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule from viscy.light.trainer import VSTrainer diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py index 0946bb0f..f9a55f12 100644 --- a/viscy/cli/cli.py +++ b/viscy/cli/cli.py @@ -9,7 +9,7 @@ from lightning.pytorch.cli import LightningCLI from lightning.pytorch.loggers import TensorBoardLogger -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule from viscy.light.engine import VSUNet from viscy.light.trainer import VSTrainer diff --git a/viscy/data/__init__.py b/viscy/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/light/data.py b/viscy/data/hcs.py similarity index 100% rename from viscy/light/data.py rename to viscy/data/hcs.py diff --git a/viscy/light/engine.py b/viscy/light/engine.py index f165a056..74f14aaa 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -25,8 +25,8 @@ structural_similarity_index_measure, ) +from viscy.data.hcs import Sample from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d -from viscy.light.data import Sample from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index a6ae88cb..7a58009c 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -9,7 +9,7 @@ from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import DTypeLike, NDArray -from viscy.light.data import HCSDataModule, Sample +from viscy.data.hcs import HCSDataModule, Sample __all__ = ["HCSPredictionWriter"] _logger = logging.getLogger("lightning.pytorch") diff --git a/viscy/scripts/profiling.py b/viscy/scripts/profiling.py index 0c947f45..a0c3ca6d 100644 --- a/viscy/scripts/profiling.py +++ b/viscy/scripts/profiling.py @@ -2,7 +2,7 @@ from profilehooks import profile -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule dataset = "/path/to/dataset.zarr" From 3d8e7e2646a10e9483120ad4e12be736342cf621 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 15:26:59 -0800 Subject: [PATCH 02/74] update type annotations --- viscy/unet/networks/Unet21D.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 7c32e34b..51ed9839 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -1,11 +1,11 @@ -from typing import Callable, Literal, Optional, Sequence, Union +from typing import Callable, Literal, Sequence import timm import torch from monai.networks.blocks import Convolution, ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer from monai.networks.utils import normal_init -from torch import nn +from torch import Tensor, nn def icnr_init( @@ -45,7 +45,7 @@ def _get_convnext_stage( in_channels: int, out_channels: int, depth: int, - upsample_factor: Optional[int] = None, + upsample_factor: int | None = None, ) -> nn.Module: stage = timm.models.convnext.ConvNeXtStage( in_chs=in_channels, @@ -83,7 +83,7 @@ def __init__( stride=kernel_size, ) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -101,7 +101,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, norm_name: str, - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() spatial_dims = 2 @@ -145,11 +145,11 @@ def __init__( upsample_factor=conv_weight_init_factor, ) - def forward(self, inp: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + def forward(self, inp: Tensor, skip: Tensor) -> Tensor: """ - :param torch.Tensor inp: Low resolution features - :param torch.Tensor skip: High resolution skip connection features - :return torch.Tensor: High resolution features + :param Tensor inp: Low resolution features + :param Tensor skip: High resolution skip connection features + :return Tensor: High resolution features """ inp = self.upsample(inp) inp = torch.cat([inp, skip], dim=1) @@ -192,7 +192,7 @@ def __init__( self.out = nn.PixelShuffle(2) self.out_stack_depth = out_stack_depth - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.upsample(x) d = self.out_stack_depth + 2 b, c, h, w = x.shape @@ -209,7 +209,7 @@ class UnsqueezeHead(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = x.unsqueeze(2) return x @@ -222,7 +222,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, strides: list[int], - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() self.decoder_stages = nn.ModuleList([]) @@ -240,7 +240,7 @@ def __init__( ) self.decoder_stages.append(stage) - def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: + def forward(self, features: Sequence[Tensor]) -> Tensor: feat = features[0] # padding features.append(None) @@ -328,7 +328,7 @@ def num_blocks(self) -> int: """2-times downscaling factor of the smallest feature map""" return 6 - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.stem(x) x: list = self.encoder_stages(x) x.reverse() From fdcbf5536133291cee298c654fac5645ca4acfab Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 16:01:28 -0800 Subject: [PATCH 03/74] move the logging module out --- viscy/unet/{utils => }/logging.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename viscy/unet/{utils => }/logging.py (100%) diff --git a/viscy/unet/utils/logging.py b/viscy/unet/logging.py similarity index 100% rename from viscy/unet/utils/logging.py rename to viscy/unet/logging.py From a2913817e0c432933ba5c81c3662993c579d7e66 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 16:03:10 -0800 Subject: [PATCH 04/74] move old logging into utils --- viscy/{unet => utils}/logging.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename viscy/{unet => utils}/logging.py (100%) diff --git a/viscy/unet/logging.py b/viscy/utils/logging.py similarity index 100% rename from viscy/unet/logging.py rename to viscy/utils/logging.py From 3cf8fa23c73ce754e27498a07879ccb23db7d170 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 09:31:21 -0800 Subject: [PATCH 05/74] rename tests to match module name --- tests/{torch_unet => unet}/networks/Unet25D_tests.py | 0 tests/{torch_unet => unet}/networks/Unet2D_tests.py | 0 tests/{torch_unet => unet}/networks/layers/ConvBlock2D_tests.py | 0 tests/{torch_unet => unet}/networks/layers/ConvBlock3D_tests.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/{torch_unet => unet}/networks/Unet25D_tests.py (100%) rename tests/{torch_unet => unet}/networks/Unet2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock3D_tests.py (100%) diff --git a/tests/torch_unet/networks/Unet25D_tests.py b/tests/unet/networks/Unet25D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet25D_tests.py rename to tests/unet/networks/Unet25D_tests.py diff --git a/tests/torch_unet/networks/Unet2D_tests.py b/tests/unet/networks/Unet2D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet2D_tests.py rename to tests/unet/networks/Unet2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock2D_tests.py b/tests/unet/networks/layers/ConvBlock2D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock2D_tests.py rename to tests/unet/networks/layers/ConvBlock2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock3D_tests.py b/tests/unet/networks/layers/ConvBlock3D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock3D_tests.py rename to tests/unet/networks/layers/ConvBlock3D_tests.py From d4cd41db42ecf62b94ab26e5bbc9a4d7feecfcac Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 09:31:30 -0800 Subject: [PATCH 06/74] bump torch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8d60ee1d..b60cd534 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ "iohub==0.1.0rc0", - "torch>=2.0.0", + "torch>=2.1.2", "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.0.1", From e87d3969617de3bc7a0b47e136b8e1270dad1ea6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 16:35:30 -0800 Subject: [PATCH 07/74] draft fcmae encoder --- tests/unet/__init__.py | 0 tests/unet/test_fcmae.py | 43 ++++++ viscy/unet/networks/Unet21D.py | 2 +- viscy/unet/networks/fcmae.py | 235 +++++++++++++++++++++++++++++++++ 4 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 tests/unet/__init__.py create mode 100644 tests/unet/test_fcmae.py create mode 100644 viscy/unet/networks/fcmae.py diff --git a/tests/unet/__init__.py b/tests/unet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py new file mode 100644 index 00000000..ae8e0ec6 --- /dev/null +++ b/tests/unet/test_fcmae.py @@ -0,0 +1,43 @@ +import torch + +from viscy.unet.networks.fcmae import ( + MaskedConvNeXtV2Block, + MaskedConvNeXtV2Stage, + MaskedGlobalResponseNorm, +) + + +def test_masked_grn() -> None: + x = torch.rand(2, 3, 4, 5) + grn = MaskedGlobalResponseNorm(3, channels_last=False) + grn.gamma.data = torch.ones_like(grn.gamma.data) + mask = torch.ones((1, 1, 4, 5), dtype=torch.bool) + mask[:, :, 2:, 2:] = False + normalized = grn(x) + assert not torch.allclose(normalized, x) + assert torch.allclose(grn(x, mask)[:, :, 2:, 2:], grn(x[:, :, 2:, 2:])) + grn = MaskedGlobalResponseNorm(5, channels_last=True) + grn.gamma.data = torch.ones_like(grn.gamma.data) + mask = torch.ones((1, 3, 4, 1), dtype=torch.bool) + mask[:, 1:, 2:, :] = False + assert torch.allclose(grn(x, mask)[:, 1:, 2:, :], grn(x[:, 1:, 2:, :])) + + +def test_masked_convnextv2_block() -> None: + x = torch.rand(2, 3, 4, 5) + mask = x[0, 0] > 0.5 + block = MaskedConvNeXtV2Block(3, 3 * 2) + assert len(block(x).unique()) == x.numel() * 2 + block = MaskedConvNeXtV2Block(3, 3) + masked_out = block(x, mask) + assert len(masked_out[:, :, mask].unique()) == x.shape[1] + + +def test_masked_convnextv2_stage() -> None: + x = torch.rand(2, 3, 16, 16) + mask = torch.rand(4, 4) > 0.5 + stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) + out = stage(x) + assert out.shape == (2, 3, 8, 8) + masked_out = stage(x, mask) + assert not torch.allclose(masked_out, out) diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 51ed9839..c4320240 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -12,7 +12,7 @@ def icnr_init( conv: nn.Module, upsample_factor: int, upsample_dims: int, - init=nn.init.kaiming_normal_, + init: Callable = nn.init.kaiming_normal_, ): """ ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py new file mode 100644 index 00000000..818e8f88 --- /dev/null +++ b/viscy/unet/networks/fcmae.py @@ -0,0 +1,235 @@ +""" +Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 +based on the official JAX example in +https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax +also referring to timm's dense implementation of the encoder in ``timm.models.convnext`` +""" + + +from typing import Callable, Literal, Sequence + +import torch +from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ +from timm.models.convnext import Downsample +from torch import BoolTensor, Tensor, nn + + +def _upsample_mask(mask: BoolTensor, features: Tensor) -> BoolTensor: + mask = mask[..., :, :][None, None] + if features.shape[-2:] != mask.shape[-2:]: + if not all(i % j == 0 for i, j in zip(features.shape[-2:], mask.shape[-2:])): + raise ValueError( + f"feature map shape {features.shape} must be divisible by " + f"mask shape {mask.shape}." + ) + mask = mask.repeat_interleave( + features.shape[-2] // mask.shape[-2], dim=-2 + ).repeat_interleave(features.shape[-1] // mask.shape[-1], dim=-1) + return mask + + +class MaskedGlobalResponseNorm(nn.Module): + """ + Masked Global Response Normalization. + + :param int dim: number of input channels + :param float eps: small value added for numerical stability, + defaults to 1e-6 + :param bool channels_last: BHWC (True) or BCHW (False) dimension ordering, + defaults to False + """ + + def __init__( + self, dim: int, eps: float = 1e-6, channels_last: bool = False + ) -> None: + super().__init__() + if channels_last: + self.spatial_dim = (1, 2) + self.channel_dim = -1 + weights_shape = (1, 1, 1, dim) + else: + self.spatial_dim = (2, 3) + self.channel_dim = 1 + weights_shape = (1, dim, 1, 1) + self.gamma = nn.Parameter(torch.zeros(weights_shape)) + self.beta = nn.Parameter(torch.zeros(weights_shape)) + self.eps = eps + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor, BHWC or BCHW + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: normalized tensor + """ + samples = x if mask is None else x * ~mask + g_x = samples.norm(p=2, dim=self.spatial_dim, keepdim=True) + n_x = g_x / (g_x.mean(dim=self.channel_dim, keepdim=True) + self.eps) + return x + torch.addcmul(self.beta, self.gamma, x * n_x) + + +class MaskedConvNeXtV2Block(nn.Module): + """Masked ConvNeXt V2 Block. + + :param int in_channels: input channels + :param int | None out_channels: output channels, defaults to None + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsample stride, defaults to 1 + :param int mlp_ratio: MLP expansion ratio, defaults to 4 + :param float drop_path: drop path rate, defaults to 0.0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + kernel_size: int = 7, + stride: int = 1, + mlp_ratio: int = 4, + drop_path: float = 0.0, + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + self.dwconv = create_conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + depthwise=True, + ) + self.layernorm = LayerNorm2d(out_channels) + self.pwconv1 = nn.Conv2d(out_channels, mlp_ratio * out_channels, kernel_size=1) + self.act = nn.GELU() + self.grn = MaskedGlobalResponseNorm(mlp_ratio * out_channels) + self.pwconv2 = nn.Conv2d(mlp_ratio * out_channels, out_channels, kernel_size=1) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if in_channels != out_channels or stride > 1: + self.shortcut = Downsample(in_channels, out_channels, stride=stride) + else: + self.shortcut = nn.Identity() + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + shortcut = self.shortcut(x) + if mask is not None: + x *= ~mask + x = self.dwconv(x) + if mask is not None: + x *= ~mask + x = self.layernorm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x, mask) + x = self.pwconv2(x) + x = self.drop_path(x) + shortcut + return x + + +class MaskedConvNeXtV2Stage(nn.Module): + """Masked ConvNeXt V2 Stage. + + :param int in_channels: input channels + :param int out_channels: output channels + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsampling factor of this stage, defaults to 2 + :param int num_blocks: number of residual blocks, defaults to 2 + :param Sequence[float] | None drop_path_rates: drop path rates of each block, + defaults to None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 7, + stride: int = 2, + num_blocks: int = 2, + drop_path_rates: Sequence[float] | None = None, + ) -> None: + super().__init__() + if drop_path_rates is None: + drop_path_rates = [0.0] * num_blocks + elif len(drop_path_rates) != num_blocks: + raise ValueError( + "length of drop_path_rates must be equal to " + f"the number of blocks {num_blocks}, got {len(drop_path_rates)}." + ) + if in_channels != out_channels or stride > 1: + downsample_kernel_size = stride if stride > 1 else 1 + self.downsample = nn.Sequential( + LayerNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=downsample_kernel_size, + stride=stride, + padding=0, + ), + ) + in_channels = out_channels + else: + self.downsample = nn.Identity() + self.blocks = nn.ModuleList() + for i in range(num_blocks): + self.blocks.append( + MaskedConvNeXtV2Block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + drop_path=drop_path_rates[i], + ) + ) + in_channels = out_channels + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + x = self.downsample(x) + if mask is not None: + mask = _upsample_mask(mask, x) + for block in self.blocks: + x = block(x, mask) + return x + + +class MaskedMultiscaleEncoder(nn.Module): + def __init__( + self, + in_channels: int, + stage_blocks: Sequence[int] = (3, 3, 9, 3), + dims: Sequence[int] = (96, 192, 384, 768), + drop_path_rate: float = 0.0, + ) -> None: + super().__init__() + self.stages = nn.ModuleList() + chs = [in_channels, *dims] + for i, num_blocks in enumerate(stage_blocks): + self.stages.append( + MaskedConvNeXtV2Stage( + chs[i], + chs[i + 1], + kernel_size=7, + stride=2, + num_blocks=num_blocks, + drop_path_rates=[drop_path_rate] * num_blocks, + ) + ) + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + features = [] + for stage in self.stages: + x = stage(x, mask) + features.append(x) + return features From dccce5f785581300dd4387f2d5f0548be50af5bf Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Jan 2024 15:14:36 -0800 Subject: [PATCH 08/74] add stem to the encoder --- tests/unet/test_fcmae.py | 37 +++++++++++++ viscy/unet/networks/fcmae.py | 101 ++++++++++++++++++++++++++++++----- 2 files changed, 125 insertions(+), 13 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index ae8e0ec6..73dc5920 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,9 +1,12 @@ import torch from viscy.unet.networks.fcmae import ( + AdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedGlobalResponseNorm, + MaskedMultiscaleEncoder, + upsample_mask, ) @@ -41,3 +44,37 @@ def test_masked_convnextv2_stage() -> None: assert out.shape == (2, 3, 8, 8) masked_out = stage(x, mask) assert not torch.allclose(masked_out, out) + + +def test_adaptive_projection() -> None: + proj = AdaptiveProjection(3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5) + assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) + assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) + proj = AdaptiveProjection( + 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 + ) + assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) + + +def test_masked_multiscale_encoder() -> None: + xy_size = 64 + dims = [12, 24, 48, 96] + x = torch.rand(2, 3, 5, xy_size, xy_size) + encoder = MaskedMultiscaleEncoder(3, dims=dims) + # auto_masked_features, mask = encoder(x, mask_ratio=0.5) + auto_masked_features = encoder(x) + target_shape = list(x.shape) + target_shape.pop(1) + pre_masked_features = encoder(x) #encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) + assert len(auto_masked_features) == len(pre_masked_features) == 4 + for i, (dim, afeat, pfeat) in enumerate( + zip(dims, auto_masked_features, pre_masked_features) + ): + assert afeat.shape[0] == x.shape[0] + assert afeat.shape[1] == dim + stride = 2 * 2 ** (i + 1) + assert afeat.shape[2] == afeat.shape[3] == xy_size // stride + assert torch.allclose(afeat, pfeat, rtol=1e-1, atol=5e-2), ( + i, + (afeat - pfeat).abs().max(), + ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 818e8f88..71644955 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -11,20 +11,19 @@ import torch from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ from timm.models.convnext import Downsample -from torch import BoolTensor, Tensor, nn +from torch import BoolTensor, Size, Tensor, nn -def _upsample_mask(mask: BoolTensor, features: Tensor) -> BoolTensor: - mask = mask[..., :, :][None, None] - if features.shape[-2:] != mask.shape[-2:]: - if not all(i % j == 0 for i, j in zip(features.shape[-2:], mask.shape[-2:])): +def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: + if target[-2:] != mask.shape[-2:]: + if not all(i % j == 0 for i, j in zip(target, mask.shape)): raise ValueError( - f"feature map shape {features.shape} must be divisible by " + f"feature map shape {target} must be divisible by " f"mask shape {mask.shape}." ) mask = mask.repeat_interleave( - features.shape[-2] // mask.shape[-2], dim=-2 - ).repeat_interleave(features.shape[-1] // mask.shape[-1], dim=-1) + target[-2] // mask.shape[-2], dim=-2 + ).repeat_interleave(target[-1] // mask.shape[-1], dim=-1) return mask @@ -193,12 +192,64 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: """ x = self.downsample(x) if mask is not None: - mask = _upsample_mask(mask, x) + mask = upsample_mask(mask, x.shape) for block in self.blocks: x = block(x, mask) return x +class AdaptiveProjection(nn.Module): + """ + Patchifying layer for projecting 2D or 3D input into 2D feature maps. + Masking is not needed because the mask will cover entire patches. + + :param int in_channels: input channels + :param int out_channels: output channels + :param Sequence[int, int] | int kernel_size_2d: kernel width and height + :param int kernel_depth: kernel depth for 3D input + :param int in_stack_depth: input stack depth for 3D input + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size_2d: tuple[int, int] | int = 4, + kernel_depth: int = 5, + in_stack_depth: int = 5, + ) -> None: + super().__init__() + ratio = in_stack_depth // kernel_depth + if isinstance(kernel_size_2d, int): + kernel_size_2d = [kernel_size_2d] * 2 + kernel_size_3d = [kernel_depth, *kernel_size_2d] + self.conv3d = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels // ratio, + kernel_size=kernel_size_3d, + stride=kernel_size_3d, + ) + self.conv2d = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size_2d, + stride=kernel_size_2d, + ) + + def forward(self, x: Tensor) -> Tensor: + """ + :param Tensor x: input tensor (BCDHW) + :return Tensor: output tensor (BCHW) + """ + if x.shape[2] > 1: + x = self.conv3d(x) + b, c, d, h, w = x.shape + # project Z/depth into channels + # return a view when possible (contiguous) + return x.reshape(b, c * d, h, w) + return self.conv2d(x.squeeze(2)) + + class MaskedMultiscaleEncoder(nn.Module): def __init__( self, @@ -208,28 +259,52 @@ def __init__( drop_path_rate: float = 0.0, ) -> None: super().__init__() + stem_kernel_size_2d = 4 + self.stem = nn.Sequential( + AdaptiveProjection( + in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 + ), + LayerNorm2d(dims[0]), + ) self.stages = nn.ModuleList() - chs = [in_channels, *dims] + chs = [dims[0], *dims] for i, num_blocks in enumerate(stage_blocks): + stride = 1 if i == 0 else 2 self.stages.append( MaskedConvNeXtV2Stage( chs[i], chs[i + 1], kernel_size=7, - stride=2, + stride=stride, num_blocks=num_blocks, drop_path_rates=[drop_path_rate] * num_blocks, ) ) + self.total_stride = stem_kernel_size_2d * 2 ** (len(self.stages) - 1) - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param float mask_ratio: ratio of the feature maps to mask, + defaults to 0.0 (no masking) :return Tensor: output tensor (BCHW) """ + if mask_ratio > 0.0: + noise = torch.rand( + x.shape[0], + 1, + x.shape[-2] // self.total_stride, + x.shape[-1] // self.total_stride, + device=x.device, + ) + mask = noise > mask_ratio + else: + mask = None + x = self.stem(x) features = [] for stage in self.stages: x = stage(x, mask) features.append(x) + if mask is not None: + return features, mask return features From 55087315f6783417acad17500e0fe3b47899b125 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Jan 2024 15:56:23 -0800 Subject: [PATCH 09/74] wip: masked stem layernorm --- tests/unet/test_fcmae.py | 15 +++++++------- viscy/unet/networks/fcmae.py | 38 ++++++++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 73dc5920..b9a3d389 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,7 +1,7 @@ import torch from viscy.unet.networks.fcmae import ( - AdaptiveProjection, + MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedGlobalResponseNorm, @@ -47,10 +47,12 @@ def test_masked_convnextv2_stage() -> None: def test_adaptive_projection() -> None: - proj = AdaptiveProjection(3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5) + proj = MaskedAdaptiveProjection( + 3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5 + ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - proj = AdaptiveProjection( + proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) @@ -61,11 +63,10 @@ def test_masked_multiscale_encoder() -> None: dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) encoder = MaskedMultiscaleEncoder(3, dims=dims) - # auto_masked_features, mask = encoder(x, mask_ratio=0.5) - auto_masked_features = encoder(x) + auto_masked_features, mask = encoder(x, mask_ratio=0.5) target_shape = list(x.shape) target_shape.pop(1) - pre_masked_features = encoder(x) #encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) + pre_masked_features = encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) assert len(auto_masked_features) == len(pre_masked_features) == 4 for i, (dim, afeat, pfeat) in enumerate( zip(dims, auto_masked_features, pre_masked_features) @@ -74,7 +75,7 @@ def test_masked_multiscale_encoder() -> None: assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride - assert torch.allclose(afeat, pfeat, rtol=1e-1, atol=5e-2), ( + assert torch.allclose(afeat, pfeat, rtol=5e-2, atol=5e-2), ( i, (afeat - pfeat).abs().max(), ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 71644955..416c50ad 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -9,7 +9,14 @@ from typing import Callable, Literal, Sequence import torch -from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ +from timm.layers import ( + DropPath, + GlobalResponseNormMlp, + LayerNorm2d, + LayerNorm, + create_conv2d, + trunc_normal_, +) from timm.models.convnext import Downsample from torch import BoolTensor, Size, Tensor, nn @@ -198,10 +205,9 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: return x -class AdaptiveProjection(nn.Module): +class MaskedAdaptiveProjection(nn.Module): """ - Patchifying layer for projecting 2D or 3D input into 2D feature maps. - Masking is not needed because the mask will cover entire patches. + Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. :param int in_channels: input channels :param int out_channels: output channels @@ -235,19 +241,35 @@ def __init__( kernel_size=kernel_size_2d, stride=kernel_size_2d, ) + self.norm = nn.LayerNorm(out_channels) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: """ :param Tensor x: input tensor (BCDHW) + :param BoolTensor mask: boolean mask (B1HW), defaults to None :return Tensor: output tensor (BCHW) """ + # no need to mask before convolutions since patches do not spill over if x.shape[2] > 1: x = self.conv3d(x) b, c, d, h, w = x.shape # project Z/depth into channels # return a view when possible (contiguous) - return x.reshape(b, c * d, h, w) - return self.conv2d(x.squeeze(2)) + x = x.reshape(b, c * d, h, w) + else: + x = self.conv2d(x.squeeze(2)) + out_shape = x.shape + if mask is not None: + mask = upsample_mask(mask, x.shape) + x = x[mask] + else: + x = x.flatten(2) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + if mask is not None: + out = torch.zeros(out_shape, device=x.device) + out[mask] = x class MaskedMultiscaleEncoder(nn.Module): @@ -261,7 +283,7 @@ def __init__( super().__init__() stem_kernel_size_2d = 4 self.stem = nn.Sequential( - AdaptiveProjection( + MaskedAdaptiveProjection( in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 ), LayerNorm2d(dims[0]), From 3eec48ed78908eb44edf8cd96991da2b79c8cece Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 16 Jan 2024 20:23:32 -0800 Subject: [PATCH 10/74] wip: patchify masked features for linear --- tests/unet/test_fcmae.py | 51 +++++++++++++-- viscy/unet/networks/fcmae.py | 122 +++++++++++++++++++++-------------- 2 files changed, 119 insertions(+), 54 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index b9a3d389..fc534981 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -4,13 +4,52 @@ MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, - MaskedGlobalResponseNorm, + # MaskedGlobalResponseNorm, MaskedMultiscaleEncoder, + generate_mask, + masked_patchify, + masked_unpatchify, upsample_mask, ) -def test_masked_grn() -> None: +def test_generate_mask(): + w = 64 + s = 16 + m = 0.75 + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m) + assert mask.shape == (2, 1, w // s, w // s) + assert mask.dtype == torch.bool + ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] + assert torch.allclose(ratio, torch.ones_like(ratio) * m) + + +def test_masked_patchify(): + b, c, h, w = 2, 3, 4, 8 + x = torch.rand(b, c, h, w) + mask_ratio = 0.75 + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) + mask = upsample_mask(mask, x.shape) + feat = masked_patchify(x, mask) + assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) + + +def test_unmasked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + y = masked_unpatchify(masked_patchify(x), out_shape=x.shape) + assert torch.allclose(x, y) + + +def test_masked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) + mask = upsample_mask(mask, x.shape) + y = masked_unpatchify(masked_patchify(x, mask), out_shape=x.shape, mask=mask) + assert torch.all((y == 0) ^ (x == y)) + assert torch.all((y == 0)[:, 0:1] == mask) + + +def test_masked_grn(): x = torch.rand(2, 3, 4, 5) grn = MaskedGlobalResponseNorm(3, channels_last=False) grn.gamma.data = torch.ones_like(grn.gamma.data) @@ -36,7 +75,7 @@ def test_masked_convnextv2_block() -> None: assert len(masked_out[:, :, mask].unique()) == x.shape[1] -def test_masked_convnextv2_stage() -> None: +def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) mask = torch.rand(4, 4) > 0.5 stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) @@ -46,19 +85,21 @@ def test_masked_convnextv2_stage() -> None: assert not torch.allclose(masked_out, out) -def test_adaptive_projection() -> None: +def test_adaptive_projection(): proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5 ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) + mask = torch.rand(2, 1, 2, 2) > 0.5 + masked_out = proj(torch.rand(2, 3, 5, 16, 16), mask) proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) -def test_masked_multiscale_encoder() -> None: +def test_masked_multiscale_encoder(): xy_size = 64 dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 416c50ad..d852f780 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -2,18 +2,18 @@ Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 based on the official JAX example in https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax -also referring to timm's dense implementation of the encoder in ``timm.models.convnext`` +and timm's dense implementation of the encoder in ``timm.models.convnext`` """ from typing import Callable, Literal, Sequence import torch +import torch.nn.functional as F from timm.layers import ( DropPath, GlobalResponseNormMlp, LayerNorm2d, - LayerNorm, create_conv2d, trunc_normal_, ) @@ -21,7 +21,27 @@ from torch import BoolTensor, Size, Tensor, nn +def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: + """ + :param Size target: target shape + :param int stride: total stride + :param float mask_ratio: ratio of the pixels to mask + :return BoolTensor: boolean mask (N, H*W) + """ + m_height = target[-2] // stride + m_width = target[-1] // stride + mask_numel = m_height * m_width + masked_elements = int(mask_numel * mask_ratio) + mask = torch.rand(target[0], mask_numel).argsort(1) < masked_elements + return mask.reshape(target[0], 1, m_height, m_width) + + def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: + """ + :param BoolTensor mask: low-resolution boolean mask (B1HW) + :param Size target: target size (BCHW) + :return BoolTensor: upsampled boolean mask (B1HW) + """ if target[-2:] != mask.shape[-2:]: if not all(i % j == 0 for i, j in zip(target, mask.shape)): raise ValueError( @@ -34,43 +54,48 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: return mask -class MaskedGlobalResponseNorm(nn.Module): +def masked_patchify(features: Tensor, mask: BoolTensor | None = None) -> Tensor: """ - Masked Global Response Normalization. - - :param int dim: number of input channels - :param float eps: small value added for numerical stability, - defaults to 1e-6 - :param bool channels_last: BHWC (True) or BCHW (False) dimension ordering, - defaults to False + :param Tensor features: input image features (BCHW) + :param BoolTensor mask: boolean mask (B1HW) + :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) """ + if mask is None: + return features.flatten(2).permute(0, 2, 1) + b, c = features.shape[:2] + # (B, C, H, W) -> (B, H, W, C) + features = features.permute(0, 2, 3, 1) + # (B, H, W, C) -> (B * L, C) -> (B, L, C) + features = features[~mask[:, 0]].reshape(b, -1, c) - def __init__( - self, dim: int, eps: float = 1e-6, channels_last: bool = False - ) -> None: - super().__init__() - if channels_last: - self.spatial_dim = (1, 2) - self.channel_dim = -1 - weights_shape = (1, 1, 1, dim) - else: - self.spatial_dim = (2, 3) - self.channel_dim = 1 - weights_shape = (1, dim, 1, 1) - self.gamma = nn.Parameter(torch.zeros(weights_shape)) - self.beta = nn.Parameter(torch.zeros(weights_shape)) - self.eps = eps + # kernel_size = tuple(features.shape[-i] // mask.shape[-i] for i in (2, 1)) + # # (B, C, H, W) -> (B, C * H_patch * Wp, H_grid * Wg) + # features = F.unfold(features, kernel_size=kernel_size, stride=kernel_size) + # patch_size = kernel_size[0] * kernel_size[1] + # # (B, C * Hp * Wp, Hg * Wg) -> (B, C, Hp * Wp, Hg * Wg) -> (B, Hg * Wg, C, Hp * Wp) + # features = features.view(b, c, patch_size, -1).permute(0, 3, 1, 2) + # # (B, 1, Hg, Wg) -> (B, Hg*Wg) + # idx = ~mask.flatten(1) + # # (B, Hg * Wg, C, Hp * Wp) -> (B * L, C, Hp * Wp) -> (B, L, C, Hp * Wp) + # features = features[idx].view(b, -1, c, patch_size) + # # (B, L, C, Hp * Wp) -> (B, L, Hp * Wp, C) -> (B, L * Hp * Wp, C) + # features = features.permute(0, 1, 3, 2).reshape(b, -1, c) + return features - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: - """ - :param Tensor x: input tensor, BHWC or BCHW - :param BoolTensor | None mask: boolean mask, defaults to None - :return Tensor: normalized tensor - """ - samples = x if mask is None else x * ~mask - g_x = samples.norm(p=2, dim=self.spatial_dim, keepdim=True) - n_x = g_x / (g_x.mean(dim=self.channel_dim, keepdim=True) + self.eps) - return x + torch.addcmul(self.beta, self.gamma, x * n_x) + +def masked_unpatchify( + features: Tensor, out_shape: Size, mask: BoolTensor | None = None +) -> Tensor: + if mask is None: + # (B, L, C) -> (B, C, L) -> (B, C, H, W) + return features.permute(0, 2, 1).reshape(out_shape) + b, c, w, h = out_shape + out = torch.zeros((b, w, h, c), device=features.device, dtype=features.dtype) + # (B, L, C) -> (B * L, C) + features = features.reshape(-1, c) + out[~mask[:, 0]] = features + # (B, H, W, C) -> (B, C, H, W) + return out.permute(0, 3, 1, 2) class MaskedConvNeXtV2Block(nn.Module): @@ -102,11 +127,13 @@ def __init__( stride=stride, depthwise=True, ) - self.layernorm = LayerNorm2d(out_channels) - self.pwconv1 = nn.Conv2d(out_channels, mlp_ratio * out_channels, kernel_size=1) - self.act = nn.GELU() - self.grn = MaskedGlobalResponseNorm(mlp_ratio * out_channels) - self.pwconv2 = nn.Conv2d(mlp_ratio * out_channels, out_channels, kernel_size=1) + self.layernorm = nn.LayerNorm(out_channels) + mid_channels = mlp_ratio * out_channels + self.mlp = GlobalResponseNormMlp( + in_features=out_channels, + hidden_features=mid_channels, + out_features=out_channels, + ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() if in_channels != out_channels or stride > 1: self.shortcut = Downsample(in_channels, out_channels, stride=stride) @@ -125,6 +152,8 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: x = self.dwconv(x) if mask is not None: x *= ~mask + out_shape = x.shape + x = masked_project(x, mask) x = self.layernorm(x) x = self.pwconv1(x) x = self.act(x) @@ -268,8 +297,10 @@ def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: x = self.norm(x) x = x.permute(0, 2, 1) if mask is not None: - out = torch.zeros(out_shape, device=x.device) + out = torch.zeros(out_shape, device=x.device, dtype=x.dtype) out[mask] = x + return out + return x.reshape(out_shape) class MaskedMultiscaleEncoder(nn.Module): @@ -312,14 +343,7 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: :return Tensor: output tensor (BCHW) """ if mask_ratio > 0.0: - noise = torch.rand( - x.shape[0], - 1, - x.shape[-2] // self.total_stride, - x.shape[-1] // self.total_stride, - device=x.device, - ) - mask = noise > mask_ratio + mask = generate_mask(x.shape, self.total_stride, mask_ratio) else: mask = None x = self.stem(x) From 8c54febcf71f8074fe6e7c40198ac1a673cf7678 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 16 Jan 2024 21:37:39 -0800 Subject: [PATCH 11/74] use mlp from timm --- tests/unet/test_fcmae.py | 51 ++++++------------- viscy/unet/networks/fcmae.py | 95 +++++++++++++++--------------------- 2 files changed, 55 insertions(+), 91 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index fc534981..ba0d7a24 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -4,7 +4,6 @@ MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, - # MaskedGlobalResponseNorm, MaskedMultiscaleEncoder, generate_mask, masked_patchify, @@ -30,7 +29,7 @@ def test_masked_patchify(): mask_ratio = 0.75 mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) mask = upsample_mask(mask, x.shape) - feat = masked_patchify(x, mask) + feat = masked_patchify(x, ~mask) assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) @@ -44,40 +43,28 @@ def test_masked_patchify_roundtrip(): x = torch.rand(2, 3, 4, 8) mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) mask = upsample_mask(mask, x.shape) - y = masked_unpatchify(masked_patchify(x, mask), out_shape=x.shape, mask=mask) + y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) assert torch.all((y == 0) ^ (x == y)) assert torch.all((y == 0)[:, 0:1] == mask) -def test_masked_grn(): - x = torch.rand(2, 3, 4, 5) - grn = MaskedGlobalResponseNorm(3, channels_last=False) - grn.gamma.data = torch.ones_like(grn.gamma.data) - mask = torch.ones((1, 1, 4, 5), dtype=torch.bool) - mask[:, :, 2:, 2:] = False - normalized = grn(x) - assert not torch.allclose(normalized, x) - assert torch.allclose(grn(x, mask)[:, :, 2:, 2:], grn(x[:, :, 2:, 2:])) - grn = MaskedGlobalResponseNorm(5, channels_last=True) - grn.gamma.data = torch.ones_like(grn.gamma.data) - mask = torch.ones((1, 3, 4, 1), dtype=torch.bool) - mask[:, 1:, 2:, :] = False - assert torch.allclose(grn(x, mask)[:, 1:, 2:, :], grn(x[:, 1:, 2:, :])) - - def test_masked_convnextv2_block() -> None: x = torch.rand(2, 3, 4, 5) - mask = x[0, 0] > 0.5 + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5) block = MaskedConvNeXtV2Block(3, 3 * 2) - assert len(block(x).unique()) == x.numel() * 2 + unmasked_out = block(x) + assert len(unmasked_out.unique()) == x.numel() * 2 + all_unmasked = torch.ones_like(mask) + empty_masked_out = block(x, all_unmasked) + assert torch.allclose(unmasked_out, empty_masked_out) block = MaskedConvNeXtV2Block(3, 3) masked_out = block(x, mask) - assert len(masked_out[:, :, mask].unique()) == x.shape[1] + assert len(masked_out.unique()) == mask.sum() * x.shape[1] + 1 def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) - mask = torch.rand(4, 4) > 0.5 + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5) stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) out = stage(x) assert out.shape == (2, 3, 8, 8) @@ -91,8 +78,9 @@ def test_adaptive_projection(): ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - mask = torch.rand(2, 1, 2, 2) > 0.5 - masked_out = proj(torch.rand(2, 3, 5, 16, 16), mask) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6) + masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) + assert masked_out.shape == (1, 12, 4, 4) proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) @@ -104,19 +92,12 @@ def test_masked_multiscale_encoder(): dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) encoder = MaskedMultiscaleEncoder(3, dims=dims) - auto_masked_features, mask = encoder(x, mask_ratio=0.5) + auto_masked_features, _ = encoder(x, mask_ratio=0.5) target_shape = list(x.shape) target_shape.pop(1) - pre_masked_features = encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) - assert len(auto_masked_features) == len(pre_masked_features) == 4 - for i, (dim, afeat, pfeat) in enumerate( - zip(dims, auto_masked_features, pre_masked_features) - ): + assert len(auto_masked_features) == 4 + for i, (dim, afeat) in enumerate(zip(dims, auto_masked_features)): assert afeat.shape[0] == x.shape[0] assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride - assert torch.allclose(afeat, pfeat, rtol=5e-2, atol=5e-2), ( - i, - (afeat - pfeat).abs().max(), - ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index d852f780..a2e6849e 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -54,46 +54,38 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: return mask -def masked_patchify(features: Tensor, mask: BoolTensor | None = None) -> Tensor: +def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor features: input image features (BCHW) - :param BoolTensor mask: boolean mask (B1HW) + :param BoolTensor unmasked: boolean foreground mask (B1HW) :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) """ - if mask is None: + if unmasked is None: return features.flatten(2).permute(0, 2, 1) b, c = features.shape[:2] # (B, C, H, W) -> (B, H, W, C) features = features.permute(0, 2, 3, 1) # (B, H, W, C) -> (B * L, C) -> (B, L, C) - features = features[~mask[:, 0]].reshape(b, -1, c) - - # kernel_size = tuple(features.shape[-i] // mask.shape[-i] for i in (2, 1)) - # # (B, C, H, W) -> (B, C * H_patch * Wp, H_grid * Wg) - # features = F.unfold(features, kernel_size=kernel_size, stride=kernel_size) - # patch_size = kernel_size[0] * kernel_size[1] - # # (B, C * Hp * Wp, Hg * Wg) -> (B, C, Hp * Wp, Hg * Wg) -> (B, Hg * Wg, C, Hp * Wp) - # features = features.view(b, c, patch_size, -1).permute(0, 3, 1, 2) - # # (B, 1, Hg, Wg) -> (B, Hg*Wg) - # idx = ~mask.flatten(1) - # # (B, Hg * Wg, C, Hp * Wp) -> (B * L, C, Hp * Wp) -> (B, L, C, Hp * Wp) - # features = features[idx].view(b, -1, c, patch_size) - # # (B, L, C, Hp * Wp) -> (B, L, Hp * Wp, C) -> (B, L * Hp * Wp, C) - # features = features.permute(0, 1, 3, 2).reshape(b, -1, c) + features = features[unmasked[:, 0]].reshape(b, -1, c) return features def masked_unpatchify( - features: Tensor, out_shape: Size, mask: BoolTensor | None = None + features: Tensor, out_shape: Size, unmasked: BoolTensor | None = None ) -> Tensor: - if mask is None: - # (B, L, C) -> (B, C, L) -> (B, C, H, W) + """ + :param Tensor features: dense channel-last features (BLC) + :param Size out_shape: output shape (BCHW) + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None + :return Tensor: masked features (BCHW) + """ + if unmasked is None: return features.permute(0, 2, 1).reshape(out_shape) b, c, w, h = out_shape out = torch.zeros((b, w, h, c), device=features.device, dtype=features.dtype) # (B, L, C) -> (B * L, C) features = features.reshape(-1, c) - out[~mask[:, 0]] = features + out[unmasked[:, 0]] = features # (B, H, W, C) -> (B, C, H, W) return out.permute(0, 3, 1, 2) @@ -140,25 +132,23 @@ def __init__( else: self.shortcut = nn.Identity() - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) """ shortcut = self.shortcut(x) - if mask is not None: - x *= ~mask + if unmasked is not None: + x *= unmasked x = self.dwconv(x) - if mask is not None: - x *= ~mask + if unmasked is not None: + x *= unmasked out_shape = x.shape - x = masked_project(x, mask) + x = masked_patchify(x, unmasked=unmasked) x = self.layernorm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x, mask) - x = self.pwconv2(x) + x = self.mlp(x.unsqueeze(1)).squeeze(1) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) x = self.drop_path(x) + shortcut return x @@ -220,17 +210,17 @@ def __init__( ) in_channels = out_channels - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) """ x = self.downsample(x) - if mask is not None: - mask = upsample_mask(mask, x.shape) + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) for block in self.blocks: - x = block(x, mask) + x = block(x, unmasked) return x @@ -272,10 +262,10 @@ def __init__( ) self.norm = nn.LayerNorm(out_channels) - def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: """ :param Tensor x: input tensor (BCDHW) - :param BoolTensor mask: boolean mask (B1HW), defaults to None + :param BoolTensor unmasked: boolean foreground mask (B1HW), defaults to None :return Tensor: output tensor (BCHW) """ # no need to mask before convolutions since patches do not spill over @@ -288,19 +278,12 @@ def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: else: x = self.conv2d(x.squeeze(2)) out_shape = x.shape - if mask is not None: - mask = upsample_mask(mask, x.shape) - x = x[mask] - else: - x = x.flatten(2) - x = x.permute(0, 2, 1) + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) + x = masked_patchify(x, unmasked=unmasked) x = self.norm(x) - x = x.permute(0, 2, 1) - if mask is not None: - out = torch.zeros(out_shape, device=x.device, dtype=x.dtype) - out[mask] = x - return out - return x.reshape(out_shape) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) + return x class MaskedMultiscaleEncoder(nn.Module): @@ -343,14 +326,14 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: :return Tensor: output tensor (BCHW) """ if mask_ratio > 0.0: - mask = generate_mask(x.shape, self.total_stride, mask_ratio) + unmasked = ~generate_mask(x.shape, self.total_stride, mask_ratio) else: - mask = None + unmasked = None x = self.stem(x) features = [] for stage in self.stages: - x = stage(x, mask) + x = stage(x, unmasked=unmasked) features.append(x) - if mask is not None: - return features, mask + if unmasked is not None: + return features, unmasked return features From 83ecf4a7fcc138fcc9cab7f6b4c1ab6c5ce149a0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 00:14:58 -0800 Subject: [PATCH 12/74] hack: POC training script for FCMAE --- tests/light/test_engine.py | 10 +++ tests/unet/test_fcmae.py | 8 +++ viscy/light/engine.py | 44 +++++++++++++ viscy/scripts/train_fcmae.py | 66 ++++++++++++++++++++ viscy/unet/networks/fcmae.py | 117 +++++++++++++++++++++++++++++------ 5 files changed, 225 insertions(+), 20 deletions(-) create mode 100644 tests/light/test_engine.py create mode 100644 viscy/scripts/train_fcmae.py diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py new file mode 100644 index 00000000..c6013365 --- /dev/null +++ b/tests/light/test_engine.py @@ -0,0 +1,10 @@ +from viscy.light.engine import FcmaeUNet + + +def test_fcmae_vsunet() -> None: + model = FcmaeUNet( + architecture="fcmae", + model_config=dict(in_channels=3), + train_mask_ratio=0.6, + ) + diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index ba0d7a24..870f1138 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,6 +1,7 @@ import torch from viscy.unet.networks.fcmae import ( + FullyConvolutionalMAE, MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, @@ -101,3 +102,10 @@ def test_masked_multiscale_encoder(): assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride + + +def test_fcmae(): + x = torch.rand(2, 3, 5, 128, 128) + model = FullyConvolutionalMAE(3) + assert model(x).shape == x.shape + assert model(x, mask_ratio=0.6).shape == x.shape diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 74f14aaa..0262cc7d 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -27,6 +27,7 @@ from viscy.data.hcs import Sample from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d +from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d @@ -43,6 +44,7 @@ # same class with out_stack_depth > 1 "2.2D": Unet21d, "2.5D": Unet25d, + "fcmae": FullyConvolutionalMAE, } @@ -367,3 +369,45 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + +class FcmaeUNet(VSUNet): + def __init__(self, train_mask_ratio: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.train_mask_ratio = train_mask_ratio + + def forward(self, x, mask_ratio: float = 0.0): + return self.model(x, mask_ratio) + + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + loss = F.mse_loss(pred, target, reduction="none") + loss = (loss * mask).sum() / mask.sum() + self.log( + "loss/train", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + return loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + loss = F.mse_loss(pred, target, reduction="none") + loss = (loss.mean(2) * mask).sum() / mask.sum() + self.log("loss/validate", loss, sync_dist=True) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py new file mode 100644 index 00000000..692bef6d --- /dev/null +++ b/viscy/scripts/train_fcmae.py @@ -0,0 +1,66 @@ +# %% +from lightning.pytorch.loggers import TensorBoardLogger +from torch import set_float32_matmul_precision + +from viscy.data.hcs import HCSDataModule +from viscy.light.engine import FcmaeUNet +from viscy.light.trainer import VSTrainer +from viscy.transforms import ( + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, +) + +# %% +model = FcmaeUNet( + architecture="fcmae", + model_config=dict(in_channels=1), + train_mask_ratio=0.6, +) + +# %% +ch = "reconstructed-labelfree" + +data = HCSDataModule( + data_path="/hpc/projects/comp.micro/virtual_staining/datasets/training/raw-and-reconstructed.zarr", + source_channel=ch, + target_channel=ch, + z_window_size=5, + batch_size=64, + num_workers=12, + architecture="3D", + augmentations=[ + RandWeightedCropd(ch, ch, spatial_size=[-1, 512, 512], num_samples=2), + RandAffined( + ch, + prob=0.5, + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.0, 0.05, 0.05], + scale_range=[0.2, 0.3, 0.3], + ), + RandAdjustContrastd(ch, prob=0.3, gamma=[0.75, 1.5]), + RandScaleIntensityd(ch, prob=0.3, factors=0.5), + RandGaussianNoised(ch, prob=0.5, mean=0.0, std=5.0), + RandGaussianSmoothd( + ch, prob=0.5, sigma_z=[0.25, 1.5], sigma_y=[0.25, 1.5], sigma_x=[0.25, 1.5] + ), + ], +) + + +# %% +set_float32_matmul_precision("high") + +trainer = VSTrainer( + fast_dev_run=False, + max_epochs=50, + logger=TensorBoardLogger( + save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_0", log_graph=False + ), +) +trainer.fit(model, data) + +# %% diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index a2e6849e..ad9d9559 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -9,19 +9,36 @@ from typing import Callable, Literal, Sequence import torch -import torch.nn.functional as F -from timm.layers import ( +from timm.models.convnext import ( + Downsample, DropPath, GlobalResponseNormMlp, LayerNorm2d, create_conv2d, trunc_normal_, ) -from timm.models.convnext import Downsample from torch import BoolTensor, Size, Tensor, nn +from viscy.unet.networks.Unet21D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead -def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: + +def _init_weights(module: nn.Module) -> None: + """Initialize weights of the given module.""" + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + +def generate_mask( + target: Size, stride: int, mask_ratio: float, device: str +) -> BoolTensor: """ :param Size target: target shape :param int stride: total stride @@ -32,7 +49,7 @@ def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: m_width = target[-1] // stride mask_numel = m_height * m_width masked_elements = int(mask_numel * mask_ratio) - mask = torch.rand(target[0], mask_numel).argsort(1) < masked_elements + mask = torch.rand(target[0], mask_numel, device=device).argsort(1) < masked_elements return mask.reshape(target[0], 1, m_height, m_width) @@ -293,14 +310,16 @@ def __init__( stage_blocks: Sequence[int] = (3, 3, 9, 3), dims: Sequence[int] = (96, 192, 384, 768), drop_path_rate: float = 0.0, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, ) -> None: super().__init__() - stem_kernel_size_2d = 4 - self.stem = nn.Sequential( - MaskedAdaptiveProjection( - in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 - ), - LayerNorm2d(dims[0]), + self.stem = MaskedAdaptiveProjection( + in_channels, + dims[0], + kernel_size_2d=stem_kernel_size[1:], + kernel_depth=stem_kernel_size[0], + in_stack_depth=in_stack_depth, ) self.stages = nn.ModuleList() chs = [dims[0], *dims] @@ -316,24 +335,82 @@ def __init__( drop_path_rates=[drop_path_rate] * num_blocks, ) ) - self.total_stride = stem_kernel_size_2d * 2 ** (len(self.stages) - 1) + self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) + self.apply(_init_weights) - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: """ - :param Tensor x: input tensor (BCHW) + :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, defaults to 0.0 (no masking) - :return Tensor: output tensor (BCHW) + :return list[Tensor]: output tensors (list of BCHW) + :return BoolTensor | None: boolean foreground mask, None if no masking """ if mask_ratio > 0.0: - unmasked = ~generate_mask(x.shape, self.total_stride, mask_ratio) + mask = generate_mask( + x.shape, self.total_stride, mask_ratio, device=x.device + ) + b, c, d, h, w = x.shape + unmasked = ~mask + mask = upsample_mask(mask, (b, d, h, w)) else: - unmasked = None + mask = unmasked = None x = self.stem(x) features = [] for stage in self.stages: x = stage(x, unmasked=unmasked) features.append(x) - if unmasked is not None: - return features, unmasked - return features + return features, mask + + +class FullyConvolutionalMAE(nn.Module): + def __init__( + self, + in_channels: int, + encoder_blocks: Sequence[int] = [3, 3, 9, 3], + dims: Sequence[int] = [96, 192, 384, 768], + encoder_drop_path_rate: float = 0.0, + head_expansion_ratio: int = 4, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, + ) -> None: + super().__init__() + self.encoder = MaskedMultiscaleEncoder( + in_channels=in_channels, + stage_blocks=encoder_blocks, + dims=dims, + drop_path_rate=encoder_drop_path_rate, + stem_kernel_size=stem_kernel_size, + in_stack_depth=in_stack_depth, + ) + decoder_channels = list(dims) + decoder_channels.reverse() + decoder_channels[-1] = ( + (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio + ) + self.decoder = Unet2dDecoder( + decoder_channels, + norm_name="instance", + mode="pixelshuffle", + conv_blocks=1, + strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], + upsample_pre_conv=None, + ) + if in_stack_depth == 1: + self.head = UnsqueezeHead() + else: + self.head = PixelToVoxelHead( + in_channels=decoder_channels[-1], + out_channels=in_channels, + out_stack_depth=in_stack_depth, + expansion_ratio=head_expansion_ratio, + pool=True, + ) + self.out_stack_depth = in_stack_depth + self.num_blocks = 6 + + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + x, mask = self.encoder(x, mask_ratio=mask_ratio) + x.reverse() + x = self.decoder(x) + return self.head(x), mask From 2fffc9928ae6499d6a4850b59f7cfdd1f6994fe5 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 10:25:08 -0800 Subject: [PATCH 13/74] fix mask for fitting --- tests/unet/test_fcmae.py | 8 ++++++-- viscy/light/engine.py | 24 ++++++++++++------------ viscy/scripts/train_fcmae.py | 23 +++++++++++++++++------ viscy/unet/networks/fcmae.py | 6 +++--- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 870f1138..36fb673e 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -107,5 +107,9 @@ def test_masked_multiscale_encoder(): def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) model = FullyConvolutionalMAE(3) - assert model(x).shape == x.shape - assert model(x, mask_ratio=0.6).shape == x.shape + y, m = model(x) + assert y.shape == x.shape + assert m is None + y, m = model(x, mask_ratio=0.6) + assert y.shape == x.shape + assert m.shape == (2, 1, 128, 128) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 0262cc7d..85254077 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -372,19 +372,23 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): class FcmaeUNet(VSUNet): - def __init__(self, train_mask_ratio: float = 0.0, **kwargs): + def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(**kwargs) - self.train_mask_ratio = train_mask_ratio + self.fit_mask_ratio = fit_mask_ratio def forward(self, x, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def training_step(self, batch: Sample, batch_idx: int): + def forward_fit(self, batch: Sample): source = batch["source"] target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) loss = F.mse_loss(pred, target, reduction="none") - loss = (loss * mask).sum() / mask.sum() + loss = (loss.mean(2) * mask).sum() / mask.sum() + return source, target, pred, mask, loss + + def training_step(self, batch: Sample, batch_idx: int): + source, target, pred, mask, loss = self.forward_fit(batch) self.log( "loss/train", loss, @@ -396,18 +400,14 @@ def training_step(self, batch: Sample, batch_idx: int): ) if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) return loss def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] - target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) - loss = F.mse_loss(pred, target, reduction="none") - loss = (loss.mean(2) * mask).sum() / mask.sum() + source, target, pred, mask, loss = self.forward_fit(batch) self.log("loss/validate", loss, sync_dist=True) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py index 692bef6d..0c098454 100644 --- a/viscy/scripts/train_fcmae.py +++ b/viscy/scripts/train_fcmae.py @@ -1,4 +1,5 @@ # %% +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import set_float32_matmul_precision @@ -17,8 +18,11 @@ # %% model = FcmaeUNet( architecture="fcmae", - model_config=dict(in_channels=1), - train_mask_ratio=0.6, + model_config=dict( + in_channels=1, encoder_blocks=[3, 3, 27, 3], dims=[128, 256, 512, 1024] + ), + fit_mask_ratio=0.6, + schedule="WarmupCosine", ) # %% @@ -32,8 +36,10 @@ batch_size=64, num_workers=12, architecture="3D", + yx_patch_size=[384, 384], + normalize_source=True, augmentations=[ - RandWeightedCropd(ch, ch, spatial_size=[-1, 512, 512], num_samples=2), + RandWeightedCropd(ch, ch, spatial_size=[-1, 768, 768], num_samples=2), RandAffined( ch, prob=0.5, @@ -55,11 +61,16 @@ set_float32_matmul_precision("high") trainer = VSTrainer( - fast_dev_run=False, - max_epochs=50, + fast_dev_run=True, + precision="16-mixed", + max_epochs=100, logger=TensorBoardLogger( - save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_0", log_graph=False + save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_1", log_graph=False ), + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint(monitor="loss/validate", save_top_k=5, every_n_epochs=1), + ], ) trainer.fit(model, data) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index ad9d9559..7f69cf8f 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -6,7 +6,7 @@ """ -from typing import Callable, Literal, Sequence +from typing import Sequence import torch from timm.models.convnext import ( @@ -43,7 +43,7 @@ def generate_mask( :param Size target: target shape :param int stride: total stride :param float mask_ratio: ratio of the pixels to mask - :return BoolTensor: boolean mask (N, H*W) + :return BoolTensor: boolean mask (B1HW) """ m_height = target[-2] // stride m_width = target[-1] // stride @@ -352,7 +352,7 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: ) b, c, d, h, w = x.shape unmasked = ~mask - mask = upsample_mask(mask, (b, d, h, w)) + mask = upsample_mask(mask, (b, 1, h, w)) else: mask = unmasked = None x = self.stem(x) From 2a598b28a38acc14fa185026936b757b7695acc9 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 10:29:58 -0800 Subject: [PATCH 14/74] remove training script --- viscy/scripts/train_fcmae.py | 77 ------------------------------------ 1 file changed, 77 deletions(-) delete mode 100644 viscy/scripts/train_fcmae.py diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py deleted file mode 100644 index 0c098454..00000000 --- a/viscy/scripts/train_fcmae.py +++ /dev/null @@ -1,77 +0,0 @@ -# %% -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.loggers import TensorBoardLogger -from torch import set_float32_matmul_precision - -from viscy.data.hcs import HCSDataModule -from viscy.light.engine import FcmaeUNet -from viscy.light.trainer import VSTrainer -from viscy.transforms import ( - RandAdjustContrastd, - RandAffined, - RandGaussianNoised, - RandGaussianSmoothd, - RandScaleIntensityd, - RandWeightedCropd, -) - -# %% -model = FcmaeUNet( - architecture="fcmae", - model_config=dict( - in_channels=1, encoder_blocks=[3, 3, 27, 3], dims=[128, 256, 512, 1024] - ), - fit_mask_ratio=0.6, - schedule="WarmupCosine", -) - -# %% -ch = "reconstructed-labelfree" - -data = HCSDataModule( - data_path="/hpc/projects/comp.micro/virtual_staining/datasets/training/raw-and-reconstructed.zarr", - source_channel=ch, - target_channel=ch, - z_window_size=5, - batch_size=64, - num_workers=12, - architecture="3D", - yx_patch_size=[384, 384], - normalize_source=True, - augmentations=[ - RandWeightedCropd(ch, ch, spatial_size=[-1, 768, 768], num_samples=2), - RandAffined( - ch, - prob=0.5, - rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.05, 0.05], - scale_range=[0.2, 0.3, 0.3], - ), - RandAdjustContrastd(ch, prob=0.3, gamma=[0.75, 1.5]), - RandScaleIntensityd(ch, prob=0.3, factors=0.5), - RandGaussianNoised(ch, prob=0.5, mean=0.0, std=5.0), - RandGaussianSmoothd( - ch, prob=0.5, sigma_z=[0.25, 1.5], sigma_y=[0.25, 1.5], sigma_x=[0.25, 1.5] - ), - ], -) - - -# %% -set_float32_matmul_precision("high") - -trainer = VSTrainer( - fast_dev_run=True, - precision="16-mixed", - max_epochs=100, - logger=TensorBoardLogger( - save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_1", log_graph=False - ), - callbacks=[ - LearningRateMonitor(logging_interval="step"), - ModelCheckpoint(monitor="loss/validate", save_top_k=5, every_n_epochs=1), - ], -) -trainer.fit(model, data) - -# %% From b9b188067221c8b156627cf537c7e2496510ec67 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 14:11:54 -0800 Subject: [PATCH 15/74] default architecture --- viscy/light/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 85254077..e1f699eb 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -373,7 +373,7 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): - super().__init__(**kwargs) + super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio def forward(self, x, mask_ratio: float = 0.0): From fd7700d0ea70339f467c0c431eca4f0c78201f5b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 22 Jan 2024 15:04:03 -0800 Subject: [PATCH 16/74] fine-tuning options --- viscy/unet/networks/fcmae.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 7f69cf8f..0799d8fb 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -367,12 +367,15 @@ class FullyConvolutionalMAE(nn.Module): def __init__( self, in_channels: int, + out_channels: int, encoder_blocks: Sequence[int] = [3, 3, 9, 3], dims: Sequence[int] = [96, 192, 384, 768], encoder_drop_path_rate: float = 0.0, head_expansion_ratio: int = 4, stem_kernel_size: Sequence[int] = (5, 4, 4), in_stack_depth: int = 5, + decoder_conv_blocks: int = 1, + pretraining: bool = True, ) -> None: super().__init__() self.encoder = MaskedMultiscaleEncoder( @@ -392,7 +395,7 @@ def __init__( decoder_channels, norm_name="instance", mode="pixelshuffle", - conv_blocks=1, + conv_blocks=decoder_conv_blocks, strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], upsample_pre_conv=None, ) @@ -401,16 +404,20 @@ def __init__( else: self.head = PixelToVoxelHead( in_channels=decoder_channels[-1], - out_channels=in_channels, + out_channels=out_channels, out_stack_depth=in_stack_depth, expansion_ratio=head_expansion_ratio, pool=True, ) self.out_stack_depth = in_stack_depth self.num_blocks = 6 + self.pretraining = pretraining def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: x, mask = self.encoder(x, mask_ratio=mask_ratio) x.reverse() x = self.decoder(x) - return self.head(x), mask + x = self.head(x) + if self.pretraining: + return x, mask + return x From 054249f14e7dac4e4040edf53d55232831ef3fe6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 24 Jan 2024 14:12:19 -0800 Subject: [PATCH 17/74] fix cli for finetuning --- viscy/data/hcs.py | 4 ++-- viscy/light/engine.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 01191db1..f8bb6a22 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -334,7 +334,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] = "2.5D", + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), augmentations: Optional[list[MapTransform]] = None, caching: bool = False, @@ -348,7 +348,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["2.2D", "3D"] else True + self.target_2d = False if architecture in ["2.2D", "3D", "fcmae"] else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size diff --git a/viscy/light/engine.py b/viscy/light/engine.py index e1f699eb..e6a2dfa4 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -118,11 +118,12 @@ class VSUNet(LightningModule): def __init__( self, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"], + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"], model_config: dict = {}, loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", + freeze_encoder: bool = False, log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, example_input_yx_shape: Sequence[int] = (256, 256), @@ -162,6 +163,7 @@ def __init__( self.test_cellpose_model_path = test_cellpose_model_path self.test_cellpose_diameter = test_cellpose_diameter self.test_evaluate_cellpose = test_evaluate_cellpose + self.freeze_encoder = freeze_encoder def forward(self, x) -> torch.Tensor: return self.model(x) @@ -331,6 +333,9 @@ def on_predict_start(self): self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) def configure_optimizers(self): + if self.freeze_encoder: + self.model: FullyConvolutionalMAE + self.model.encoder.requires_grad_(False) optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) if self.schedule == "WarmupCosine": scheduler = WarmupCosineSchedule( From d867e101b3e006ed9dc819722280b2bae8ea5560 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 24 Jan 2024 14:56:10 -0800 Subject: [PATCH 18/74] draft combined data module --- viscy/data/combined.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 viscy/data/combined.py diff --git a/viscy/data/combined.py b/viscy/data/combined.py new file mode 100644 index 00000000..6b8dd63c --- /dev/null +++ b/viscy/data/combined.py @@ -0,0 +1,62 @@ +from typing import Literal, Sequence + +from lightning.pytorch import LightningDataModule +from lightning.pytorch.utilities import combined_loader + +_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] + + +class CombinedDataModule(LightningDataModule): + """Wrapper for combining multiple data modules. + For supported modes, see ``lightning.pytorch.utilities.combined_loader``. + + :param Sequence[LightningDataModule] data_modules: data modules to combine + :param str train_mode: mode in training stage, defaults to "max_size_cycle" + :param str val_mode: mode in validation stage, defaults to "sequential" + :param str test_mode: mode in testing stage, defaults to "sequential" + :param str predict_mode: mode in prediction stage, defaults to "sequential" + """ + + def __init__( + self, + data_modules: Sequence[LightningDataModule], + train_mode: _MODES = "max_size_cycle", + val_mode: _MODES = "sequential", + test_mode: _MODES = "sequential", + predict_mode: _MODES = "sequential", + ): + super().__init__() + self.data_modules = data_modules + self.train_mode = train_mode + self.val_mode = val_mode + self.test_mode = test_mode + self.predict_mode = predict_mode + + def prepare_data(self): + for dm in self.data_modules: + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + for dm in self.data_modules: + dm.setup(stage) + + def train_dataloader(self): + return combined_loader( + [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode + ) + + def val_dataloader(self): + return combined_loader( + [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode + ) + + def test_dataloader(self): + return combined_loader( + [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode + ) + + def predict_dataloader(self): + return combined_loader( + [dm.predict_dataloader() for dm in self.data_modules], + mode=self.predict_mode, + ) From b06a30077c83402b71390de160f1ce404ca98240 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 25 Jan 2024 15:52:42 -0800 Subject: [PATCH 19/74] fix import --- viscy/data/combined.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 6b8dd63c..5da700dd 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,7 +1,7 @@ from typing import Literal, Sequence from lightning.pytorch import LightningDataModule -from lightning.pytorch.utilities import combined_loader +from lightning.pytorch.utilities.combined_loader import CombinedLoader _MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] @@ -41,22 +41,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): dm.setup(stage) def train_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode ) def val_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode ) def test_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode ) def predict_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, ) From 39eafab77f97a046a20c5bc4944bf9f24dc11ca1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 26 Jan 2024 21:35:49 -0800 Subject: [PATCH 20/74] manual validation loss reduction --- viscy/light/engine.py | 48 ++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index e6a2dfa4..ebd2fd60 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -10,7 +10,7 @@ from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad from skimage.exposure import rescale_intensity -from torch import nn +from torch import Tensor, nn from torch.nn import functional as F from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -165,7 +165,7 @@ def __init__( self.test_evaluate_cellpose = test_evaluate_cellpose self.freeze_encoder = freeze_encoder - def forward(self, x) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: return self.model(x) def training_step(self, batch: Sample, batch_idx: int): @@ -230,7 +230,7 @@ def test_step(self, batch: Sample, batch_idx: int): else: self._log_segmentation_metrics(None, None) - def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): + def _log_regression_metrics(self, pred: Tensor, target: Tensor): # paired image translation metrics self.log_dict( { @@ -253,7 +253,7 @@ def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): on_epoch=True, ) - def _cellpose_predict(self, pred: torch.Tensor, name: str) -> torch.ShortTensor: + def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: pred_labels_np = self.cellpose_model.eval( pred.cpu().numpy(), channels=[0, 0], diameter=self.test_cellpose_diameter )[0].astype(np.int16) @@ -350,7 +350,7 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] - def _detach_sample(self, imgs: Sequence[torch.Tensor]): + def _detach_sample(self, imgs: Sequence[Tensor]): num_samples = min(imgs[0].shape[0], self.log_samples_per_batch) return [ [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] @@ -380,11 +380,12 @@ class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio + self.validation_losses = [] - def forward(self, x, mask_ratio: float = 0.0): + def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def forward_fit(self, batch: Sample): + def forward_fit(self, batch: Sample) -> tuple[Tensor]: source = batch["source"] target = batch["target"] pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) @@ -392,27 +393,40 @@ def forward_fit(self, batch: Sample): loss = (loss.mean(2) * mask).sum() / mask.sum() return source, target, pred, mask, loss - def training_step(self, batch: Sample, batch_idx: int): - source, target, pred, mask, loss = self.forward_fit(batch) + def training_step(self, batch: Sequence[Sample], batch_idx: int): + losses = [] + batch_size = 0 + for b in batch: + source, target, pred, mask, loss = self.forward_fit(b) + losses.append(loss) + batch_size += source.shape[0] + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target * mask.unsqueeze(2), pred)) + ) + loss_step = torch.stack(losses).mean() self.log( "loss/train", - loss, + loss_step, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, + batch_size=batch_size, ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target * mask.unsqueeze(2), pred)) - ) - return loss + return loss_step - def validation_step(self, batch: Sample, batch_idx: int): + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source, target, pred, mask, loss = self.forward_fit(batch) - self.log("loss/validate", loss, sync_dist=True) + self.validation_losses.append(loss.detach()) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) + + def on_validation_epoch_end(self): + super().on_validation_epoch_end() + self.log( + "loss/validate", torch.stack(self.validation_losses).mean(), sync_dist=True + ) From 9fbf7a551e0613e0173d7de05ba6f9dfd911d709 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 2 Feb 2024 09:55:29 -0800 Subject: [PATCH 21/74] update linting new black version has different rules --- pyproject.toml | 17 +++++++++++------ viscy/evaluation/evaluation_metrics.py | 1 + viscy/light/engine.py | 26 +++++++++++++------------- viscy/preprocessing/generate_masks.py | 1 + viscy/unet/networks/fcmae.py | 1 - viscy/utils/image_utils.py | 4 +--- viscy/utils/normalize.py | 1 + 7 files changed, 28 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b60cd534..67142b4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,15 @@ metrics = [ "ptflops>=0.7", ] visual = ["ipykernel", "graphviz", "torchview"] -dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"] +dev = [ + "pytest", + "pytest-cov", + "hypothesis", + "ruff", + "black", + "profilehooks", + "onnxruntime", +] [project.scripts] viscy = "viscy.cli.cli:main" @@ -39,12 +47,9 @@ viscy = "viscy.cli.cli:main" write_to = "viscy/_version.py" [tool.black] -src = ["viscy"] line-length = 88 [tool.ruff] src = ["viscy", "tests"] -extend-select = ["I001"] - -[tool.ruff.isort] -known-first-party = ["viscy"] +lint.extend-select = ["I001"] +lint.isort.known-first-party = ["viscy"] diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index 589370bd..fb83c06b 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -1,4 +1,5 @@ """Metrics for model evaluation""" + from typing import Sequence, Union from warnings import warn diff --git a/viscy/light/engine.py b/viscy/light/engine.py index ebd2fd60..c6197a9a 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -272,19 +272,19 @@ def _log_segmentation_metrics( self.log_dict( { # semantic segmentation - "test_metrics/accuracy": accuracy( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, - "test_metrics/dice": dice(pred_binary, target_binary) - if compute - else -1, - "test_metrics/jaccard": jaccard_index( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, + "test_metrics/accuracy": ( + accuracy(pred_binary, target_binary, task="binary") + if compute + else -1 + ), + "test_metrics/dice": ( + dice(pred_binary, target_binary) if compute else -1 + ), + "test_metrics/jaccard": ( + jaccard_index(pred_binary, target_binary, task="binary") + if compute + else -1 + ), "test_metrics/mAP": coco_metrics["map"] if compute else -1, "test_metrics/mAP_50": coco_metrics["map_50"] if compute else -1, "test_metrics/mAP_75": coco_metrics["map_75"] if compute else -1, diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py index f88f8fbe..491bc406 100644 --- a/viscy/preprocessing/generate_masks.py +++ b/viscy/preprocessing/generate_masks.py @@ -1,4 +1,5 @@ """Generate masks from sum of flurophore channels""" + import iohub.ngff as ngff import viscy.utils.aux_utils as aux_utils diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 0799d8fb..97771365 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -5,7 +5,6 @@ and timm's dense implementation of the encoder in ``timm.models.convnext`` """ - from typing import Sequence import torch diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py index f9020dc9..a9569116 100644 --- a/viscy/utils/image_utils.py +++ b/viscy/utils/image_utils.py @@ -21,9 +21,7 @@ def im_bit_convert(im, bit=16, norm=False, limit=[]): / (limit[1] - limit[0] + sys.float_info.epsilon) * (2**bit - 1) ) - im = np.clip( - im, 0, 2**bit - 1 - ) # clip the values to avoid wrap-around by np.astype + im = np.clip(im, 0, 2**bit - 1) # clip the values to avoid wrap-around by np.astype if bit == 8: im = im.astype(np.uint8, copy=False) # convert to 8 bit else: diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py index 93c11713..73753acb 100644 --- a/viscy/utils/normalize.py +++ b/viscy/utils/normalize.py @@ -1,4 +1,5 @@ """Image normalization related functions""" + import sys import numpy as np From e00f5f3bd0415c1c3b8db924bd600cd2354e4cb7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 2 Feb 2024 10:01:36 -0800 Subject: [PATCH 22/74] update development guide --- CONTRIBUTING.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3b40b075..44db5bbc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,19 @@ then make an editable installation with all the optional dependencies: pip install -e ".[dev,visual,metrics]" ``` -## Testing +## CI requirements + +Lint with Ruff: + +```sh +ruff check viscy +``` + +Format the code with Black: + +```sh +black viscy +``` Run tests with `pytest`: From 9e345b6c3b59a70a3b7c0bcde8bce184e46c3833 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 13 Feb 2024 15:27:26 -0800 Subject: [PATCH 23/74] update type hints --- viscy/data/hcs.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index f8bb6a22..218ea414 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -23,10 +23,11 @@ MultiSampleTrait, RandAffined, ) +from torch import Tensor from torch.utils.data import DataLoader, Dataset -def _ensure_channel_list(str_or_seq: Union[str, Sequence[str]]): +def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ Ensure channel argument is a list of strings. @@ -67,9 +68,9 @@ class Sample(TypedDict, total=False): index: tuple[str, int, int] # optional - source: Union[torch.Tensor, Sequence[torch.Tensor]] - target: Union[torch.Tensor, Sequence[torch.Tensor]] - labels: Union[torch.Tensor, Sequence[torch.Tensor]] + source: Union[Tensor, Sequence[Tensor]] + target: Union[Tensor, Sequence[Tensor]] + labels: Union[Tensor, Sequence[Tensor]] def _collate_samples(batch: Sequence[Sample]) -> Sample: @@ -83,7 +84,7 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: elemment = batch[0] collated = {} for key in elemment.keys(): - data: list[list[torch.Tensor]] = [sample[key] for sample in batch] + data: list[list[Tensor]] = [sample[key] for sample in batch] collated[key] = collate_meta_tensor([im for imgs in data for im in imgs]) return collated @@ -108,13 +109,13 @@ def _stat(self, key: str) -> dict: # FIXME: hard-coded key return self.norm_meta[key]["dataset_statistics"] - def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: d = dict(data) for key in self.keys: d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] return d - def inverse(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]: d = dict(data) for key in self.keys: d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] @@ -128,7 +129,7 @@ class SlidingWindowDataset(Dataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: a callable that transforms data, defaults to None """ @@ -137,7 +138,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, ) -> None: super().__init__() self.positions = positions @@ -178,14 +179,14 @@ def _find_window(self, index: int) -> tuple[int, int]: def _read_img_window( self, img: ImageArray, ch_idx: list[str], tz: int - ) -> tuple[tuple[torch.Tensor], tuple[str, int, int]]: + ) -> tuple[tuple[Tensor], tuple[str, int, int]]: """Read image window as tensor. :param ImageArray img: NGFF image array :param list[int] channels: list of channel indices to read, output channel ordering will reflect the sequence :param int tz: window index within the FOV, counted Z-first - :return tuple[torch.Tensor], tuple[str, int, int]: + :return tuple[Tensor], tuple[str, int, int]: tuple of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index """ @@ -203,8 +204,8 @@ def __len__(self) -> int: return self._max_window def _stack_channels( - self, sample_images: list[dict[str, torch.Tensor]], key: str - ) -> torch.Tensor: + self, sample_images: list[dict[str, Tensor]], key: str + ) -> Tensor: """Stack single-channel images into a multi-channel tensor.""" if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in self.channels[key]]) @@ -258,7 +259,7 @@ class MaskTestDataset(SlidingWindowDataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: a callable that transforms data, defaults to None """ @@ -267,7 +268,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, ground_truth_masks: str = None, ) -> None: super().__init__(positions, channels, z_window_size, transform) @@ -527,7 +528,7 @@ def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample if self.trainer: if self.trainer.predicting: predicting = True - if predicting or isinstance(batch, torch.Tensor): + if predicting or isinstance(batch, Tensor): # skipping example input array return batch if self.target_2d: From 96deca5f0020fb9a99388f37d0640e656253145e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 20 Feb 2024 14:44:02 -0800 Subject: [PATCH 24/74] bump iohub --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 67142b4f..8f6978de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.10" license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ - "iohub==0.1.0rc0", + "iohub==0.1.0", "torch>=2.1.2", "timm>=0.9.5", "tensorboard>=2.13.0", From e06aa574634dd504755dd21e40346071ea7a6b00 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 21:42:29 -0800 Subject: [PATCH 25/74] draft ctmc v1 dataset --- viscy/data/ctmc_v1.py | 67 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 viscy/data/ctmc_v1.py diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py new file mode 100644 index 00000000..8c42f85d --- /dev/null +++ b/viscy/data/ctmc_v1.py @@ -0,0 +1,67 @@ +import logging +from pathlib import Path + +import numpy as np +from iohub.ngff import ImageArray, Plate, Position, TransformationMeta, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, MapTransform +from torch import Tensor +from torch.utils.data import DataLoader + +from viscy.data.hcs import ChannelMap, SlidingWindowDataset + + +class CTMCv1DataModule(LightningDataModule): + """ + Autoregression data module for the CTMCv1 dataset. + Training and validation datasets are stored in separate HCS OME-Zarr stores. + """ + + def __init__( + self, + train_data_path: str | Path, + val_data_path: str | Path, + train_transforms: list[MapTransform], + val_transforms: list[MapTransform], + batch_size: int = 16, + num_workers: int = 8, + channel_name: str = "DIC", + ) -> None: + super().__init__() + self.train_data_path = Path(train_data_path) + self.val_data_path = Path(val_data_path) + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.channel_map = ChannelMap(source=channel_name, target=channel_name) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + train_plate = open_ome_zarr(self.train_data_path, mode="r") + val_plate = open_ome_zarr(self.val_data_path, mode="r") + train_positions = [p for _, p in train_plate.positions()] + val_positions = [p for _, p in val_plate.positions()] + self.train_dataset = SlidingWindowDataset( + train_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.train_transform), + ) + self.val_dataset = SlidingWindowDataset( + val_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.val_transform), + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) From 72de113f8c5a678f4383d374da925517d1cace6b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 22:33:41 -0800 Subject: [PATCH 26/74] update tests --- tests/light/test_engine.py | 5 +---- tests/unet/test_fcmae.py | 14 +++++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py index c6013365..9ce182f5 100644 --- a/tests/light/test_engine.py +++ b/tests/light/test_engine.py @@ -3,8 +3,5 @@ def test_fcmae_vsunet() -> None: model = FcmaeUNet( - architecture="fcmae", - model_config=dict(in_channels=3), - train_mask_ratio=0.6, + model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6 ) - diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 36fb673e..4ed441b4 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -17,7 +17,7 @@ def test_generate_mask(): w = 64 s = 16 m = 0.75 - mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m) + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m, device="cpu") assert mask.shape == (2, 1, w // s, w // s) assert mask.dtype == torch.bool ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] @@ -28,7 +28,7 @@ def test_masked_patchify(): b, c, h, w = 2, 3, 4, 8 x = torch.rand(b, c, h, w) mask_ratio = 0.75 - mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio, device=x.device) mask = upsample_mask(mask, x.shape) feat = masked_patchify(x, ~mask) assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) @@ -42,7 +42,7 @@ def test_unmasked_patchify_roundtrip(): def test_masked_patchify_roundtrip(): x = torch.rand(2, 3, 4, 8) - mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5, device=x.device) mask = upsample_mask(mask, x.shape) y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) assert torch.all((y == 0) ^ (x == y)) @@ -51,7 +51,7 @@ def test_masked_patchify_roundtrip(): def test_masked_convnextv2_block() -> None: x = torch.rand(2, 3, 4, 5) - mask = generate_mask(x.shape, stride=1, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5, device=x.device) block = MaskedConvNeXtV2Block(3, 3 * 2) unmasked_out = block(x) assert len(unmasked_out.unique()) == x.numel() * 2 @@ -65,7 +65,7 @@ def test_masked_convnextv2_block() -> None: def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) - mask = generate_mask(x.shape, stride=4, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5, device=x.device) stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) out = stage(x) assert out.shape == (2, 3, 8, 8) @@ -79,7 +79,7 @@ def test_adaptive_projection(): ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6, device="cpu") masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) assert masked_out.shape == (1, 12, 4, 4) proj = MaskedAdaptiveProjection( @@ -106,7 +106,7 @@ def test_masked_multiscale_encoder(): def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) - model = FullyConvolutionalMAE(3) + model = FullyConvolutionalMAE(3, 3) y, m = model(x) assert y.shape == x.shape assert m is None From 13d0aa0574665d0da4f17407033fc964da00e602 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 23:47:56 -0800 Subject: [PATCH 27/74] move test_data --- tests/data/__init__.py | 0 tests/{light => data}/test_data.py | 0 viscy/data/ctmc_v1.py | 24 ++++++++++++++++-------- 3 files changed, 16 insertions(+), 8 deletions(-) create mode 100644 tests/data/__init__.py rename tests/{light => data}/test_data.py (100%) diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/light/test_data.py b/tests/data/test_data.py similarity index 100% rename from tests/light/test_data.py rename to tests/data/test_data.py diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 8c42f85d..df1d3223 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -1,11 +1,8 @@ -import logging from pathlib import Path -import numpy as np -from iohub.ngff import ImageArray, Plate, Position, TransformationMeta, open_ome_zarr +from iohub.ngff import open_ome_zarr from lightning.pytorch import LightningDataModule from monai.transforms import Compose, MapTransform -from torch import Tensor from torch.utils.data import DataLoader from viscy.data.hcs import ChannelMap, SlidingWindowDataset @@ -39,8 +36,11 @@ def __init__( def setup(self, stage: str) -> None: if stage != "fit": raise NotImplementedError("Only fit stage is supported") - train_plate = open_ome_zarr(self.train_data_path, mode="r") - val_plate = open_ome_zarr(self.val_data_path, mode="r") + self._setup_fit() + + def _setup_fit(self) -> None: + train_plate = open_ome_zarr(self.train_data_path) + val_plate = open_ome_zarr(self.val_data_path) train_positions = [p for _, p in train_plate.positions()] val_positions = [p for _, p in val_plate.positions()] self.train_dataset = SlidingWindowDataset( @@ -58,10 +58,18 @@ def setup(self, stage: str) -> None: def train_dataloader(self) -> DataLoader: return DataLoader( - self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, ) def val_dataloader(self) -> DataLoader: return DataLoader( - self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=False, ) From 78aed971aa2bea34e89c0db20bde883ebb98a06e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 23:53:15 -0800 Subject: [PATCH 28/74] remove path conversion --- viscy/data/ctmc_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index df1d3223..0d65a36a 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -25,8 +25,8 @@ def __init__( channel_name: str = "DIC", ) -> None: super().__init__() - self.train_data_path = Path(train_data_path) - self.val_data_path = Path(val_data_path) + self.train_data_path = train_data_path + self.val_data_path = val_data_path self.train_transforms = train_transforms self.val_transforms = val_transforms self.channel_map = ChannelMap(source=channel_name, target=channel_name) From 74e7db3633aed6d04c0995aee6f1db70abb51045 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 26 Feb 2024 09:31:49 -0800 Subject: [PATCH 29/74] configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu --- examples/configs/fit_example.yml | 13 +++ tests/conftest.py | 2 + tests/data/test_data.py | 33 ++---- viscy/data/hcs.py | 146 +++++++-------------------- viscy/data/typing.py | 22 ++++ viscy/preprocessing/preprocessing.md | 16 ++- viscy/transforms.py | 39 +++++++ 7 files changed, 139 insertions(+), 132 deletions(-) create mode 100644 viscy/data/typing.py diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index 017c57f0..fd17071e 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -37,6 +37,19 @@ data: batch_size: 32 num_workers: 16 yx_patch_size: [256, 256] + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [source] + level: 'fov_statistics', + subtrahend: 'mean' + divisor: 'std' + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [target_1] + level: 'fov_statistics', + subtrahend: 'median' + divisor: 'iqr' augmentations: - class_path: viscy.transforms.RandWeightedCropd init_args: diff --git a/tests/conftest.py b/tests/conftest.py index 9ad6630c..198e51ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,8 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names} with open_ome_zarr(dataset_path, mode="r+") as dataset: dataset.zattrs["normalization"] = norm_meta + for _, fov in dataset.positions(): + fov.zattrs["normalization"] = norm_meta return dataset_path diff --git a/tests/data/test_data.py b/tests/data/test_data.py index 153f175f..fb3d8620 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -18,6 +18,16 @@ def test_preprocess(small_hcs_dataset: Path, default_channels: bool): channel_names = dataset.channel_names trainer = VSTrainer(accelerator="cpu") trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + for channel in channel_names: + assert "dataset_statistics" in dataset.zattrs["normalization"][channel] + for _, fov in dataset.positions(): + norm_metadata = fov.zattrs["normalization"] + for channel in channel_names: + assert channel in norm_metadata + assert "dataset_statistics" in norm_metadata[channel] + assert "fov_statistics" in norm_metadata[channel] def test_datamodule_setup_predict(preprocessed_hcs_dataset): @@ -45,26 +55,3 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset): img.height, img.width, ) - - -def test_datamodule_predict_scales(preprocessed_hcs_dataset): - data_path = preprocessed_hcs_dataset - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - - def get_normalized_stack(predict_scale_source): - factor = 1 if predict_scale_source is None else predict_scale_source - dm = HCSDataModule( - data_path=data_path, - source_channel=channel_names[:2], - target_channel=channel_names[2:], - z_window_size=5, - batch_size=2, - num_workers=0, - predict_scale_source=predict_scale_source, - normalize_source=True, - ) - dm.setup(stage="predict") - return dm.predict_dataset[0]["source"] / factor - - assert torch.allclose(get_normalized_stack(None), get_normalized_stack(2)) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 218ea414..bb0be09c 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -5,7 +5,7 @@ import tempfile from glob import glob from pathlib import Path -from typing import Callable, Iterable, Literal, Optional, Sequence, TypedDict, Union +from typing import Callable, Literal, Optional, Sequence, Union import numpy as np import torch @@ -18,7 +18,6 @@ from monai.transforms import ( CenterSpatialCropd, Compose, - InvertibleTransform, MapTransform, MultiSampleTrait, RandAffined, @@ -26,6 +25,8 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.typing import ChannelMap, Sample + def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ @@ -55,24 +56,6 @@ def _search_int_in_str(pattern: str, file_name: str) -> str: raise ValueError(f"Cannot find pattern {pattern} in {file_name}.") -class ChannelMap(TypedDict, total=False): - """Source and target channel names.""" - - source: Union[str, Sequence[str]] - # optional - target: Union[str, Sequence[str]] - - -class Sample(TypedDict, total=False): - """Image sample type for mini-batches.""" - - index: tuple[str, int, int] - # optional - source: Union[Tensor, Sequence[Tensor]] - target: Union[Tensor, Sequence[Tensor]] - labels: Union[Tensor, Sequence[Tensor]] - - def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. @@ -89,38 +72,6 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: return collated -class NormalizeSampled(MapTransform, InvertibleTransform): - """Dictionary transform to only normalize target (fluorescence) channel. - - :param Union[str, Iterable[str]] keys: keys to normalize - :param dict[str, dict] norm_meta: Plate normalization metadata - written in preprocessing - """ - - def __init__( - self, keys: Union[str, Iterable[str]], norm_meta: dict[str, dict] - ) -> None: - if set(keys) > set(norm_meta.keys()): - raise KeyError(f"{keys} is not a subset of {norm_meta.keys()}") - super().__init__(keys, allow_missing_keys=False) - self.norm_meta = norm_meta - - def _stat(self, key: str) -> dict: - # FIXME: hard-coded key - return self.norm_meta[key]["dataset_statistics"] - - def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] - return d - - def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] - - class SlidingWindowDataset(Dataset): """Torch dataset where each element is a window of (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. @@ -161,6 +112,7 @@ def _get_windows(self) -> None: w = 0 self.window_keys = [] self.window_arrays = [] + self.window_norm_meta = [] for fov in self.positions: img_arr = fov["0"] ts = img_arr.frames @@ -168,6 +120,7 @@ def _get_windows(self) -> None: w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) + self.window_norm_meta.append(fov.zattrs["normalization"]) self._max_window = w def _find_window(self, index: int) -> tuple[int, int]: @@ -175,7 +128,8 @@ def _find_window(self, index: int) -> tuple[int, int]: window_idx = sorted(self.window_keys + [index + 1]).index(index + 1) w = self.window_keys[window_idx] tz = index - self.window_keys[window_idx - 1] if window_idx > 0 else index - return self.window_arrays[self.window_keys.index(w)], tz + norm_meta = self.window_norm_meta[self.window_keys.index(w)] + return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta) def _read_img_window( self, img: ImageArray, ch_idx: list[str], tz: int @@ -216,7 +170,7 @@ def _stack_channels( ] def __getitem__(self, index: int) -> Sample: - img, tz = self._find_window(index) + img, tz, norm_meta = self._find_window(index) ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() if self.target_ch_idx is not None: @@ -229,6 +183,7 @@ def __getitem__(self, index: int) -> Sample: # since adding a reference to a tensor does not copy # maybe write a weight map in preprocessing to use more information? sample_images["weight"] = sample_images[self.channels["target"][0]] + sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) # if isinstance(sample_images, list): @@ -238,6 +193,7 @@ def __getitem__(self, index: int) -> Sample: sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), + "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") @@ -312,18 +268,16 @@ class HCSDataModule(LightningDataModule): defaults to "2.5D" :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) + :param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms + applied to selected channels, defaults to None (no normalization) :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms applied to the training set, defaults to None (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False - :param bool normalize_source: whether to normalize the source channel, - defaults to False :param Optional[Path] ground_truth_masks: path to the ground truth masks, used in the test stage to compute segmentation metrics, defaults to None - :param Optional[float] predict_scale_source: scale the source channel intensity, - defaults to None (no scaling) """ def __init__( @@ -337,11 +291,10 @@ def __init__( num_workers: int = 8, architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), + normalizations: Optional[list[MapTransform]] = None, augmentations: Optional[list[MapTransform]] = None, caching: bool = False, - normalize_source: bool = False, ground_truth_masks: Optional[Path] = None, - predict_scale_source: Optional[float] = None, ): super().__init__() self.data_path = Path(data_path) @@ -353,21 +306,11 @@ def __init__( self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size + self.normalizations = normalizations self.augmentations = augmentations self.caching = caching - self.normalize_source = normalize_source self.ground_truth_masks = ground_truth_masks self.tmp_zarr = None - if predict_scale_source is not None: - if not normalize_source: - raise ValueError( - "Intensity scaling must be applied to normalized source channels." - ) - if predict_scale_source <= 0: - raise ValueError( - f"Intensity scaling {predict_scale_source} should be positive." - ) - self.predict_scale_source = predict_scale_source def prepare_data(self): if not self.caching: @@ -419,31 +362,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): else: raise NotImplementedError(f"{stage} stage") - def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]: - """Setup stages where the target is available (evaluating performance).""" - dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path - plate = open_ome_zarr(data_path, mode="r") - # disable metadata tracking in MONAI for performance - set_track_meta(False) - # define training stage transforms - norm_keys = self.target_channel.copy() - if self.normalize_source: - norm_keys += self.source_channel - normalize_transform = NormalizeSampled( - norm_keys, - plate.zattrs["normalization"], - ) - return plate, normalize_transform - def _setup_fit(self, dataset_settings: dict): """Set up the training and validation datasets.""" - plate, normalize_transform = self._setup_eval(dataset_settings) + # Setup the transformations + # TODO: These have a fixed order for now... (normalization->augmentation->fit_transform) fit_transform = self._fit_transform() train_transform = Compose( - [normalize_transform] + self._train_transform() + fit_transform + self.normalizations + self._train_transform() + fit_transform ) - val_transform = Compose([normalize_transform] + fit_transform) + val_transform = Compose(self.normalizations + fit_transform) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + plate = open_ome_zarr(data_path, mode="r") + + # disable metadata tracking in MONAI for performance + set_track_meta(False) # shuffle positions, randomness is handled globally positions = [pos for _, pos in plate.positions()] shuffled_indices = torch.randperm(len(positions)) @@ -465,26 +399,31 @@ def _setup_fit(self, dataset_settings: dict): **train_dataset_settings, ) self.val_dataset = SlidingWindowDataset( - positions[num_train_fovs:], transform=val_transform, **dataset_settings + positions[num_train_fovs:], + transform=val_transform, + **dataset_settings, ) def _setup_test(self, dataset_settings: dict): """Set up the test stage.""" if self.batch_size != 1: logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") - plate, normalize_transform = self._setup_eval(dataset_settings) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + plate = open_ome_zarr(data_path, mode="r") if self.ground_truth_masks: self.test_dataset = MaskTestDataset( [p for _, p in plate.positions()], - transform=normalize_transform, + transform=self.normalizations, ground_truth_masks=self.ground_truth_masks, - **dataset_settings, + norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( [p for _, p in plate.positions()], - transform=normalize_transform, - **dataset_settings, + transform=self.normalizations, + norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) def _setup_predict(self, dataset_settings: dict): @@ -506,16 +445,9 @@ def _setup_predict(self, dataset_settings: dict): positions = [plate[fov_name]] elif isinstance(dataset, Plate): positions = [p for _, p in dataset.positions()] - norm_meta = dataset.zattrs["normalization"].copy() - if self.predict_scale_source is not None: - for ch in self.source_channel: - # FIXME: hard-coded key - norm_meta[ch]["dataset_statistics"]["iqr"] /= self.predict_scale_source - predict_transform = ( - NormalizeSampled(self.source_channel, norm_meta) - if self.normalize_source - else None - ) + + predict_transform = self.normalizations + self.predict_dataset = SlidingWindowDataset( positions=positions, transform=predict_transform, diff --git a/viscy/data/typing.py b/viscy/data/typing.py new file mode 100644 index 00000000..c6b7c32f --- /dev/null +++ b/viscy/data/typing.py @@ -0,0 +1,22 @@ +from typing import Sequence, TypedDict, Union + +from torch import Tensor + + +class Sample(TypedDict, total=False): + """Image sample type for mini-batches.""" + + index: tuple[str, int, int] + # optional + source: Union[Tensor, Sequence[Tensor]] + target: Union[Tensor, Sequence[Tensor]] + labels: Union[Tensor, Sequence[Tensor]] + norm_meta: dict[str, dict] + + +class ChannelMap(TypedDict, total=False): + """Source and target channel names.""" + + source: Union[str, Sequence[str]] + # optional + target: Union[str, Sequence[str]] diff --git a/viscy/preprocessing/preprocessing.md b/viscy/preprocessing/preprocessing.md index 76d508c5..809b456f 100644 --- a/viscy/preprocessing/preprocessing.md +++ b/viscy/preprocessing/preprocessing.md @@ -87,11 +87,17 @@ The statistics are added as dictionaries into the .zattrs file. An example of pl } ``` -FOV level statistics added to every position: +FOV level statistics added to every position as well as the dataset_statistics to read dataset statistics: ```json "normalization": { "Deconvolved-Nuc": { + "dataset_statistics": { + "iqr": 149.7620086669922, + "mean": 262.2070617675781, + "median": 65.5246353149414, + "std": 890.0471801757812 + }, "fov_statistics": { "iqr": 450.4745788574219, "mean": 486.3854064941406, @@ -99,7 +105,13 @@ FOV level statistics added to every position: "std": 976.02392578125 } }, - "Phase3D": { + "Phase3D": { + "dataset_statistics": { + "iqr": 0.0011349652777425945, + "mean": -1.9603044165705796e-06, + "median": 3.388232289580628e-05, + "std": 0.005480962339788675 + }, "fov_statistics": { "iqr": 0.006403466919437051, "mean": 0.0010083537781611085, diff --git a/viscy/transforms.py b/viscy/transforms.py index cb3d2622..7ce192af 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -3,6 +3,7 @@ from typing import Sequence, Union from monai.transforms import ( + MapTransform, RandAdjustContrastd, RandAffined, RandGaussianNoised, @@ -10,6 +11,9 @@ RandScaleIntensityd, RandWeightedCropd, ) +from typing_extensions import Iterable, Literal + +from viscy.data.typing import Sample class RandWeightedCropd(RandWeightedCropd): @@ -118,3 +122,38 @@ def __init__( sigma_z=sigma_z, **kwargs, ) + + +class NormalizeSampled(MapTransform): + """ + Normalize the sample + :param Union[str, Iterable[str]] keys: keys to normalize + :param str fov: fov path with respect to Plate + :param str subtrahend: subtrahend for normalization, defaults to "mean" + :param str divisor: divisor for normalization, defaults to "std" + """ + + def __init__( + self, + keys: Union[str, Iterable[str]], + level: Literal["fov_statistics", "dataset_statistics"], + subtrahend="mean", + divisor="std", + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self.subtrahend = subtrahend + self.divisor = divisor + self.level = level + + # TODO: need to implement the case where the preprocessing already exists + def __call__(self, sample: Sample) -> Sample: + for key in self.keys: + if key in self.keys: + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] + divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero + sample[key] = (sample[key] - subtrahend_val) / divisor_val + return sample + + def _normalize(): + NotImplementedError("_normalization() not implemented") From 9b3b032100b480f9340b8aa8b124e8116f232820 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 27 Feb 2024 17:33:53 -0800 Subject: [PATCH 30/74] fix ctmc dataloading --- viscy/data/ctmc_v1.py | 6 +++--- viscy/data/hcs.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 0d65a36a..47844d68 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -29,7 +29,7 @@ def __init__( self.val_data_path = val_data_path self.train_transforms = train_transforms self.val_transforms = val_transforms - self.channel_map = ChannelMap(source=channel_name, target=channel_name) + self.channel_map = ChannelMap(source=[channel_name], target=[channel_name]) self.batch_size = batch_size self.num_workers = num_workers @@ -47,13 +47,13 @@ def _setup_fit(self) -> None: train_positions, channels=self.channel_map, z_window_size=1, - transform=Compose(self.train_transform), + transform=Compose(self.train_transforms), ) self.val_dataset = SlidingWindowDataset( val_positions, channels=self.channel_map, z_window_size=1, - transform=Compose(self.val_transform), + transform=Compose(self.val_transforms), ) def train_dataloader(self) -> DataLoader: diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index bb0be09c..2c7397c0 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -120,7 +120,7 @@ def _get_windows(self) -> None: w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) - self.window_norm_meta.append(fov.zattrs["normalization"]) + self.window_norm_meta.append(fov.zattrs.get("normalization", 0)) self._max_window = w def _find_window(self, index: int) -> tuple[int, int]: From a3569364ac18897858471c78d4a4c6f3381c6d1c Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 27 Feb 2024 17:34:30 -0800 Subject: [PATCH 31/74] add example ctmc v1 loading script --- viscy/scripts/load_ctmc_v1.py | 68 +++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 viscy/scripts/load_ctmc_v1.py diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py new file mode 100644 index 00000000..e5c19094 --- /dev/null +++ b/viscy/scripts/load_ctmc_v1.py @@ -0,0 +1,68 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCropd, + NormalizeIntensityd, + RandAffined, + RandScaleIntensityd, +) +from tqdm import tqdm + +from viscy.data.ctmc_v1 import CTMCv1DataModule + +# %% +data_path = Path("") + +normalize_transform = NormalizeIntensityd(keys=["DIC"], channel_wise=True) +crop_transform = CenterSpatialCropd(keys=["DIC"], roi_size=[1, 256, 256]) + +data = CTMCv1DataModule( + train_data_path=data_path / "CTMCV1_test.zarr", + val_data_path=data_path / "CTMCV1_train.zarr", + train_transforms=[ + normalize_transform, + RandAffined( + keys=["DIC"], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.0, 0.3, 0.3], + scale_range=[0.0, 0.3, 0.3], + prob=0.8, + ), + RandScaleIntensityd(keys=["DIC"], factors=0.3, prob=0.5), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=4, + num_workers=0, + channel_name="DIC", +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% From bac26bedeb1037bf0eec44fb7f1b65fd3da7b653 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 28 Feb 2024 15:52:41 -0800 Subject: [PATCH 32/74] changing the normalization and augmentations default from None to empty list. --- viscy/data/hcs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 2c7397c0..af9a03a8 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -269,9 +269,9 @@ class HCSDataModule(LightningDataModule): :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) :param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms - applied to selected channels, defaults to None (no normalization) + applied to selected channels, defaults to [] (no normalization) :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms - applied to the training set, defaults to None (no augmentation) + applied to the training set, defaults to [] (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False @@ -291,8 +291,8 @@ def __init__( num_workers: int = 8, architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), - normalizations: Optional[list[MapTransform]] = None, - augmentations: Optional[list[MapTransform]] = None, + normalizations: Optional[list[MapTransform]] = [], + augmentations: Optional[list[MapTransform]] = [], caching: bool = False, ground_truth_masks: Optional[Path] = None, ): From 0b598c7e1b9fc0bbbc2aa08307964300f66b7f8a Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:03:58 -0800 Subject: [PATCH 33/74] invert intensity transform --- viscy/transforms.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index 7ce192af..88e7f738 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -8,9 +8,12 @@ RandAffined, RandGaussianNoised, RandGaussianSmoothd, + RandomizableTransform, RandScaleIntensityd, RandWeightedCropd, ) +from monai.transforms.transform import Randomizable +from numpy.random.mtrand import RandomState as RandomState from typing_extensions import Iterable, Literal from viscy.data.typing import Sample @@ -148,12 +151,34 @@ def __init__( # TODO: need to implement the case where the preprocessing already exists def __call__(self, sample: Sample) -> Sample: for key in self.keys: - if key in self.keys: - level_meta = sample["norm_meta"][key][self.level] - subtrahend_val = level_meta[self.subtrahend] - divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero - sample[key] = (sample[key] - subtrahend_val) / divisor_val + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] + divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero + sample[key] = (sample[key] - subtrahend_val) / divisor_val return sample def _normalize(): NotImplementedError("_normalization() not implemented") + + +class RandInvertIntensityd(MapTransform, RandomizableTransform): + """ + Randomly invert the intensity of the image. + """ + + def __init__(self, keys: Union[str, Iterable[str]], prob: float = 0.1) -> None: + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) + + def __call__(self, sample: Sample) -> Sample: + self.randomize(None) + for key in self.keys: + if key in sample: + sample[key] = -sample[key] + return sample + + def set_random_state( + self, seed: int | None = None, state: RandomState | None = None + ) -> Randomizable: + super().set_random_state(seed, state) + return self From ddb30e9d05ebb7378a0ec1ee29acab6a33b32a14 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:17 -0800 Subject: [PATCH 34/74] concatenated data module --- viscy/data/combined.py | 72 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 5 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 5da700dd..45072909 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,9 +1,19 @@ +from enum import Enum from typing import Literal, Sequence from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader +from torch import Tensor +from torch.utils.data import ConcatDataset, DataLoader -_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] +from viscy.data.hcs import _collate_samples + + +class CombineMode(Enum): + MIN_SIZE = "min_size" + MAX_SIZE_CYCLE = "max_size_cycle" + MAX_SIZE = "max_size" + SEQUENTIAL = "sequential" class CombinedDataModule(LightningDataModule): @@ -20,10 +30,10 @@ class CombinedDataModule(LightningDataModule): def __init__( self, data_modules: Sequence[LightningDataModule], - train_mode: _MODES = "max_size_cycle", - val_mode: _MODES = "sequential", - test_mode: _MODES = "sequential", - predict_mode: _MODES = "sequential", + train_mode: CombineMode = CombineMode.MAX_SIZE_CYCLE, + val_mode: CombineMode = CombineMode.SEQUENTIAL, + test_mode: CombineMode = CombineMode.SEQUENTIAL, + predict_mode: CombineMode = CombineMode.SEQUENTIAL, ): super().__init__() self.data_modules = data_modules @@ -60,3 +70,55 @@ def predict_dataloader(self): [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, ) + + +class ConcatDataModule(LightningDataModule): + def __init__(self, data_modules: Sequence[LightningDataModule]): + super().__init__() + self.data_modules = data_modules + self.num_workers = data_modules[0].num_workers + self.batch_size = data_modules[0].batch_size + for dm in data_modules: + if dm.num_workers != self.num_workers: + raise ValueError("Inconsistent number of workers") + if dm.batch_size != self.batch_size: + raise ValueError("Inconsistent batch size") + + def prepare_data(self): + for dm in self.data_modules: + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + self.train_patches_per_stack = 0 + for dm in self.data_modules: + dm.setup(stage) + if patches := getattr(dm, "train_patches_per_stack", 0): + if self.train_patches_per_stack == 0: + self.train_patches_per_stack = patches + elif self.train_patches_per_stack != patches: + raise ValueError("Inconsistent patches per stack") + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self.train_dataset = ConcatDataset( + [dm.train_dataset for dm in self.data_modules] + ) + self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.data_modules]) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size // self.train_patches_per_stack, + num_workers=self.num_workers, + shuffle=True, + persistent_workers=bool(self.num_workers), + collate_fn=_collate_samples, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=bool(self.num_workers), + ) From 950475584f15534638c4c83d6e3fcf21314bb1e7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:37 -0800 Subject: [PATCH 35/74] subsample videos --- viscy/data/ctmc_v1.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 47844d68..d666fdcb 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -6,12 +6,33 @@ from torch.utils.data import DataLoader from viscy.data.hcs import ChannelMap, SlidingWindowDataset +from viscy.data.typing import Sample + + +class CTMCv1ValidationDataset(SlidingWindowDataset): + subsample_rate: int = 30 + + def __len__(self) -> int: + # sample every 30th frame in the videos + return super().__len__() // self.subsample_rate + + def __getitem__(self, index: int) -> Sample: + index = index * self.subsample_rate + return super().__getitem__(index) class CTMCv1DataModule(LightningDataModule): """ Autoregression data module for the CTMCv1 dataset. Training and validation datasets are stored in separate HCS OME-Zarr stores. + + :param str | Path train_data_path: Path to the training dataset + :param str | Path val_data_path: Path to the validation dataset + :param list[MapTransform] train_transforms: List of transforms for training + :param list[MapTransform] val_transforms: List of transforms for validation + :param int batch_size: Batch size, defaults to 16 + :param int num_workers: Number of workers, defaults to 8 + :param str channel_name: Name of the DIC channel, defaults to "DIC" """ def __init__( @@ -49,7 +70,7 @@ def _setup_fit(self) -> None: z_window_size=1, transform=Compose(self.train_transforms), ) - self.val_dataset = SlidingWindowDataset( + self.val_dataset = CTMCv1ValidationDataset( val_positions, channels=self.channel_map, z_window_size=1, From 808e39c02763f21c6cb07b79d2c60c2501220021 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:48 -0800 Subject: [PATCH 36/74] livecell dataset --- viscy/data/livecell.py | 98 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 viscy/data/livecell.py diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py new file mode 100644 index 00000000..5d83f099 --- /dev/null +++ b/viscy/data/livecell.py @@ -0,0 +1,98 @@ +import json +from pathlib import Path + +import torch +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, Transform +from tifffile import imread +from torch.utils.data import DataLoader, Dataset + +from viscy.data.typing import Sample + + +class LiveCellDataset(Dataset): + """ + LiveCell dataset. + + :param list[Path] images: List of paths to single-page, single-channel TIFF files. + :param Transform | Compose transform: Transform to apply to the dataset + """ + + def __init__(self, images: list[Path], transform: Transform | Compose) -> None: + self.images = images + self.transform = transform + + def __len__(self) -> int: + return len(self.images) + + def __getitem__(self, idx: int) -> Sample: + image = imread(self.images[idx])[None, None] + image = torch.from_numpy(image).to(torch.float32) + image = self.transform(image) + return {"source": image, "target": image} + + +class LiveCellDataModule(LightningDataModule): + def __init__( + self, + train_val_images: Path, + train_annotations: Path, + val_annotations: Path, + train_transforms: list[Transform], + val_transforms: list[Transform], + batch_size: int = 16, + num_workers: int = 8, + ) -> None: + super().__init__() + self.train_val_images = Path(train_val_images) + if not self.train_val_images.is_dir(): + raise NotADirectoryError(str(train_val_images)) + self.train_annotations = Path(train_annotations) + if not self.train_annotations.is_file(): + raise FileNotFoundError(str(train_annotations)) + self.val_annotations = Path(val_annotations) + if not self.val_annotations.is_file(): + raise FileNotFoundError(str(val_annotations)) + self.train_transforms = Compose(train_transforms) + self.val_transforms = Compose(val_transforms) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self._setup_fit() + + def _parse_image_names(self, annotations: Path) -> list[Path]: + with open(annotations) as f: + images = [f["file_name"] for f in json.load(f)["images"]] + return sorted(images) + + def _setup_fit(self) -> None: + train_images = self._parse_image_names(self.train_annotations) + val_images = self._parse_image_names(self.val_annotations) + self.train_dataset = LiveCellDataset( + [self.train_val_images / f for f in train_images], + transform=self.train_transforms, + ) + self.val_dataset = LiveCellDataset( + [self.train_val_images / f for f in val_images], + transform=self.val_transforms, + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + ) From 43d641db2e448336be64a6ccd17ecd4a8c218b95 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:05:04 -0800 Subject: [PATCH 37/74] all sample fields are optional --- viscy/data/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/typing.py b/viscy/data/typing.py index c6b7c32f..aef7dea7 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -6,8 +6,8 @@ class Sample(TypedDict, total=False): """Image sample type for mini-batches.""" + # all optional index: tuple[str, int, int] - # optional source: Union[Tensor, Sequence[Tensor]] target: Union[Tensor, Sequence[Tensor]] labels: Union[Tensor, Sequence[Tensor]] From 42f81cfd2093e020e97db1238cc897e623fedcb1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:05:19 -0800 Subject: [PATCH 38/74] fix multi-dataloader validation --- viscy/light/engine.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 4d18e9c4..6c284954 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -194,12 +194,12 @@ def training_step(self, batch: Sample, batch_idx: int): ) return loss - def validation_step(self, batch: Sample, batch_idx: int): + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = batch["source"] target = batch["target"] pred = self.forward(source) loss = self.loss_function(pred, target) - self.log("loss/validate", loss, sync_dist=True) + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) @@ -425,7 +425,15 @@ def training_step(self, batch: Sequence[Sample], batch_idx: int): def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source, target, pred, mask, loss = self.forward_fit(batch) - self.validation_losses.append(loss.detach()) + if dataloader_idx + 1 > len(self.validation_losses): + self.validation_losses.append([]) + self.validation_losses[dataloader_idx].append(loss.detach()) + self.log( + f"loss/val/{dataloader_idx}", + loss, + sync_dist=True, + batch_size=source.shape[0], + ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) @@ -433,6 +441,6 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 def on_validation_epoch_end(self): super().on_validation_epoch_end() - self.log( - "loss/validate", torch.stack(self.validation_losses).mean(), sync_dist=True - ) + # average within each dataloader + loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] + self.log("loss/validate", torch.tensor(loss_means).mean(), sync_dist=True) From 4546fc77b8ee469b1a93f9689404b3fc47cc622d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:08:26 -0800 Subject: [PATCH 39/74] lint --- viscy/data/combined.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 45072909..d70b9333 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -3,7 +3,6 @@ from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader -from torch import Tensor from torch.utils.data import ConcatDataset, DataLoader from viscy.data.hcs import _collate_samples From 306f3efadce651298647a1a8e60dcdd95eccb6d0 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 29 Feb 2024 13:13:25 -0800 Subject: [PATCH 40/74] fixing preprocessing for varying array shapes (i.e aics dataset) --- viscy/utils/meta_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index d644dadf..961b6696 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -104,8 +104,9 @@ def generate_normalization_metadata( positions, fov_sample_values = mp_utils.mp_sample_im_pixels( this_channels_args, num_workers ) - dataset_sample_values = np.stack(fov_sample_values, 0) - + dataset_sample_values = np.concatenate( + [arr.flatten() for arr in fov_sample_values] + ) fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) From 1a0e3ced8711bcdae7c6698c899aff45f7bdc777 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Mar 2024 20:51:50 -0800 Subject: [PATCH 41/74] update loading scripts --- viscy/scripts/load_ctmc_v1.py | 38 ++++++++++----- viscy/scripts/load_livecell.py | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 11 deletions(-) create mode 100644 viscy/scripts/load_livecell.py diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py index e5c19094..41cef698 100644 --- a/viscy/scripts/load_ctmc_v1.py +++ b/viscy/scripts/load_ctmc_v1.py @@ -5,7 +5,11 @@ from monai.transforms import ( CenterSpatialCropd, NormalizeIntensityd, + RandAdjustContrastd, RandAffined, + RandFlipd, + RandGaussianNoised, + RandGaussianSmoothd, RandScaleIntensityd, ) from tqdm import tqdm @@ -13,10 +17,11 @@ from viscy.data.ctmc_v1 import CTMCv1DataModule # %% -data_path = Path("") +channel = "DIC" +data_path = Path("/hpc/reference/imaging/ctmc") -normalize_transform = NormalizeIntensityd(keys=["DIC"], channel_wise=True) -crop_transform = CenterSpatialCropd(keys=["DIC"], roi_size=[1, 256, 256]) +normalize_transform = NormalizeIntensityd(keys=[channel], channel_wise=True) +crop_transform = CenterSpatialCropd(keys=[channel], roi_size=[1, 224, 224]) data = CTMCv1DataModule( train_data_path=data_path / "CTMCV1_test.zarr", @@ -24,19 +29,29 @@ train_transforms=[ normalize_transform, RandAffined( - keys=["DIC"], + keys=[channel], rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.3, 0.3], - scale_range=[0.0, 0.3, 0.3], + scale_range=[0.0, [-0.6, 0.1], [-0.6, 0.1]], prob=0.8, + padding_mode="zeros", + ), + RandFlipd(keys=[channel], prob=0.5, spatial_axis=(1,2)), + RandAdjustContrastd(keys=[channel], prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=[channel], factors=0.3, prob=0.5), + RandGaussianNoised(keys=[channel], prob=0.5, mean=0.0, std=0.2), + RandGaussianSmoothd( + keys=[channel], + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, ), - RandScaleIntensityd(keys=["DIC"], factors=0.3, prob=0.5), crop_transform, ], val_transforms=[normalize_transform, crop_transform], - batch_size=4, + batch_size=32, num_workers=0, - channel_name="DIC", + channel_name=channel, ) # %% @@ -47,7 +62,8 @@ # %% for batch in tqdm(dmt): img = batch["source"] - f, ax = plt.subplots(4, 4, figsize=(12, 12)) + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(5, 5, figsize=(15, 15)) for sample, a in zip(img, ax.flatten()): a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) a.axis("off") @@ -57,7 +73,7 @@ # %% for batch in tqdm(dmv): img = batch["source"] - f, ax = plt.subplots(4, 4, figsize=(12, 12)) + f, ax = plt.subplots(5, 5, figsize=(15, 15)) for sample, a in zip(img, ax.flatten()): a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) a.axis("off") diff --git a/viscy/scripts/load_livecell.py b/viscy/scripts/load_livecell.py new file mode 100644 index 00000000..cfaf2dfe --- /dev/null +++ b/viscy/scripts/load_livecell.py @@ -0,0 +1,85 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCrop, + NormalizeIntensity, + RandAdjustContrast, + RandAffine, + RandFlip, + RandGaussianNoise, + RandGaussianSmooth, + RandScaleIntensity, + RandSpatialCrop, +) +from tqdm import tqdm + +from viscy.data.livecell import LiveCellDataModule + +# %% +data_path = Path("/hpc/reference/imaging/livecell") + +normalize_transform = NormalizeIntensity(channel_wise=True) +crop_transform = CenterSpatialCrop(roi_size=[1, 224, 224]) + +data = LiveCellDataModule( + train_val_images=data_path / "images" / "livecell_train_val_images", + train_annotations=data_path + / "annotations" + / "livecell_coco_train_images_only.json", + val_annotations=data_path / "annotations" / "livecell_coco_val_images_only.json", + train_transforms=[ + normalize_transform, + RandSpatialCrop(roi_size=[1, 384, 384]), + RandAffine( + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, [-0.2, 0.8], [-0.2, 0.8]], + prob=0.8, + padding_mode="zeros", + ), + RandFlip(prob=0.5, spatial_axis=(1, 2)), + RandAdjustContrast(prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensity(factors=0.3, prob=0.5), + RandGaussianNoise(prob=0.5, mean=0.0, std=0.3), + RandGaussianSmooth( + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, + ), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=16, + num_workers=0, +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["target"] + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(4, 4, figsize=(15, 15)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% From d3ec94d2c0142bf073b2019ab9ac6eba4312eddd Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Mar 2024 21:26:12 -0800 Subject: [PATCH 42/74] fix CombineMode --- viscy/data/combined.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index d70b9333..13e64e21 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -36,10 +36,10 @@ def __init__( ): super().__init__() self.data_modules = data_modules - self.train_mode = train_mode - self.val_mode = val_mode - self.test_mode = test_mode - self.predict_mode = predict_mode + self.train_mode = CombineMode(train_mode).value + self.val_mode = CombineMode(val_mode).value + self.test_mode = CombineMode(test_mode).value + self.predict_mode = CombineMode(predict_mode).value def prepare_data(self): for dm in self.data_modules: From 02e6d0b09ce4a8d4c59556df072a31642f8c0dda Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Mar 2024 23:32:41 -0800 Subject: [PATCH 43/74] always use untrainable head for FCMAE --- viscy/unet/networks/fcmae.py | 54 ++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 97771365..2b398117 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -8,6 +8,7 @@ from typing import Sequence import torch +from monai.networks.blocks import UpSample from timm.models.convnext import ( Downsample, DropPath, @@ -18,7 +19,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.Unet21D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead +from viscy.unet.networks.Unet21D import Unet2dDecoder def _init_weights(module: nn.Module) -> None: @@ -362,6 +363,35 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: return features, mask +class PixelToVoxelShuffleHead(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + out_stack_depth: int = 5, + xy_scaling: int = 4, + pool: bool = False, + ) -> None: + super().__init__() + self.out_channels = out_channels + self.out_stack_depth = out_stack_depth + self.upsample = UpSample( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_stack_depth * out_channels, + scale_factor=xy_scaling, + mode="pixelshuffle", + pre_conv=None, + apply_pad_pool=pool, + ) + + def forward(self, x: Tensor) -> Tensor: + x = self.upsample(x) + b, _, h, w = x.shape + x = x.reshape(b, self.out_channels, self.out_stack_depth, h, w) + return x + + class FullyConvolutionalMAE(nn.Module): def __init__( self, @@ -370,7 +400,6 @@ def __init__( encoder_blocks: Sequence[int] = [3, 3, 9, 3], dims: Sequence[int] = [96, 192, 384, 768], encoder_drop_path_rate: float = 0.0, - head_expansion_ratio: int = 4, stem_kernel_size: Sequence[int] = (5, 4, 4), in_stack_depth: int = 5, decoder_conv_blocks: int = 1, @@ -387,9 +416,7 @@ def __init__( ) decoder_channels = list(dims) decoder_channels.reverse() - decoder_channels[-1] = ( - (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio - ) + decoder_channels[-1] = (in_stack_depth + 2) * in_channels * 2**2 self.decoder = Unet2dDecoder( decoder_channels, norm_name="instance", @@ -398,16 +425,13 @@ def __init__( strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], upsample_pre_conv=None, ) - if in_stack_depth == 1: - self.head = UnsqueezeHead() - else: - self.head = PixelToVoxelHead( - in_channels=decoder_channels[-1], - out_channels=out_channels, - out_stack_depth=in_stack_depth, - expansion_ratio=head_expansion_ratio, - pool=True, - ) + self.head = PixelToVoxelShuffleHead( + in_channels=decoder_channels[-1], + out_channels=out_channels, + out_stack_depth=in_stack_depth, + xy_scaling=stem_kernel_size[-1], + pool=True, + ) self.out_stack_depth = in_stack_depth self.num_blocks = 6 self.pretraining = pretraining From e18d305dbc21d2bd4632af978ab388795bd37cd0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Sat, 2 Mar 2024 11:03:52 -0800 Subject: [PATCH 44/74] move log values to GPU before syncing https://github.com/Lightning-AI/pytorch-lightning/issues/18803 --- viscy/light/engine.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 6c284954..3dc92b74 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -413,7 +413,7 @@ def training_step(self, batch: Sequence[Sample], batch_idx: int): loss_step = torch.stack(losses).mean() self.log( "loss/train", - loss_step, + loss_step.to(self.device), on_step=True, on_epoch=True, prog_bar=True, @@ -430,7 +430,7 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 self.validation_losses[dataloader_idx].append(loss.detach()) self.log( f"loss/val/{dataloader_idx}", - loss, + loss.to(self.device), sync_dist=True, batch_size=source.shape[0], ) @@ -443,4 +443,8 @@ def on_validation_epoch_end(self): super().on_validation_epoch_end() # average within each dataloader loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] - self.log("loss/validate", torch.tensor(loss_means).mean(), sync_dist=True) + self.log( + "loss/validate", + torch.tensor(loss_means).mean().to(self.device), + sync_dist=True, + ) From 01c71cf18148cc335a7e423c17d63666bf3a8bb0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Sat, 2 Mar 2024 11:04:20 -0800 Subject: [PATCH 45/74] custom head --- tests/unet/test_fcmae.py | 8 ++++++++ viscy/unet/networks/fcmae.py | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 4ed441b4..9f3fc805 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -2,6 +2,7 @@ from viscy.unet.networks.fcmae import ( FullyConvolutionalMAE, + PixelToVoxelShuffleHead, MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, @@ -104,6 +105,13 @@ def test_masked_multiscale_encoder(): assert afeat.shape[2] == afeat.shape[3] == xy_size // stride +def test_pixel_to_voxel_shuffle_head(): + head = PixelToVoxelShuffleHead(240, 3, out_stack_depth=5, xy_scaling=4) + x = torch.rand(2, 240, 16, 16) + y = head(x) + assert y.shape == (2, 3, 5, 64, 64) + + def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) model = FullyConvolutionalMAE(3, 3) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 2b398117..821f8421 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -5,6 +5,7 @@ and timm's dense implementation of the encoder in ``timm.models.convnext`` """ +import math from typing import Sequence import torch @@ -416,7 +417,7 @@ def __init__( ) decoder_channels = list(dims) decoder_channels.reverse() - decoder_channels[-1] = (in_stack_depth + 2) * in_channels * 2**2 + decoder_channels[-1] = out_channels * in_stack_depth * stem_kernel_size[-1] ** 2 self.decoder = Unet2dDecoder( decoder_channels, norm_name="instance", @@ -433,7 +434,8 @@ def __init__( pool=True, ) self.out_stack_depth = in_stack_depth - self.num_blocks = 6 + # TODO: replace num_blocks with explicit strides for all models + self.num_blocks = len(dims) * int(math.log2(stem_kernel_size[-1])) self.pretraining = pretraining def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: From dd64b31aab4c38370589cf4423668a64dceab2a8 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 4 Mar 2024 14:29:23 -0800 Subject: [PATCH 46/74] ddp caching fixes --- viscy/data/hcs.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index af9a03a8..205b1e9e 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -199,10 +199,6 @@ def __getitem__(self, index: int) -> Sample: sample["target"] = self._stack_channels(sample_images, "target") return sample - def __del__(self): - """Close the Zarr store when the dataset instance gets GC'ed.""" - self.positions[0].zgroup.store.close() - class MaskTestDataset(SlidingWindowDataset): """Torch dataset where each element is a window of @@ -310,7 +306,13 @@ def __init__( self.augmentations = augmentations self.caching = caching self.ground_truth_masks = ground_truth_masks - self.tmp_zarr = None + self.prepare_data_per_node = True + + @property + def cache_path(self): + return Path( + tempfile.gettempdir(), os.getenv("SLURM_JOB_ID"), self.data_path.name + ) def prepare_data(self): if not self.caching: @@ -322,20 +324,14 @@ def prepare_data(self): console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) logger.addHandler(console_handler) - os.mkdir(self.trainer.logger.log_dir) + os.makedirs(self.trainer.logger.log_dir, exist_ok=True) file_handler = logging.FileHandler( os.path.join(self.trainer.logger.log_dir, "data.log") ) file_handler.setLevel(logging.DEBUG) logger.addHandler(file_handler) - # cache in temporary directory - self.tmp_zarr = os.path.join( - tempfile.gettempdir(), - os.getenv("SLURM_JOB_ID"), - os.path.basename(self.data_path), - ) - logger.info(f"Caching dataset at {self.tmp_zarr}.") - tmp_store = zarr.NestedDirectoryStore(self.tmp_zarr) + logger.info(f"Caching dataset at {self.cache_path}.") + tmp_store = zarr.NestedDirectoryStore(self.cache_path) with open_ome_zarr(self.data_path, mode="r") as lazy_plate: _, skipped, _ = zarr.copy( lazy_plate.zgroup, @@ -373,7 +369,7 @@ def _setup_fit(self, dataset_settings: dict): val_transform = Compose(self.normalizations + fit_transform) dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + data_path = self.cache_path if self.caching else self.data_path plate = open_ome_zarr(data_path, mode="r") # disable metadata tracking in MONAI for performance @@ -410,7 +406,7 @@ def _setup_test(self, dataset_settings: dict): logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + data_path = self.cache_path if self.cache_path else self.data_path plate = open_ome_zarr(data_path, mode="r") if self.ground_truth_masks: self.test_dataset = MaskTestDataset( @@ -476,7 +472,9 @@ def train_dataloader(self): num_workers=self.num_workers, shuffle=True, persistent_workers=bool(self.num_workers), + prefetch_factor=4, collate_fn=_collate_samples, + drop_last=True, ) def val_dataloader(self): From b3ea8d726c6f8c4adb51088ac7b02a3f7fe20b95 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 4 Mar 2024 14:29:48 -0800 Subject: [PATCH 47/74] fix caching when using combined loader --- viscy/data/combined.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 13e64e21..7fc071b8 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -40,9 +40,11 @@ def __init__( self.val_mode = CombineMode(val_mode).value self.test_mode = CombineMode(test_mode).value self.predict_mode = CombineMode(predict_mode).value + self.prepare_data_per_node = True def prepare_data(self): for dm in self.data_modules: + dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): From d3db2bb2268d15bb27df6310c99ff6d80986c63c Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 4 Mar 2024 15:04:29 -0800 Subject: [PATCH 48/74] compose normalizations for predict and test stages --- viscy/data/hcs.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index af9a03a8..db1fea21 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -412,17 +412,18 @@ def _setup_test(self, dataset_settings: dict): dataset_settings["channels"]["target"] = self.target_channel data_path = self.tmp_zarr if self.tmp_zarr else self.data_path plate = open_ome_zarr(data_path, mode="r") + test_transform = Compose(self.normalizations) if self.ground_truth_masks: self.test_dataset = MaskTestDataset( [p for _, p in plate.positions()], - transform=self.normalizations, + transform=test_transform, ground_truth_masks=self.ground_truth_masks, norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( [p for _, p in plate.positions()], - transform=self.normalizations, + transform=test_transform, norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) @@ -445,9 +446,7 @@ def _setup_predict(self, dataset_settings: dict): positions = [plate[fov_name]] elif isinstance(dataset, Plate): positions = [p for _, p in dataset.positions()] - - predict_transform = self.normalizations - + predict_transform = Compose(self.normalizations) self.predict_dataset = SlidingWindowDataset( positions=positions, transform=predict_transform, From a549d4e19d5b1ffc5a81d2ba4c9eee5154cc207d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 4 Mar 2024 15:06:34 -0800 Subject: [PATCH 49/74] black --- viscy/scripts/load_ctmc_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py index 41cef698..d4326b81 100644 --- a/viscy/scripts/load_ctmc_v1.py +++ b/viscy/scripts/load_ctmc_v1.py @@ -35,7 +35,7 @@ prob=0.8, padding_mode="zeros", ), - RandFlipd(keys=[channel], prob=0.5, spatial_axis=(1,2)), + RandFlipd(keys=[channel], prob=0.5, spatial_axis=(1, 2)), RandAdjustContrastd(keys=[channel], prob=0.5, gamma=(0.8, 1.2)), RandScaleIntensityd(keys=[channel], factors=0.3, prob=0.5), RandGaussianNoised(keys=[channel], prob=0.5, mean=0.0, std=0.2), From a38da8b02eea0c1691bad996f23870f4741bb325 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 5 Mar 2024 18:53:13 -0800 Subject: [PATCH 50/74] fix normalization in example config --- examples/configs/fit_example.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index fd17071e..32e3fdee 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -39,17 +39,17 @@ data: yx_patch_size: [256, 256] normalizations: - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [source] - level: 'fov_statistics', - subtrahend: 'mean' - divisor: 'std' + init_args: + keys: [source] + level: "fov_statistics" + subtrahend: "mean" + divisor: "std" - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [target_1] - level: 'fov_statistics', - subtrahend: 'median' - divisor: 'iqr' + init_args: + keys: [target_1] + level: "fov_statistics" + subtrahend: "median" + divisor: "iqr" augmentations: - class_path: viscy.transforms.RandWeightedCropd init_args: From af317c413fcb51742fec0cf6f3ecccc211956a64 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 5 Mar 2024 18:53:13 -0800 Subject: [PATCH 51/74] fix normalization in example config --- examples/configs/fit_example.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index fd17071e..32e3fdee 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -39,17 +39,17 @@ data: yx_patch_size: [256, 256] normalizations: - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [source] - level: 'fov_statistics', - subtrahend: 'mean' - divisor: 'std' + init_args: + keys: [source] + level: "fov_statistics" + subtrahend: "mean" + divisor: "std" - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [target_1] - level: 'fov_statistics', - subtrahend: 'median' - divisor: 'iqr' + init_args: + keys: [target_1] + level: "fov_statistics" + subtrahend: "median" + divisor: "iqr" augmentations: - class_path: viscy.transforms.RandWeightedCropd init_args: From 96aac512c333bca6f391bdd64d99cf6f44758221 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 5 Mar 2024 18:52:52 -0800 Subject: [PATCH 52/74] prefetch more in validation --- viscy/data/hcs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index a9445ffc..ed996248 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -482,6 +482,7 @@ def val_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, + prefetch_factor=4, persistent_workers=bool(self.num_workers), ) From d9a471dd3f9ad3c3179a15657540a9c3a1375793 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 5 Mar 2024 22:16:23 -0800 Subject: [PATCH 53/74] fix collate when multi-sample transform is not used --- tests/data/test_data.py | 50 ++++++++++++++++++++++++++++++++++++++++- viscy/data/hcs.py | 14 ++++++++---- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/tests/data/test_data.py b/tests/data/test_data.py index fb3d8620..8eb06352 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -1,7 +1,7 @@ from pathlib import Path -import torch from iohub import open_ome_zarr +from monai.transforms import RandSpatialCropSamplesd from pytest import mark from viscy.data.hcs import HCSDataModule @@ -30,6 +30,54 @@ def test_preprocess(small_hcs_dataset: Path, default_channels: bool): assert "fov_statistics" in norm_metadata[channel] +@mark.parametrize("multi_sample_augmentation", [True, False]) +def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentation): + data_path = preprocessed_hcs_dataset + z_window_size = 5 + channel_split = 2 + split_ratio = 0.8 + yx_patch_size = [128, 96] + batch_size = 4 + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + if multi_sample_augmentation: + transforms = [ + RandSpatialCropSamplesd( + keys=channel_names, + roi_size=[z_window_size, *yx_patch_size], + num_samples=2, + ) + ] + else: + transforms = [] + dm = HCSDataModule( + data_path=data_path, + source_channel=channel_names[:channel_split], + target_channel=channel_names[channel_split:], + z_window_size=z_window_size, + batch_size=batch_size, + num_workers=0, + augmentations=transforms, + architecture="3D", + split_ratio=split_ratio, + yx_patch_size=yx_patch_size, + ) + dm.setup(stage="fit") + for batch in dm.train_dataloader(): + assert batch["source"].shape == ( + batch_size, + channel_split, + z_window_size, + *yx_patch_size, + ) + assert batch["target"].shape == ( + batch_size, + len(channel_names) - channel_split, + z_window_size, + *yx_patch_size, + ) + + def test_datamodule_setup_predict(preprocessed_hcs_dataset): data_path = preprocessed_hcs_dataset z_window_size = 5 diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index db1fea21..94ad4aac 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -64,11 +64,15 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: as is the case with ``train_patches_per_stack > 1``. :return Sample: Batch sample (dictionary of tensors) """ - elemment = batch[0] collated = {} - for key in elemment.keys(): - data: list[list[Tensor]] = [sample[key] for sample in batch] - collated[key] = collate_meta_tensor([im for imgs in data for im in imgs]) + for key in batch[0].keys(): + data = [] + for sample in batch: + if isinstance(sample[key], list): + data.extend(sample[key]) + else: + data.append(sample[key]) + collated[key] = collate_meta_tensor(data) return collated @@ -475,6 +479,7 @@ def train_dataloader(self): num_workers=self.num_workers, shuffle=True, persistent_workers=bool(self.num_workers), + prefetch_factor=4 if self.num_workers else None, collate_fn=_collate_samples, ) @@ -484,6 +489,7 @@ def val_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, + prefetch_factor=4 if self.num_workers else None, persistent_workers=bool(self.num_workers), ) From 669ee83d816e1f219587a2f1c1997826b3ff8b99 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 4 Mar 2024 14:29:23 -0800 Subject: [PATCH 54/74] ddp caching fixes --- viscy/data/hcs.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 94ad4aac..b7652721 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -203,10 +203,6 @@ def __getitem__(self, index: int) -> Sample: sample["target"] = self._stack_channels(sample_images, "target") return sample - def __del__(self): - """Close the Zarr store when the dataset instance gets GC'ed.""" - self.positions[0].zgroup.store.close() - class MaskTestDataset(SlidingWindowDataset): """Torch dataset where each element is a window of @@ -314,7 +310,13 @@ def __init__( self.augmentations = augmentations self.caching = caching self.ground_truth_masks = ground_truth_masks - self.tmp_zarr = None + self.prepare_data_per_node = True + + @property + def cache_path(self): + return Path( + tempfile.gettempdir(), os.getenv("SLURM_JOB_ID"), self.data_path.name + ) def prepare_data(self): if not self.caching: @@ -326,20 +328,14 @@ def prepare_data(self): console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) logger.addHandler(console_handler) - os.mkdir(self.trainer.logger.log_dir) + os.makedirs(self.trainer.logger.log_dir, exist_ok=True) file_handler = logging.FileHandler( os.path.join(self.trainer.logger.log_dir, "data.log") ) file_handler.setLevel(logging.DEBUG) logger.addHandler(file_handler) - # cache in temporary directory - self.tmp_zarr = os.path.join( - tempfile.gettempdir(), - os.getenv("SLURM_JOB_ID"), - os.path.basename(self.data_path), - ) - logger.info(f"Caching dataset at {self.tmp_zarr}.") - tmp_store = zarr.NestedDirectoryStore(self.tmp_zarr) + logger.info(f"Caching dataset at {self.cache_path}.") + tmp_store = zarr.NestedDirectoryStore(self.cache_path) with open_ome_zarr(self.data_path, mode="r") as lazy_plate: _, skipped, _ = zarr.copy( lazy_plate.zgroup, @@ -377,7 +373,7 @@ def _setup_fit(self, dataset_settings: dict): val_transform = Compose(self.normalizations + fit_transform) dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + data_path = self.cache_path if self.caching else self.data_path plate = open_ome_zarr(data_path, mode="r") # disable metadata tracking in MONAI for performance @@ -414,7 +410,7 @@ def _setup_test(self, dataset_settings: dict): logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + data_path = self.cache_path if self.cache_path else self.data_path plate = open_ome_zarr(data_path, mode="r") test_transform = Compose(self.normalizations) if self.ground_truth_masks: @@ -481,6 +477,7 @@ def train_dataloader(self): persistent_workers=bool(self.num_workers), prefetch_factor=4 if self.num_workers else None, collate_fn=_collate_samples, + drop_last=True, ) def val_dataloader(self): From b2e23b88e937646418ca82e06b1b62cc9266ab49 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 4 Mar 2024 14:29:48 -0800 Subject: [PATCH 55/74] fix caching when using combined loader --- viscy/data/combined.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 13e64e21..7fc071b8 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -40,9 +40,11 @@ def __init__( self.val_mode = CombineMode(val_mode).value self.test_mode = CombineMode(test_mode).value self.predict_mode = CombineMode(predict_mode).value + self.prepare_data_per_node = True def prepare_data(self): for dm in self.data_modules: + dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): From 8132b6879006833281fc473d7c678c9508d29e8e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 6 Mar 2024 13:37:26 -0800 Subject: [PATCH 56/74] typing fixes --- viscy/data/hcs.py | 63 +++++++++++++++++++++++------------------ viscy/data/typing.py | 67 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 90 insertions(+), 40 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index b7652721..ed4d0373 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -25,7 +25,7 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset -from viscy.data.typing import ChannelMap, Sample +from viscy.data.typing import ChannelMap, HCSStackIndex, NormMeta, Sample def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: @@ -64,11 +64,11 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: as is the case with ``train_patches_per_stack > 1``. :return Sample: Batch sample (dictionary of tensors) """ - collated = {} + collated: Sample = {} for key in batch[0].keys(): data = [] for sample in batch: - if isinstance(sample[key], list): + if isinstance(sample[key], Sequence): data.extend(sample[key]) else: data.append(sample[key]) @@ -84,7 +84,7 @@ class SlidingWindowDataset(Dataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] | None transform: a callable that transforms data, defaults to None """ @@ -93,7 +93,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, ) -> None: super().__init__() self.positions = positions @@ -116,18 +116,18 @@ def _get_windows(self) -> None: w = 0 self.window_keys = [] self.window_arrays = [] - self.window_norm_meta = [] + self.window_norm_meta: list[NormMeta | None] = [] for fov in self.positions: - img_arr = fov["0"] + img_arr: ImageArray = fov["0"] ts = img_arr.frames zs = img_arr.slices - self.z_window_size + 1 w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) - self.window_norm_meta.append(fov.zattrs.get("normalization", 0)) + self.window_norm_meta.append(fov.zattrs.get("normalization", None)) self._max_window = w - def _find_window(self, index: int) -> tuple[int, int]: + def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]: """Look up window given index.""" window_idx = sorted(self.window_keys + [index + 1]).index(index + 1) w = self.window_keys[window_idx] @@ -136,16 +136,16 @@ def _find_window(self, index: int) -> tuple[int, int]: return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta) def _read_img_window( - self, img: ImageArray, ch_idx: list[str], tz: int - ) -> tuple[tuple[Tensor], tuple[str, int, int]]: + self, img: ImageArray, ch_idx: list[int], tz: int + ) -> tuple[list[Tensor], HCSStackIndex]: """Read image window as tensor. :param ImageArray img: NGFF image array - :param list[int] channels: list of channel indices to read, + :param list[int] ch_idx: list of channel indices to read, output channel ordering will reflect the sequence :param int tz: window index within the FOV, counted Z-first - :return tuple[Tensor], tuple[str, int, int]: - tuple of (C=1, Z, Y, X) image tensors, + :return list[Tensor], HCSStackIndex: + list of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index """ zs = img.shape[-3] - self.z_window_size + 1 @@ -162,8 +162,8 @@ def __len__(self) -> int: return self._max_window def _stack_channels( - self, sample_images: list[dict[str, Tensor]], key: str - ) -> Tensor: + self, sample_images: list[dict[str, Tensor]] | dict[str, Tensor], key: str + ) -> Tensor | list[Tensor]: """Stack single-channel images into a multi-channel tensor.""" if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in self.channels[key]]) @@ -187,7 +187,8 @@ def __getitem__(self, index: int) -> Sample: # since adding a reference to a tensor does not copy # maybe write a weight map in preprocessing to use more information? sample_images["weight"] = sample_images[self.channels["target"][0]] - sample_images["norm_meta"] = norm_meta + if norm_meta is not None: + sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) # if isinstance(sample_images, list): @@ -224,7 +225,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, ground_truth_masks: str = None, ) -> None: super().__init__(positions, channels, z_window_size, transform) @@ -268,9 +269,9 @@ class HCSDataModule(LightningDataModule): defaults to "2.5D" :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) - :param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms + :param list[MapTransform] normalizations: MONAI dictionary transforms applied to selected channels, defaults to [] (no normalization) - :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms + :param list[MapTransform] augmentations: MONAI dictionary transforms applied to the training set, defaults to [] (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, @@ -291,8 +292,8 @@ def __init__( num_workers: int = 8, architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), - normalizations: Optional[list[MapTransform]] = [], - augmentations: Optional[list[MapTransform]] = [], + normalizations: list[MapTransform] = [], + augmentations: list[MapTransform] = [], caching: bool = False, ground_truth_masks: Optional[Path] = None, ): @@ -315,9 +316,20 @@ def __init__( @property def cache_path(self): return Path( - tempfile.gettempdir(), os.getenv("SLURM_JOB_ID"), self.data_path.name + tempfile.gettempdir(), + os.getenv("SLURM_JOB_ID", "viscy_cache"), + self.data_path.name, ) + def _data_log_path(self) -> Path: + log_dir = Path.cwd() + if self.trainer: + if self.trainer.logger: + if self.trainer.logger.log_dir: + log_dir = Path(self.trainer.logger.log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + return log_dir / "data.log" + def prepare_data(self): if not self.caching: return @@ -328,10 +340,7 @@ def prepare_data(self): console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) logger.addHandler(console_handler) - os.makedirs(self.trainer.logger.log_dir, exist_ok=True) - file_handler = logging.FileHandler( - os.path.join(self.trainer.logger.log_dir, "data.log") - ) + file_handler = logging.FileHandler(self._data_log_path()) file_handler.setLevel(logging.DEBUG) logger.addHandler(file_handler) logger.info(f"Caching dataset at {self.cache_path}.") diff --git a/viscy/data/typing.py b/viscy/data/typing.py index aef7dea7..1eabba75 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -1,22 +1,63 @@ -from typing import Sequence, TypedDict, Union +from __future__ import annotations -from torch import Tensor +from typing import TYPE_CHECKING, NamedTuple, Sequence, TypedDict, TypeVar + +if TYPE_CHECKING: + from torch import Tensor + +T = TypeVar("T") +OneOrSeq = T | Sequence[T] + + +class LevelNormStats(TypedDict): + mean: float + std: float + median: float + iqr: float + + +class ChannelNormStats(TypedDict): + dataset_statistics: LevelNormStats + fov_statistics: LevelNormStats + + +NormMeta = dict[str, ChannelNormStats] + + +class HCSStackIndex(NamedTuple): + """HCS stack index.""" + + # name of the image array, e.g. "A/1/0/0" + image: str + time: int + z: int class Sample(TypedDict, total=False): - """Image sample type for mini-batches.""" + """ + Image sample type for mini-batches. + All fields are optional. + """ + + index: HCSStackIndex + # Image data + source: OneOrSeq[Tensor] + target: OneOrSeq[Tensor] + weight: OneOrSeq[Tensor] + # Instance segmentation masks + labels: OneOrSeq[Tensor] + # None: not available + norm_meta: NormMeta + + +class _ChannelMap(TypedDict): + """Source channel names.""" - # all optional - index: tuple[str, int, int] - source: Union[Tensor, Sequence[Tensor]] - target: Union[Tensor, Sequence[Tensor]] - labels: Union[Tensor, Sequence[Tensor]] - norm_meta: dict[str, dict] + source: OneOrSeq[str] -class ChannelMap(TypedDict, total=False): +class ChannelMap(_ChannelMap, total=False): """Source and target channel names.""" - source: Union[str, Sequence[str]] - # optional - target: Union[str, Sequence[str]] + # TODO: use typing.NotRequired when upgrading to Python 3.11 + target: OneOrSeq[str] From 4c7a484852dca102db80cc8003193ca1af731e15 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 6 Mar 2024 13:53:21 -0800 Subject: [PATCH 57/74] fix test dataset --- viscy/data/hcs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index ed4d0373..a791ba33 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -427,13 +427,13 @@ def _setup_test(self, dataset_settings: dict): [p for _, p in plate.positions()], transform=test_transform, ground_truth_masks=self.ground_truth_masks, - norm_meta=plate.zattrs["normalization"] ** dataset_settings, + **dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( [p for _, p in plate.positions()], transform=test_transform, - norm_meta=plate.zattrs["normalization"] ** dataset_settings, + **dataset_settings, ) def _setup_predict(self, dataset_settings: dict): From 7cfe40309c220d433894c2b270dc2064eea73916 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 13 Mar 2024 16:22:25 -0700 Subject: [PATCH 58/74] fix invert transform --- viscy/transforms.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index 88e7f738..3775d154 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -166,8 +166,13 @@ class RandInvertIntensityd(MapTransform, RandomizableTransform): Randomly invert the intensity of the image. """ - def __init__(self, keys: Union[str, Iterable[str]], prob: float = 0.1) -> None: - MapTransform.__init__(self, keys) + def __init__( + self, + keys: Union[str, Iterable[str]], + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=allow_missing_keys) RandomizableTransform.__init__(self, prob) def __call__(self, sample: Sample) -> Sample: From 0b22f1aae7aca9e793401ec66d4c3e10ae6382f0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 13 Mar 2024 16:23:12 -0700 Subject: [PATCH 59/74] add ddp prepare flag for combined data module --- viscy/data/combined.py | 1 + 1 file changed, 1 insertion(+) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 7fc071b8..7f610d2f 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -84,6 +84,7 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): raise ValueError("Inconsistent number of workers") if dm.batch_size != self.batch_size: raise ValueError("Inconsistent batch size") + self.prepare_data_per_node = True def prepare_data(self): for dm in self.data_modules: From ed01065058c400654dfe65d9018a0908163a6d15 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 13 Mar 2024 16:23:43 -0700 Subject: [PATCH 60/74] remove redundant operations --- viscy/data/hcs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index a791ba33..f4d22521 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -191,14 +191,11 @@ def __getitem__(self, index: int) -> Sample: sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) - # if isinstance(sample_images, list): - # sample_images = sample_images[0] if "weight" in sample_images: del sample_images["weight"] sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), - "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") From c12fbf7e2af8ee4d96a46b6f8f559a1ce5b5155e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 13 Mar 2024 17:11:00 -0700 Subject: [PATCH 61/74] filter empty detections --- viscy/evaluation/evaluation_metrics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index fb83c06b..921b0e4e 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -126,13 +126,15 @@ def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor: """ if labels.ndim != 2: raise ValueError(f"Labels must be 2D, got shape {labels.shape}.") + segments = torch.unique(labels) + n_instances = segments.numel() - 1 masks = torch.zeros( - (labels.max(), *labels.shape), dtype=torch.bool, device=labels.device + (n_instances, *labels.shape), dtype=torch.bool, device=labels.device ) # TODO: optimize this? - for segment in range(labels.max()): + for s, segment in enumerate(segments): # start from label value 1, i.e. skip background label - masks[segment] = labels == (segment + 1) + masks[s - 1] = labels == segment return masks From f226801fd7edded00ff186d3f36112135821a079 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 14 Mar 2024 14:05:51 -0700 Subject: [PATCH 62/74] pass trainer to underlying data modules in concatenated --- viscy/data/combined.py | 1 + 1 file changed, 1 insertion(+) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 7f610d2f..1d29b53e 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -88,6 +88,7 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): def prepare_data(self): for dm in self.data_modules: + dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): From 073acf4f71fcd342b1dd85065b1e030765d408e0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 14 Mar 2024 14:06:30 -0700 Subject: [PATCH 63/74] hack: add test dataloader for LiveCell dataset --- viscy/data/livecell.py | 92 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 11 deletions(-) diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index 5d83f099..461b3195 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -3,7 +3,8 @@ import torch from lightning.pytorch import LightningDataModule -from monai.transforms import Compose, Transform +from monai.transforms import Compose, MapTransform +from pycocotools.coco import COCO from tifffile import imread from torch.utils.data import DataLoader, Dataset @@ -15,10 +16,10 @@ class LiveCellDataset(Dataset): LiveCell dataset. :param list[Path] images: List of paths to single-page, single-channel TIFF files. - :param Transform | Compose transform: Transform to apply to the dataset + :param MapTransform | Compose transform: Transform to apply to the dataset """ - def __init__(self, images: list[Path], transform: Transform | Compose) -> None: + def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None: self.images = images self.transform = transform @@ -32,14 +33,62 @@ def __getitem__(self, idx: int) -> Sample: return {"source": image, "target": image} +class LiveCellTestDataset(Dataset): + """ + LiveCell dataset. + + :param list[Path] images: List of paths to single-page, single-channel TIFF files. + :param MapTransform | Compose transform: Transform to apply to the dataset + """ + + def __init__( + self, + image_dir: Path, + transform: MapTransform | Compose, + annotations: Path, + load_target: bool = False, + load_labels: bool = False, + ) -> None: + self.image_dir = image_dir + self.transform = transform + self.coco = COCO(str(annotations)) + self.image_ids = list(self.coco.imgs.keys()) + self.load_target = load_target + self.load_labels = load_labels + + def __len__(self) -> int: + return len(self.image_ids) + + def __getitem__(self, idx: int) -> Sample: + image_id = self.image_ids[idx] + image_path = self.image_dir / self.coco.imgs[image_id]["file_name"] + image = imread(image_path)[None, None] + image = torch.from_numpy(image).to(torch.float32) + sample = Sample(source=image) + if self.load_target: + sample["target"] = image + if self.load_labels: + labels = torch.zeros_like(image) + anns = self.coco.loadAnns(self.coco.getAnnIds(image_id)) or [] + for i, ann in enumerate(anns): + mask = torch.from_numpy(self.coco.annToMask(ann)) + labels[0, 0] += mask * (i + 1) + sample["labels"] = labels + self.transform(sample) + return sample + + class LiveCellDataModule(LightningDataModule): def __init__( self, - train_val_images: Path, - train_annotations: Path, - val_annotations: Path, - train_transforms: list[Transform], - val_transforms: list[Transform], + train_val_images: Path | None = None, + test_images: Path | None = None, + train_annotations: Path | None = None, + val_annotations: Path | None = None, + test_annotations: Path | None = None, + train_transforms: list[MapTransform] = [], + val_transforms: list[MapTransform] = [], + test_transforms: list[MapTransform] = [], batch_size: int = 16, num_workers: int = 8, ) -> None: @@ -47,21 +96,29 @@ def __init__( self.train_val_images = Path(train_val_images) if not self.train_val_images.is_dir(): raise NotADirectoryError(str(train_val_images)) + self.test_images = Path(test_images) + if not self.test_images.is_dir(): + raise NotADirectoryError(str(test_images)) self.train_annotations = Path(train_annotations) if not self.train_annotations.is_file(): raise FileNotFoundError(str(train_annotations)) self.val_annotations = Path(val_annotations) if not self.val_annotations.is_file(): raise FileNotFoundError(str(val_annotations)) + self.test_annotations = Path(test_annotations) + if not self.test_annotations.is_file(): + raise FileNotFoundError(str(test_annotations)) self.train_transforms = Compose(train_transforms) self.val_transforms = Compose(val_transforms) + self.test_transforms = Compose(test_transforms) self.batch_size = batch_size self.num_workers = num_workers def setup(self, stage: str) -> None: - if stage != "fit": - raise NotImplementedError("Only fit stage is supported") - self._setup_fit() + if stage == "fit": + self._setup_fit() + elif stage == "test": + self._setup_test() def _parse_image_names(self, annotations: Path) -> list[Path]: with open(annotations) as f: @@ -80,6 +137,14 @@ def _setup_fit(self) -> None: transform=self.val_transforms, ) + def _setup_test(self) -> None: + self.test_dataset = LiveCellTestDataset( + self.test_images, + transform=self.test_transforms, + annotations=self.test_annotations, + load_labels=True, + ) + def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, @@ -96,3 +161,8 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) From 2771fdbb9dc5bb980bc47d56ba870d308081d715 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 28 Mar 2024 09:08:29 -0700 Subject: [PATCH 64/74] test datasets for livecell and ctmc --- viscy/data/ctmc_v1.py | 3 +-- viscy/data/livecell.py | 21 +++++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index d666fdcb..727f241f 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -10,9 +10,8 @@ class CTMCv1ValidationDataset(SlidingWindowDataset): - subsample_rate: int = 30 - def __len__(self) -> int: + def __len__(self, subsample_rate: int = 30) -> int: # sample every 30th frame in the videos return super().__len__() // self.subsample_rate diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index 461b3195..bb8bb56c 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -7,6 +7,7 @@ from pycocotools.coco import COCO from tifffile import imread from torch.utils.data import DataLoader, Dataset +from torchvision.ops import box_convert from viscy.data.typing import Sample @@ -61,19 +62,27 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Sample: image_id = self.image_ids[idx] - image_path = self.image_dir / self.coco.imgs[image_id]["file_name"] + file_name = self.coco.imgs[image_id]["file_name"] + image_path = self.image_dir / file_name image = imread(image_path)[None, None] image = torch.from_numpy(image).to(torch.float32) sample = Sample(source=image) if self.load_target: sample["target"] = image if self.load_labels: - labels = torch.zeros_like(image) anns = self.coco.loadAnns(self.coco.getAnnIds(image_id)) or [] - for i, ann in enumerate(anns): - mask = torch.from_numpy(self.coco.annToMask(ann)) - labels[0, 0] += mask * (i + 1) - sample["labels"] = labels + boxes = [torch.tensor(ann["bbox"]).to(torch.float32) for ann in anns] + masks = [ + torch.from_numpy(self.coco.annToMask(ann)).to(torch.bool) + for ann in anns + ] + dets = { + "boxes": box_convert(torch.stack(boxes), in_fmt="xywh", out_fmt="xyxy"), + "labels": torch.zeros(len(anns)).to(torch.uint8), + "masks": torch.stack(masks), + } + sample["detections"] = dets + sample["file_name"] = file_name self.transform(sample) return sample From 178df34ede6f4db68ff33ebb71084b9d82f3718e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Apr 2024 15:15:40 -0700 Subject: [PATCH 65/74] fix merge error --- viscy/data/combined.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 1d29b53e..31ea9f6c 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -74,6 +74,14 @@ def predict_dataloader(self): class ConcatDataModule(LightningDataModule): + """ + Concatenate multiple data modules. + The concatenated data module will have the same + batch size and number of workers as the first data module. + Each element will be sampled uniformly regardless of their original data module. + :param Sequence[LightningDataModule] data_modules: data modules to concatenate + """ + def __init__(self, data_modules: Sequence[LightningDataModule]): super().__init__() self.data_modules = data_modules From 77149e067e26185e63862b3c2daa069558dd3612 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Apr 2024 15:17:30 -0700 Subject: [PATCH 66/74] fix merge error --- viscy/data/hcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 78958d18..77bcc1ed 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -417,7 +417,7 @@ def _setup_test(self, dataset_settings: dict): logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") dataset_settings["channels"]["target"] = self.target_channel - data_path = self.cache_path if self.cache_path else self.data_path + data_path = self.cache_path if self.caching else self.data_path plate = open_ome_zarr(data_path, mode="r") test_transform = Compose(self.normalizations) if self.ground_truth_masks: From 3b1ff5c052bb214716f6b30ef6ca7ed586fea070 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 22 Apr 2024 07:52:57 -0700 Subject: [PATCH 67/74] fix mAP default for over 100 detections --- viscy/evaluation/evaluation_metrics.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index 921b0e4e..bb89858f 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -9,7 +9,7 @@ from monai.metrics.regression import compute_ssim_and_cs from scipy.optimize import linear_sum_assignment from skimage.measure import label, regionprops -from torchmetrics.detection import MeanAveragePrecision +from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.ops import masks_to_boxes @@ -172,7 +172,12 @@ def mean_average_precision( :py:class:`torchmetrics.detection.MeanAveragePrecision` :return dict[str, torch.Tensor]: COCO-style metrics """ - map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", **kwargs) + defaults = dict( + iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000] + ) + if not kwargs: + kwargs = {} + map_metric = MeanAveragePrecision(**(defaults | kwargs)) map_metric.update( [labels_to_detection(pred_labels)], [labels_to_detection(target_labels)] ) From 31522aeb7a062d555a5220cb53f66b67d8283275 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 22 Apr 2024 07:53:06 -0700 Subject: [PATCH 68/74] bump torchmetric --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8f6978de..5f0a184f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dynamic = ["version"] metrics = [ "cellpose==2.1.0", "scikit-learn>=1.1.3", - "torchmetrics[detection]>=1.0.0", + "torchmetrics[detection]>=1.3.1", "ptflops>=0.7", ] visual = ["ipykernel", "graphviz", "torchview"] From bf1b9d39bf97078531d3a28bf6c3e729ef75a353 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 22 Apr 2024 07:53:34 -0700 Subject: [PATCH 69/74] fix combined loader training for virtual staining task --- viscy/light/engine.py | 61 ++++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 3dc92b74..3ab642f7 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -150,6 +150,7 @@ def __init__( self.log_batches_per_epoch = log_batches_per_epoch self.log_samples_per_batch = log_samples_per_batch self.training_step_outputs = [] + self.validation_losses = [] self.validation_step_outputs = [] # required to log the graph if architecture == "2D": @@ -175,31 +176,46 @@ def forward(self, x: Tensor) -> Tensor: return self.model(x) def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] - target = batch["target"] - pred = self.forward(source) - loss = self.loss_function(pred, target) + losses = [] + batch_size = 0 + for b in batch: + source = b["source"] + target = b["target"] + pred = self.forward(source) + loss = self.loss_function(pred, target) + losses.append(loss) + batch_size += source.shape[0] + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + loss_step = torch.stack(losses).mean() self.log( "loss/train", - loss, + loss_step.to(self.device), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, + batch_size=batch_size, ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target, pred)) - ) - return loss + return loss_step def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = batch["source"] - target = batch["target"] + source: Tensor = batch["source"] + target: Tensor = batch["target"] pred = self.forward(source) loss = self.loss_function(pred, target) - self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) + if dataloader_idx + 1 > len(self.validation_losses): + self.validation_losses.append([]) + self.validation_losses[dataloader_idx].append(loss.detach()) + self.log( + f"loss/val/{dataloader_idx}", + loss.to(self.device), + sync_dist=True, + batch_size=source.shape[0], + ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) @@ -309,8 +325,16 @@ def on_train_epoch_end(self): self.training_step_outputs = [] def on_validation_epoch_end(self): + super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) self.validation_step_outputs = [] + # average within each dataloader + loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] + self.log( + "loss/validate", + torch.tensor(loss_means).mean().to(self.device), + sync_dist=True, + ) def on_test_start(self): """Load CellPose model for segmentation.""" @@ -386,7 +410,6 @@ class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio - self.validation_losses = [] def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) @@ -438,13 +461,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) - - def on_validation_epoch_end(self): - super().on_validation_epoch_end() - # average within each dataloader - loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] - self.log( - "loss/validate", - torch.tensor(loss_means).mean().to(self.device), - sync_dist=True, - ) From d2a63c12b52b4e3cdc6707cdf56902ba018c4e98 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 23 Apr 2024 22:21:05 -0700 Subject: [PATCH 70/74] fix non-combined data loader training --- viscy/light/engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 3ab642f7..927165be 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -175,9 +175,11 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) - def training_step(self, batch: Sample, batch_idx: int): + def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): losses = [] batch_size = 0 + if not isinstance(batch, Sequence): + batch = [batch] for b in batch: source = b["source"] target = b["target"] From bd29616c637e13276eddf191034e6c87a91270b0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Sat, 4 May 2024 17:13:52 -0700 Subject: [PATCH 71/74] add fcmae to graph script --- viscy/scripts/network_diagram.py | 39 +++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index 419d69a4..c8491c49 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -1,7 +1,7 @@ # %% from torchview import draw_graph -from viscy.light.engine import VSUNet +from viscy.light.engine import FcmaeUNet, VSUNet # %% 2D UNet model = VSUNet( @@ -93,3 +93,40 @@ graph22d # %% If you want to save the graphs as SVG files: # model_graph.visual_graph.render(format="svg") + +# %% +model = FcmaeUNet( + model_config=dict( + in_channels=1, + out_channels=1, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=1, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + ), + fit_mask_ratio=0.5, + schedule="WarmupCosine", + lr=2e-4, + log_batches_per_epoch=2, + log_samples_per_batch=2, +) + +model_graph = draw_graph( + model, + (model.example_input_array), + graph_name="VSCyto2D", + roll=True, + depth=3, +) + +fcmae = model_graph.visual_graph +fcmae + +# %% + +model_graph.visual_graph.render( + format="svg", +) + +# %% From b98c34cae7aea2536ba15a1addc1463248551d16 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 5 Jun 2024 05:23:03 -0700 Subject: [PATCH 72/74] fix type hint --- viscy/unet/networks/fcmae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 821f8421..aca710e9 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -339,7 +339,7 @@ def __init__( self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) self.apply(_init_weights) - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> tuple[list[Tensor], BoolTensor | None]: """ :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, From 80521891dc92141b9cddfd4e5b3ac226ffce8b8c Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 5 Jun 2024 05:36:04 -0700 Subject: [PATCH 73/74] format --- viscy/unet/networks/fcmae.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 9176985f..47a88b46 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -339,7 +339,9 @@ def __init__( self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) self.apply(_init_weights) - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> tuple[list[Tensor], BoolTensor | None]: + def forward( + self, x: Tensor, mask_ratio: float = 0.0 + ) -> tuple[list[Tensor], BoolTensor | None]: """ :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, From bbf22fb678c7472df91de9327dc81f657b1a7473 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 5 Jun 2024 06:10:13 -0700 Subject: [PATCH 74/74] add back convolutiuon option for fcmae head --- tests/unet/test_fcmae.py | 15 ++++++++++++++- viscy/unet/networks/fcmae.py | 37 +++++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 9f3fc805..f22efa4c 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -2,11 +2,11 @@ from viscy.unet.networks.fcmae import ( FullyConvolutionalMAE, - PixelToVoxelShuffleHead, MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedMultiscaleEncoder, + PixelToVoxelShuffleHead, generate_mask, masked_patchify, masked_unpatchify, @@ -121,3 +121,16 @@ def test_fcmae(): y, m = model(x, mask_ratio=0.6) assert y.shape == x.shape assert m.shape == (2, 1, 128, 128) + + +def test_fcmae_head_conv(): + x = torch.rand(2, 3, 5, 128, 128) + model = FullyConvolutionalMAE( + 3, 3, head_conv=True, head_conv_expansion_ratio=4, head_conv_pool=True + ) + y, m = model(x) + assert y.shape == x.shape + assert m is None + y, m = model(x, mask_ratio=0.6) + assert y.shape == x.shape + assert m.shape == (2, 1, 128, 128) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 47a88b46..6c2f6f45 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -20,7 +20,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.Unet22D import Unet2dDecoder +from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder def _init_weights(module: nn.Module) -> None: @@ -407,6 +407,9 @@ def __init__( in_stack_depth: int = 5, decoder_conv_blocks: int = 1, pretraining: bool = True, + head_conv: bool = False, + head_conv_expansion_ratio: int = 4, + head_conv_pool: bool = True, ) -> None: super().__init__() self.encoder = MaskedMultiscaleEncoder( @@ -419,7 +422,14 @@ def __init__( ) decoder_channels = list(dims) decoder_channels.reverse() - decoder_channels[-1] = out_channels * in_stack_depth * stem_kernel_size[-1] ** 2 + if head_conv: + decoder_channels[-1] = ( + (in_stack_depth + 2) * in_channels * 2**2 * head_conv_expansion_ratio + ) + else: + decoder_channels[-1] = ( + out_channels * in_stack_depth * stem_kernel_size[-1] ** 2 + ) self.decoder = Unet2dDecoder( decoder_channels, norm_name="instance", @@ -428,13 +438,22 @@ def __init__( strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], upsample_pre_conv=None, ) - self.head = PixelToVoxelShuffleHead( - in_channels=decoder_channels[-1], - out_channels=out_channels, - out_stack_depth=in_stack_depth, - xy_scaling=stem_kernel_size[-1], - pool=True, - ) + if head_conv: + self.head = PixelToVoxelHead( + in_channels=decoder_channels[-1], + out_channels=out_channels, + out_stack_depth=in_stack_depth, + expansion_ratio=head_conv_expansion_ratio, + pool=head_conv_pool, + ) + else: + self.head = PixelToVoxelShuffleHead( + in_channels=decoder_channels[-1], + out_channels=out_channels, + out_stack_depth=in_stack_depth, + xy_scaling=stem_kernel_size[-1], + pool=True, + ) self.out_stack_depth = in_stack_depth # TODO: replace num_blocks with explicit strides for all models self.num_blocks = len(dims) * int(math.log2(stem_kernel_size[-1]))