-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Masked autoencoder pre-training for virtual staining models (#67)
* refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu <ziwen.liu@czbiohub.org> * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * compose normalizations for predict and test stages * black * fix normalization in example config * fix collate when multi-sample transform is not used * ddp caching fixes * fix caching when using combined loader * move log values to GPU before syncing Lightning-AI/pytorch-lightning#18803 * removing normalize_source from configs. * typing fixes * fix test data path * fix test dataset * add docstring for ConcatDataModule * format --------- Co-authored-by: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
- Loading branch information
1 parent
41b9d10
commit f924cbc
Showing
36 changed files
with
1,527 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,6 @@ data: | |
- 256 | ||
- 256 | ||
caching: false | ||
normalize_source: false | ||
ground_truth_masks: null | ||
ckpt_path: null | ||
verbose: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from pathlib import Path | ||
|
||
from iohub import open_ome_zarr | ||
from monai.transforms import RandSpatialCropSamplesd | ||
from pytest import mark | ||
|
||
from viscy.data.hcs import HCSDataModule | ||
from viscy.light.trainer import VSTrainer | ||
|
||
|
||
@mark.parametrize("default_channels", [True, False]) | ||
def test_preprocess(small_hcs_dataset: Path, default_channels: bool): | ||
data_path = small_hcs_dataset | ||
if default_channels: | ||
channel_names = -1 | ||
else: | ||
with open_ome_zarr(data_path) as dataset: | ||
channel_names = dataset.channel_names | ||
trainer = VSTrainer(accelerator="cpu") | ||
trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) | ||
with open_ome_zarr(data_path) as dataset: | ||
channel_names = dataset.channel_names | ||
for channel in channel_names: | ||
assert "dataset_statistics" in dataset.zattrs["normalization"][channel] | ||
for _, fov in dataset.positions(): | ||
norm_metadata = fov.zattrs["normalization"] | ||
for channel in channel_names: | ||
assert channel in norm_metadata | ||
assert "dataset_statistics" in norm_metadata[channel] | ||
assert "fov_statistics" in norm_metadata[channel] | ||
|
||
|
||
@mark.parametrize("multi_sample_augmentation", [True, False]) | ||
def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentation): | ||
data_path = preprocessed_hcs_dataset | ||
z_window_size = 5 | ||
channel_split = 2 | ||
split_ratio = 0.8 | ||
yx_patch_size = [128, 96] | ||
batch_size = 4 | ||
with open_ome_zarr(data_path) as dataset: | ||
channel_names = dataset.channel_names | ||
if multi_sample_augmentation: | ||
transforms = [ | ||
RandSpatialCropSamplesd( | ||
keys=channel_names, | ||
roi_size=[z_window_size, *yx_patch_size], | ||
num_samples=2, | ||
) | ||
] | ||
else: | ||
transforms = [] | ||
dm = HCSDataModule( | ||
data_path=data_path, | ||
source_channel=channel_names[:channel_split], | ||
target_channel=channel_names[channel_split:], | ||
z_window_size=z_window_size, | ||
batch_size=batch_size, | ||
num_workers=0, | ||
augmentations=transforms, | ||
architecture="3D", | ||
split_ratio=split_ratio, | ||
yx_patch_size=yx_patch_size, | ||
) | ||
dm.setup(stage="fit") | ||
for batch in dm.train_dataloader(): | ||
assert batch["source"].shape == ( | ||
batch_size, | ||
channel_split, | ||
z_window_size, | ||
*yx_patch_size, | ||
) | ||
assert batch["target"].shape == ( | ||
batch_size, | ||
len(channel_names) - channel_split, | ||
z_window_size, | ||
*yx_patch_size, | ||
) | ||
|
||
|
||
def test_datamodule_setup_predict(preprocessed_hcs_dataset): | ||
data_path = preprocessed_hcs_dataset | ||
z_window_size = 5 | ||
channel_split = 2 | ||
with open_ome_zarr(data_path) as dataset: | ||
channel_names = dataset.channel_names | ||
img = next(dataset.positions())[1][0] | ||
total_p = len(list(dataset.positions())) | ||
dm = HCSDataModule( | ||
data_path=data_path, | ||
source_channel=channel_names[:channel_split], | ||
target_channel=channel_names[channel_split:], | ||
z_window_size=z_window_size, | ||
batch_size=2, | ||
num_workers=0, | ||
) | ||
dm.setup(stage="predict") | ||
dataset = dm.predict_dataset | ||
assert len(dataset) == total_p * 2 * (img.slices - z_window_size + 1) | ||
assert dataset[0]["source"].shape == ( | ||
channel_split, | ||
z_window_size, | ||
img.height, | ||
img.width, | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from viscy.light.engine import FcmaeUNet | ||
|
||
|
||
def test_fcmae_vsunet() -> None: | ||
model = FcmaeUNet( | ||
model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6 | ||
) |
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.