Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D FCMAE #71

Merged
merged 80 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
c6692f1
refactor data loading into its own module
ziw-liu Jan 10, 2024
3d8e7e2
update type annotations
ziw-liu Jan 10, 2024
fdcbf55
move the logging module out
ziw-liu Jan 11, 2024
a291381
move old logging into utils
ziw-liu Jan 11, 2024
3cf8fa2
rename tests to match module name
ziw-liu Jan 11, 2024
d4cd41d
bump torch
ziw-liu Jan 11, 2024
e87d396
draft fcmae encoder
ziw-liu Jan 12, 2024
dccce5f
add stem to the encoder
ziw-liu Jan 12, 2024
5508731
wip: masked stem layernorm
ziw-liu Jan 12, 2024
3eec48e
wip: patchify masked features for linear
ziw-liu Jan 17, 2024
8c54feb
use mlp from timm
ziw-liu Jan 17, 2024
83ecf4a
hack: POC training script for FCMAE
ziw-liu Jan 17, 2024
2fffc99
fix mask for fitting
ziw-liu Jan 17, 2024
2a598b2
remove training script
ziw-liu Jan 17, 2024
b9b1880
default architecture
ziw-liu Jan 17, 2024
fd7700d
fine-tuning options
ziw-liu Jan 22, 2024
054249f
fix cli for finetuning
ziw-liu Jan 24, 2024
d867e10
draft combined data module
ziw-liu Jan 24, 2024
b06a300
fix import
ziw-liu Jan 25, 2024
39eafab
manual validation loss reduction
ziw-liu Jan 27, 2024
9fbf7a5
update linting
ziw-liu Feb 2, 2024
e00f5f3
update development guide
ziw-liu Feb 2, 2024
9e345b6
update type hints
ziw-liu Feb 13, 2024
96deca5
bump iohub
ziw-liu Feb 20, 2024
e06aa57
draft ctmc v1 dataset
ziw-liu Feb 24, 2024
ea8b300
Merge branch 'main' into fcmae
ziw-liu Feb 24, 2024
72de113
update tests
ziw-liu Feb 24, 2024
13d0aa0
move test_data
ziw-liu Feb 24, 2024
78aed97
remove path conversion
ziw-liu Feb 24, 2024
74e7db3
configurable normalizations (#68)
edyoshikun Feb 26, 2024
9b3b032
fix ctmc dataloading
ziw-liu Feb 28, 2024
a356936
add example ctmc v1 loading script
ziw-liu Feb 28, 2024
bac26be
changing the normalization and augmentations default from None to emp…
edyoshikun Feb 28, 2024
0b598c7
invert intensity transform
ziw-liu Feb 29, 2024
ddb30e9
concatenated data module
ziw-liu Feb 29, 2024
9504755
subsample videos
ziw-liu Feb 29, 2024
808e39c
livecell dataset
ziw-liu Feb 29, 2024
43d641d
all sample fields are optional
ziw-liu Feb 29, 2024
42f81cf
fix multi-dataloader validation
ziw-liu Feb 29, 2024
4546fc7
lint
ziw-liu Feb 29, 2024
306f3ef
fixing preprocessing for varying array shapes (i.e aics dataset)
edyoshikun Feb 29, 2024
1a0e3ce
update loading scripts
ziw-liu Mar 2, 2024
d3ec94d
fix CombineMode
ziw-liu Mar 2, 2024
02e6d0b
always use untrainable head for FCMAE
ziw-liu Mar 2, 2024
e18d305
move log values to GPU before syncing
ziw-liu Mar 2, 2024
01c71cf
custom head
ziw-liu Mar 2, 2024
dd64b31
ddp caching fixes
ziw-liu Mar 4, 2024
b3ea8d7
fix caching when using combined loader
ziw-liu Mar 4, 2024
d3db2bb
compose normalizations for predict and test stages
ziw-liu Mar 4, 2024
d5a3fd6
Merge branch 'fcmae' into 2d-fcmae
ziw-liu Mar 4, 2024
a549d4e
black
ziw-liu Mar 4, 2024
d74e731
Merge branch 'fcmae' into 2d-fcmae
ziw-liu Mar 4, 2024
a38da8b
fix normalization in example config
ziw-liu Mar 6, 2024
af317c4
fix normalization in example config
ziw-liu Mar 6, 2024
96aac51
prefetch more in validation
ziw-liu Mar 6, 2024
d9a471d
fix collate when multi-sample transform is not used
ziw-liu Mar 6, 2024
669ee83
ddp caching fixes
ziw-liu Mar 4, 2024
b2e23b8
fix caching when using combined loader
ziw-liu Mar 4, 2024
acdf362
Merge branch 'fcmae' into 2d-fcmae
ziw-liu Mar 6, 2024
8132b68
typing fixes
ziw-liu Mar 6, 2024
4c7a484
fix test dataset
ziw-liu Mar 6, 2024
7cfe403
fix invert transform
ziw-liu Mar 13, 2024
0b22f1a
add ddp prepare flag for combined data module
ziw-liu Mar 13, 2024
ed01065
remove redundant operations
ziw-liu Mar 13, 2024
c12fbf7
filter empty detections
ziw-liu Mar 14, 2024
f226801
pass trainer to underlying data modules in concatenated
ziw-liu Mar 14, 2024
073acf4
hack: add test dataloader for LiveCell dataset
ziw-liu Mar 14, 2024
2771fdb
test datasets for livecell and ctmc
ziw-liu Mar 28, 2024
1732974
Merge branch 'main' into 2d-fcmae
ziw-liu Apr 12, 2024
178df34
fix merge error
ziw-liu Apr 12, 2024
77149e0
fix merge error
ziw-liu Apr 12, 2024
3b1ff5c
fix mAP default for over 100 detections
ziw-liu Apr 22, 2024
31522ae
bump torchmetric
ziw-liu Apr 22, 2024
bf1b9d3
fix combined loader training for virtual staining task
ziw-liu Apr 22, 2024
d2a63c1
fix non-combined data loader training
ziw-liu Apr 24, 2024
bd29616
add fcmae to graph script
ziw-liu May 5, 2024
b98c34c
fix type hint
ziw-liu Jun 5, 2024
464ae0c
Merge branch 'main' into 2d-fcmae
ziw-liu Jun 5, 2024
8052189
format
ziw-liu Jun 5, 2024
bbf22fb
add back convolutiuon option for fcmae head
ziw-liu Jun 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dynamic = ["version"]
metrics = [
"cellpose==2.1.0",
"scikit-learn>=1.1.3",
"torchmetrics[detection]>=1.0.0",
"torchmetrics[detection]>=1.3.1",
"ptflops>=0.7",
]
visual = ["ipykernel", "graphviz", "torchview"]
Expand Down
21 changes: 21 additions & 0 deletions tests/unet/test_fcmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MaskedConvNeXtV2Block,
MaskedConvNeXtV2Stage,
MaskedMultiscaleEncoder,
PixelToVoxelShuffleHead,
generate_mask,
masked_patchify,
masked_unpatchify,
Expand Down Expand Up @@ -104,6 +105,13 @@ def test_masked_multiscale_encoder():
assert afeat.shape[2] == afeat.shape[3] == xy_size // stride


def test_pixel_to_voxel_shuffle_head():
head = PixelToVoxelShuffleHead(240, 3, out_stack_depth=5, xy_scaling=4)
x = torch.rand(2, 240, 16, 16)
y = head(x)
assert y.shape == (2, 3, 5, 64, 64)


def test_fcmae():
x = torch.rand(2, 3, 5, 128, 128)
model = FullyConvolutionalMAE(3, 3)
Expand All @@ -113,3 +121,16 @@ def test_fcmae():
y, m = model(x, mask_ratio=0.6)
assert y.shape == x.shape
assert m.shape == (2, 1, 128, 128)


def test_fcmae_head_conv():
x = torch.rand(2, 3, 5, 128, 128)
model = FullyConvolutionalMAE(
3, 3, head_conv=True, head_conv_expansion_ratio=4, head_conv_pool=True
)
y, m = model(x)
assert y.shape == x.shape
assert m is None
y, m = model(x, mask_ratio=0.6)
assert y.shape == x.shape
assert m.shape == (2, 1, 128, 128)
3 changes: 2 additions & 1 deletion viscy/data/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class ConcatDataModule(LightningDataModule):
The concatenated data module will have the same
batch size and number of workers as the first data module.
Each element will be sampled uniformly regardless of their original data module.

:param Sequence[LightningDataModule] data_modules: data modules to concatenate
"""

Expand All @@ -93,9 +92,11 @@ def __init__(self, data_modules: Sequence[LightningDataModule]):
raise ValueError("Inconsistent number of workers")
if dm.batch_size != self.batch_size:
raise ValueError("Inconsistent batch size")
self.prepare_data_per_node = True

def prepare_data(self):
for dm in self.data_modules:
dm.trainer = self.trainer
dm.prepare_data()

def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
Expand Down
3 changes: 1 addition & 2 deletions viscy/data/ctmc_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@


class CTMCv1ValidationDataset(SlidingWindowDataset):
subsample_rate: int = 30

def __len__(self) -> int:
def __len__(self, subsample_rate: int = 30) -> int:
# sample every 30th frame in the videos
return super().__len__() // self.subsample_rate

Expand Down
2 changes: 0 additions & 2 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def __getitem__(self, index: int) -> Sample:
sample_images["norm_meta"] = norm_meta
if self.transform:
sample_images = self.transform(sample_images)
# if isinstance(sample_images, list):
# sample_images = sample_images[0]
if "weight" in sample_images:
del sample_images["weight"]
sample = {
Expand Down
101 changes: 90 additions & 11 deletions viscy/data/livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import torch
from lightning.pytorch import LightningDataModule
from monai.transforms import Compose, Transform
from monai.transforms import Compose, MapTransform
from pycocotools.coco import COCO
from tifffile import imread
from torch.utils.data import DataLoader, Dataset
from torchvision.ops import box_convert

from viscy.data.typing import Sample

Expand All @@ -15,10 +17,10 @@ class LiveCellDataset(Dataset):
LiveCell dataset.

:param list[Path] images: List of paths to single-page, single-channel TIFF files.
:param Transform | Compose transform: Transform to apply to the dataset
:param MapTransform | Compose transform: Transform to apply to the dataset
"""

def __init__(self, images: list[Path], transform: Transform | Compose) -> None:
def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None:
self.images = images
self.transform = transform

Expand All @@ -32,36 +34,100 @@ def __getitem__(self, idx: int) -> Sample:
return {"source": image, "target": image}


class LiveCellTestDataset(Dataset):
"""
LiveCell dataset.

:param list[Path] images: List of paths to single-page, single-channel TIFF files.
:param MapTransform | Compose transform: Transform to apply to the dataset
"""

def __init__(
self,
image_dir: Path,
transform: MapTransform | Compose,
annotations: Path,
load_target: bool = False,
load_labels: bool = False,
) -> None:
self.image_dir = image_dir
self.transform = transform
self.coco = COCO(str(annotations))
self.image_ids = list(self.coco.imgs.keys())
self.load_target = load_target
self.load_labels = load_labels

def __len__(self) -> int:
return len(self.image_ids)

def __getitem__(self, idx: int) -> Sample:
image_id = self.image_ids[idx]
file_name = self.coco.imgs[image_id]["file_name"]
image_path = self.image_dir / file_name
image = imread(image_path)[None, None]
image = torch.from_numpy(image).to(torch.float32)
sample = Sample(source=image)
if self.load_target:
sample["target"] = image
if self.load_labels:
anns = self.coco.loadAnns(self.coco.getAnnIds(image_id)) or []
boxes = [torch.tensor(ann["bbox"]).to(torch.float32) for ann in anns]
masks = [
torch.from_numpy(self.coco.annToMask(ann)).to(torch.bool)
for ann in anns
]
dets = {
"boxes": box_convert(torch.stack(boxes), in_fmt="xywh", out_fmt="xyxy"),
"labels": torch.zeros(len(anns)).to(torch.uint8),
"masks": torch.stack(masks),
}
sample["detections"] = dets
sample["file_name"] = file_name
self.transform(sample)
return sample


class LiveCellDataModule(LightningDataModule):
def __init__(
self,
train_val_images: Path,
train_annotations: Path,
val_annotations: Path,
train_transforms: list[Transform],
val_transforms: list[Transform],
train_val_images: Path | None = None,
test_images: Path | None = None,
train_annotations: Path | None = None,
val_annotations: Path | None = None,
test_annotations: Path | None = None,
train_transforms: list[MapTransform] = [],
val_transforms: list[MapTransform] = [],
test_transforms: list[MapTransform] = [],
batch_size: int = 16,
num_workers: int = 8,
) -> None:
super().__init__()
self.train_val_images = Path(train_val_images)
if not self.train_val_images.is_dir():
raise NotADirectoryError(str(train_val_images))
self.test_images = Path(test_images)
if not self.test_images.is_dir():
raise NotADirectoryError(str(test_images))
self.train_annotations = Path(train_annotations)
if not self.train_annotations.is_file():
raise FileNotFoundError(str(train_annotations))
self.val_annotations = Path(val_annotations)
if not self.val_annotations.is_file():
raise FileNotFoundError(str(val_annotations))
self.test_annotations = Path(test_annotations)
if not self.test_annotations.is_file():
raise FileNotFoundError(str(test_annotations))
self.train_transforms = Compose(train_transforms)
self.val_transforms = Compose(val_transforms)
self.test_transforms = Compose(test_transforms)
self.batch_size = batch_size
self.num_workers = num_workers

def setup(self, stage: str) -> None:
if stage != "fit":
raise NotImplementedError("Only fit stage is supported")
self._setup_fit()
if stage == "fit":
self._setup_fit()
elif stage == "test":
self._setup_test()

def _parse_image_names(self, annotations: Path) -> list[Path]:
with open(annotations) as f:
Expand All @@ -80,6 +146,14 @@ def _setup_fit(self) -> None:
transform=self.val_transforms,
)

