Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axes #103

Merged
merged 38 commits into from
Mar 5, 2024
Merged

Axes #103

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
10848ad
l nb draft
CatEek Feb 6, 2024
214d40f
(WIP) Fix some config tests
jdeschamps Feb 12, 2024
81fa5de
Use StrEnum
jdeschamps Feb 12, 2024
19b5e71
Fix config tests, cleanup
jdeschamps Feb 13, 2024
b75d150
Config draft and tests
jdeschamps Feb 13, 2024
3f41cea
median masking wip
CatEek Feb 13, 2024
47f4e1a
Refactor Ligthning modules
jdeschamps Feb 14, 2024
204c21a
API with dispatch
jdeschamps Feb 14, 2024
c12f4a4
Add Custom model
jdeschamps Feb 15, 2024
08a7437
Fix issue with str in train dispatch
jdeschamps Feb 15, 2024
534d1d7
Fix issue with str in predict dispatch
jdeschamps Feb 15, 2024
993f83e
Add array entry point to CAREamist
jdeschamps Feb 19, 2024
93aee11
Fix error, and tests
jdeschamps Feb 19, 2024
73ebac4
Fix oversight in method dispatch
jdeschamps Feb 20, 2024
b5e1ef9
Add BaseEnum to use with Enum values
jdeschamps Feb 20, 2024
8a5388c
Raise more errors and improve error messages
jdeschamps Feb 20, 2024
de56191
Small refactoring, remove unused transform
jdeschamps Feb 20, 2024
9fcf1f0
Add test for uniform_manipulate
jdeschamps Feb 20, 2024
2fbc574
Add ManipulateN2V test
jdeschamps Feb 20, 2024
6948238
Refactor read_tiff, prevent transform error
jdeschamps Feb 22, 2024
89ac2bd
Add tests and fix call error
jdeschamps Feb 22, 2024
a763634
Add val splitting to the datasets
jdeschamps Feb 22, 2024
1a0c5e9
Add 3D compatible transforms
jdeschamps Feb 23, 2024
0a82650
iter median
CatEek Feb 26, 2024
a3781fb
Merge remote-tracking branch 'origin/refactoring' into axes
CatEek Feb 26, 2024
7944257
Iterable pipeline with simple test
jdeschamps Feb 26, 2024
4ee1450
Fix error in C axes for array
jdeschamps Feb 26, 2024
d6ef7a8
struct mask draft
CatEek Feb 28, 2024
4eb2e30
apply struct comments
CatEek Feb 28, 2024
a240d2c
Add test and small patching refactoring
jdeschamps Feb 28, 2024
53c55dc
apply struct tests
CatEek Feb 29, 2024
85da0a7
Merge remote-tracking branch 'origin/refactoring' into axes
CatEek Feb 29, 2024
973c073
Fix call to median
jdeschamps Mar 5, 2024
95a0fb7
strat coords fix
CatEek Mar 5, 2024
9ef6091
Merge remote-tracking branch 'origin/refactoring' into axes
CatEek Mar 5, 2024
d755427
lightning nb + transforms small refac
CatEek Mar 5, 2024
c371473
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Mar 5, 2024
3280564
Quick CI fix: Add dependency for test
jdeschamps Mar 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions examples/2D/n2v/example_BSD68_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@
"source": [
"from pathlib import Path\n",
"\n",
"import tifffile\n",
"import matplotlib.pyplot as plt\n",
"import tifffile\n",
"from careamics_portfolio import PortfolioManager\n",
"from pytorch_lightning import Trainer\n",
"import albumentations as Aug\n",
"\n",
"from careamics_portfolio import PortfolioManager\n",
"from careamics.lightning_module import (\n",
" CAREamicsModule,\n",
" CAREamicsTrainDataModule,\n",
"from careamics import CAREamicsModule\n",
"from careamics.lightning_prediction import CAREamicsFiring\n",
"from careamics.ligthning_datamodule import (\n",
" CAREamicsPredictDataModule,\n",
" CAREamicsFiring,\n",
" CAREamicsTrainDataModule,\n",
")\n",
"from careamics.utils.metrics import psnr\n",
"from careamics.transforms import ManipulateN2V"
"from careamics.utils.metrics import psnr"
]
},
{
Expand All @@ -38,8 +36,18 @@
"metadata": {},
"outputs": [],
"source": [
"# Download and unzip the files\n",
"# Explore portfolio\n",
"portfolio = PortfolioManager()\n",
"print(portfolio.denoising)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Download and unzip the files\n",
"root_path = Path(\"data\")\n",
"files = portfolio.denoising.N2V_BSD68.download(root_path)\n",
"print(f\"List of downloaded files: {files}\")"
Expand All @@ -55,7 +63,12 @@
"train_path = data_path / \"train\"\n",
"val_path = data_path / \"val\"\n",
"test_path = data_path / \"test\" / \"images\"\n",
"gt_path = data_path / \"test\" / \"gt\""
"gt_path = data_path / \"test\" / \"gt\"\n",
"\n",
"train_path.mkdir(parents=True, exist_ok=True)\n",
"val_path.mkdir(parents=True, exist_ok=True)\n",
"test_path.mkdir(parents=True, exist_ok=True)\n",
"gt_path.mkdir(parents=True, exist_ok=True)"
]
},
{
Expand Down Expand Up @@ -118,7 +131,9 @@
" algorithm=\"n2v\",\n",
" loss=\"n2v\",\n",
" architecture=\"UNet\",\n",
")\n"
" optimizer_parameters={\"lr\": 1e-4},\n",
" lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n",
")"
]
},
{
Expand All @@ -129,17 +144,6 @@
"### Define the Transforms"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transforms = Aug.Compose(\n",
" [Aug.Flip(), Aug.RandomRotate90(), Aug.Normalize(), ManipulateN2V()],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -161,7 +165,6 @@
" patch_size=(64, 64),\n",
" axes=\"SYX\",\n",
" batch_size=128,\n",
" transforms=transforms,\n",
" num_workers=4,\n",
")"
]
Expand All @@ -182,7 +185,7 @@
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(max_epochs=1)"
"trainer = Trainer(max_epochs=50)"
]
},
{
Expand All @@ -202,17 +205,6 @@
"### Define a prediction datamodule"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transforms_predict = Aug.Compose(\n",
" [Aug.Normalize()],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -225,8 +217,6 @@
" tile_size=(256, 256),\n",
" axes=\"YX\",\n",
" batch_size=1,\n",
" num_workers=0,\n",
" transforms=transforms_predict,\n",
")"
]
},
Expand Down Expand Up @@ -324,7 +314,7 @@
"psnr_total = 0\n",
"\n",
"for pred, gt in zip(preds, gts):\n",
" psnr_total += psnr(gt, pred)\n",
" psnr_total += psnr(gt, pred.squeeze())\n",
"\n",
"print(f\"PSNR total: {psnr_total / len(preds)}\")"
]
Expand Down Expand Up @@ -353,7 +343,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.18"
},
"vscode": {
"interpreter": {
Expand Down
9 changes: 4 additions & 5 deletions examples/careamics_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from careamics import CAREamist, Configuration


def main():
config_dict ={
config_dict = {
"experiment_name": "ConfigTest",
"working_directory": ".",
"algorithm": {
Expand All @@ -14,9 +15,7 @@ def main():
"optimizer": {
"name": "Adam",
},
"lr_scheduler": {
"name": "ReduceLROnPlateau"
},
"lr_scheduler": {"name": "ReduceLROnPlateau"},
},
"training": {
"num_epochs": 1,
Expand All @@ -42,5 +41,5 @@ def main():
# print(pred.shape)


if __name__ == '__main__':
if __name__ == "__main__":
main()
18 changes: 10 additions & 8 deletions examples/careamics_lightning_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
"import albumentations as Aug\n",
"from pytorch_lightning import Trainer\n",
"\n",
"\n",
"from careamics import (\n",
" CAREamicsModule,\n",
" CAREamicsTrainDataModule,\n",
")\n",
"from careamics.transforms import ManipulateN2V\n"
"from careamics.transforms import ManipulateN2V"
]
},
{
Expand All @@ -26,20 +25,20 @@
"# Instantiate ligthning module\n",
"model = CAREamicsModule(\n",
" algorithm=\"n2v\",\n",
" loss=\"n2v\", \n",
" loss=\"n2v\",\n",
" architecture=\"UNet\",\n",
" model_parameters={\n",
" # parameters such as depth, n2v2, etc. See UNet definition.\n",
" },\n",
" optimizer=\"Adam\", # see SupportedOptimizer\n",
" optimizer=\"Adam\", # see SupportedOptimizer\n",
" optimizer_parameters={\n",
" \"lr\": 1e-4,\n",
" # parameters from torch.optim\n",
" },\n",
" lr_scheduler=\"ReduceLROnPlateau\", # see SupportedLRScheduler\n",
" lr_scheduler=\"ReduceLROnPlateau\", # see SupportedLRScheduler\n",
" lr_scheduler_parameters={\n",
" # parameters from torch.optim.lr_scheduler\n",
" }\n",
" },\n",
")"
]
},
Expand Down Expand Up @@ -68,9 +67,12 @@
"outputs": [],
"source": [
"# define function to read data\n",
"\n",
"\n",
"def read_my_data_type(file):\n",
" pass\n",
"\n",
"\n",
"# Create your transforms using albumentations\n",
"transforms = Aug.Compose(\n",
" [Aug.Flip(), Aug.RandomRotate90(), Aug.Normalize(), ManipulateN2V()],\n",
Expand All @@ -83,13 +85,13 @@
"train_data_module = CAREamicsTrainDataModule(\n",
" train_path=train_path,\n",
" val_path=val_path,\n",
" data_type=\"custom\", # this forces read_source_func to be specified\n",
" data_type=\"custom\", # this forces read_source_func to be specified\n",
" patch_size=(64, 64),\n",
" axes=\"SYX\",\n",
" batch_size=128,\n",
" transforms=transforms,\n",
" num_workers=4,\n",
" read_source_func = read_my_data_type # function to read data\n",
" read_source_func=read_my_data_type, # function to read data\n",
")"
]
},
Expand Down
Loading
Loading