def _setup_test(self) -> None:
self.test_dataset = LiveCellTestDataset(
self.test_images,
transform=self.test_transforms,
annotations=self.test_annotations,
load_labels=True,
)

def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
Expand All @@ -96,3 +170,8 @@ def val_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
)

def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
9 changes: 7 additions & 2 deletions viscy/evaluation/evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from monai.metrics.regression import compute_ssim_and_cs
from scipy.optimize import linear_sum_assignment
from skimage.measure import label, regionprops
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.ops import masks_to_boxes


Expand Down Expand Up @@ -172,7 +172,12 @@ def mean_average_precision(
:py:class:`torchmetrics.detection.MeanAveragePrecision`
:return dict[str, torch.Tensor]: COCO-style metrics
"""
map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", **kwargs)
defaults = dict(
iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000]
)
if not kwargs:
kwargs = {}
map_metric = MeanAveragePrecision(**(defaults | kwargs))
map_metric.update(
[labels_to_detection(pred_labels)], [labels_to_detection(target_labels)]
)
Expand Down
65 changes: 40 additions & 25 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
self.log_batches_per_epoch = log_batches_per_epoch
self.log_samples_per_batch = log_samples_per_batch
self.training_step_outputs = []
self.validation_losses = []
self.validation_step_outputs = []
# required to log the graph
if architecture == "2D":
Expand All @@ -170,32 +171,49 @@ def __init__(
def forward(self, x: Tensor) -> Tensor:
return self.model(x)

def training_step(self, batch: Sample, batch_idx: int):
source = batch["source"]
target = batch["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int):
losses = []
batch_size = 0
if not isinstance(batch, Sequence):
batch = [batch]
for b in batch:
source = b["source"]
target = b["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
losses.append(loss)
batch_size += source.shape[0]
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
self._detach_sample((source, target, pred))
)
loss_step = torch.stack(losses).mean()
self.log(
"loss/train",
loss,
loss_step.to(self.device),
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
batch_size=batch_size,
)
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
self._detach_sample((source, target, pred))
)
return loss
return loss_step

def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
source = batch["source"]
target = batch["target"]
source: Tensor = batch["source"]
target: Tensor = batch["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False)
if dataloader_idx + 1 > len(self.validation_losses):
self.validation_losses.append([])
self.validation_losses[dataloader_idx].append(loss.detach())
self.log(
f"loss/val/{dataloader_idx}",
loss.to(self.device),
sync_dist=True,
batch_size=source.shape[0],
)
if batch_idx < self.log_batches_per_epoch:
self.validation_step_outputs.extend(
self._detach_sample((source, target, pred))
Expand Down Expand Up @@ -305,8 +323,16 @@ def on_train_epoch_end(self):
self.training_step_outputs = []

def on_validation_epoch_end(self):
super().on_validation_epoch_end()
self._log_samples("val_samples", self.validation_step_outputs)
self.validation_step_outputs = []
# average within each dataloader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I know that this is the end of the validation meaning that it won't backprop and therefore maybe no need to detach the tensor before doing any logging. Is this the common practice? I know that for the train_step and validation_step is more relevant and important If we don't detach here it doesnt affect it any way? Just curious..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To log the loss value (through lightning logger), the detaching is automatic. We only need to care about it when logging manually (images).

loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
self.log(
"loss/validate",
torch.tensor(loss_means).mean().to(self.device),
sync_dist=True,
)

def on_test_start(self):
"""Load CellPose model for segmentation."""
Expand Down Expand Up @@ -382,7 +408,6 @@ class FcmaeUNet(VSUNet):
def __init__(self, fit_mask_ratio: float = 0.0, **kwargs):
super().__init__(architecture="fcmae", **kwargs)
self.fit_mask_ratio = fit_mask_ratio
self.validation_losses = []

def forward(self, x: Tensor, mask_ratio: float = 0.0):
return self.model(x, mask_ratio)
Expand Down Expand Up @@ -434,13 +459,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0
self.validation_step_outputs.extend(
self._detach_sample((source, target * mask.unsqueeze(2), pred))
)

def on_validation_epoch_end(self):
super().on_validation_epoch_end()
# average within each dataloader
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
self.log(
"loss/validate",
torch.tensor(loss_means).mean().to(self.device),
sync_dist=True,
)
Loading
Loading