-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prediction Callback #148
Comments
Hi Conrad, Thank you for your feedback! Firstly, would you mind sharing a code snippet of what you have so far so we can have a clearer idea of the problem? This discussion on the pytorch-lightning repo seems to suggest that calling We are currently in the process of removing the image stitching logic from the prediction loop and implementing it as a separate function. This should be merged with the main this week. It might make logging predictions throughout training easier. So watch this space! Some solutions we will consider implementing in the future might include leveraging existing logging frameworks such as Weights & Biases. See their lightning integration docs, where they implement an example Melisande |
Hi Melisande, Thank you for reply! Shortly after submitting my issue I came across issue #141, and indeed I am almost certain that it would solve the problem I am having. I was basing my solution on the PyTorch Lightning discussion you shared, but I cannot do it exactly as them since, as it stands, I need to use the prediction loop in order to stitch the images together (and all solutions I could see involve calling The W&B Callback you shared does pretty much the same I am doing with mine, but I must use TensorBoard with my project, unfortunately. Do you have any plans to allow custom Callbacks being passed to the |
Hi @conradkun, Passing custom We have not worked much on the loggers so far, but we will try to have some support of both TensorBoard and WandB, with examples of useful cases. Feedback and suggestions always welcome! EDIT: Fixing link EDIT2: Regarding the TensorBoard callback, we actually have a way to define TensorBoard as the logger (this is set in the |
### Description Custom callbacks are a powerful way to customize the information compiled during training, or integrate the training in another application. For instance, it will be necessary for the `napari` plugin to relay the progression of the training to the UI using callbacks. This was also raised here: #148. This PR add a mechanism to pass callbacks upon `CAREamist` instantiation. - **What**: Add the possibility to pass custom callbacks to the `CAREamist`. - **Why**: Enable users to customize further the training or include CAREamics in another application. - **How**: Added a `callbacks` parameter to the constructor of `CAREamist`. ### Changes Made - **Modified**: `callbacks` parameter to the constructor of `CAREamist`, corresponding test in `test_careamist`. ### Related Issues Linked to a request in #148. ### Additional Notes and Examples Here is the example from the doc PR: ```python from pytorch_lightning.callbacks import Callback # define a custom callback class MyPrintingCallback(Callback): def __init__(self): super().__init__() self.has_started = False self.has_ended = False def on_train_start(self, trainer, pl_module): self.has_started = True def on_train_end(self, trainer, pl_module): self.has_ended = True my_callback = MyPrintingCallback() careamist = CAREamist(config, callbacks=[my_callback]) ``` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [x] PR to the documentation exists (for bug fixes / features) - CAREamics/careamics-examples#2 - CAREamics/careamics.github.io#8
Hi @conradkun, #141 has now been merged with main! I hope this will help to solve your issue. We no longer use a custom prediction loop, and instead, the outputs of lightning's Do not hesitate to get in contact with any questions! |
Perfect, thanks for the heads up! I will reimplement what I had using the new setup soon and let you know if I come across any issues. Just saw you also merged the custom callbacks branch, so my life is even simpler! Thank you for the quick work. |
Unfortunately, I was not able to make it work, so here I go. What do I want to achieve? Given a I want to use a datamodule that is detached from both the training and the validation datasets found in the
I think such functionality is really useful in a model like n2v where metrics are not entirely representative of the perceived performance. Picture a UI where a user is training on an image and selects a patch to see how it develops as the training progresses; this real time feedback would be great to give the user an idea of what's going on and whether training longer is a good idea, etc... What have I tried? With the new class CustomPredictAfterValidationCallback(Callback):
def __init__(self, pred_datamodule):
self.pred_datamodule = pred_datamodule
def setup(self, trainer, pl_module, stage):
if stage in ("fit", "validate"):
# setup the predict data for fit/validate, as we will call it during `on_validation_epoch_end`
# not sure if needed, but doesn't hurt until I get it to work
self.pred_datamodule.prepare_data()
self.pred_datamodule.setup("predict")
def on_validation_epoch_end(self, trainer, pl_module):
if trainer.sanity_checking: # optional skip
return
predictions = trainer.predict(model=pl_module, datamodule=self.pred_datamodule)
return convert_outputs(predictions, self.pred_datamodule.tiled)
pred_datamodule = create_pred_datamodule(
source="image.tiff",
config=config
)
predict_after_val_callback = CustomPredictAfterValidationCallback(pred_datamodule=pred_datamodule)
engine = CAREamist(config, callbacks=[predict_after_val_callback]) For some reason, PyTorch lightning really does not like this setup. In particular, I get the following error after trying to call ---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[32], [line 1](vscode-notebook-cell:?execution_count=32&line=1)
----> [1](vscode-notebook-cell:?execution_count=32&line=1) engine.train(datamodule=data_module)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:322, in CAREamist.train(self, datamodule, train_source, val_source, train_target, val_target, use_in_memory, val_percentage, val_minimum_split)
[320](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:320) # train
[321](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:321) if datamodule is not None:
--> [322](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:322) self._train_on_datamodule(datamodule=datamodule)
[324](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:324) else:
[325](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:325) # raise error if target is provided to N2V
[326](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:326) if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:394, in CAREamist._train_on_datamodule(self, datamodule)
[391](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:391) # record datamodule
[392](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:392) self.train_datamodule = datamodule
--> [394](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:394) self.trainer.fit(self.model, datamodule=datamodule)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
[542](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:542) self.state.status = TrainerStatus.RUNNING
[543](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:543) self.training = True
--> [544](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544) call._call_and_handle_interrupt(
[545](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:545) self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
[546](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:546) )
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
[42](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:42) if trainer.strategy.launcher is not None:
[43](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:43) return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> [44](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44) return trainer_fn(*args, **kwargs)
[46](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:46) except _TunerExitException:
[47](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:47) _call_teardown_hook(trainer)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
[573](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:573) assert self.state.fn is not None
[574](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:574) ckpt_path = self._checkpoint_connector._select_ckpt_path(
[575](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:575) self.state.fn,
[576](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:576) ckpt_path,
[577](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:577) model_provided=True,
[578](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:578) model_connected=self.lightning_module is not None,
[579](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:579) )
--> [580](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580) self._run(model, ckpt_path=ckpt_path)
[582](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:582) assert self.state.stopped
[583](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:583) self.training = False
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path)
[982](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:982) self._signal_connector.register_signal_handlers()
[984](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:984) # ----------------------------
[985](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:985) # RUN THE TRAINER
[986](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:986) # ----------------------------
--> [987](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:987) results = self._run_stage()
[989](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:989) # ----------------------------
[990](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:990) # POST-Training CLEAN UP
[991](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:991) # ----------------------------
[992](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:992) log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1033, in Trainer._run_stage(self)
[1031](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1031) self._run_sanity_check()
[1032](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1032) with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> [1033](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1033) self.fit_loop.run()
[1034](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1034) return None
[1035](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1035) raise RuntimeError(f"Unexpected state {self.state}")
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
[203](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:203) try:
[204](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:204) self.on_advance_start()
--> [205](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205) self.advance()
[206](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:206) self.on_advance_end()
[207](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:207) self._restarting = False
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
[361](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:361) with self.trainer.profiler.profile("run_training_epoch"):
[362](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:362) assert self._data_fetcher is not None
--> [363](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363) self.epoch_loop.run(self._data_fetcher)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141, in _TrainingEpochLoop.run(self, data_fetcher)
[139](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:139) try:
[140](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140) self.advance(data_fetcher)
--> [141](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141) self.on_advance_end(data_fetcher)
[142](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:142) self._restarting = False
[143](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:143) except StopIteration:
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295, in _TrainingEpochLoop.on_advance_end(self, data_fetcher)
[291](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:291) if not self._should_accumulate():
[292](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:292) # clear gradients to not leave any unused memory during validation
[293](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:293) call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
--> [295](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295) self.val_loop.run()
[296](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:296) self.trainer.training = True
[297](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:297) self.trainer._logger_connector._first_loop_iter = first_loop_iter
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:182, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
[180](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:180) context_manager = torch.no_grad
[181](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:181) with context_manager():
--> [182](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:182) return loop_run(self, *args, **kwargs)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:142, in _EvaluationLoop.run(self)
[140](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:140) self._restarting = False
[141](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:141) self._store_dataloader_outputs()
--> [142](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:142) return self.on_run_end()
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:254, in _EvaluationLoop.on_run_end(self)
[251](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:251) self.trainer._logger_connector._evaluation_epoch_end()
[253](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:253) # hook
--> [254](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:254) self._on_evaluation_epoch_end()
[256](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:256) logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
[257](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:257) # include any logged outputs on epoch_end
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:336, in _EvaluationLoop._on_evaluation_epoch_end(self)
[333](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:333) call._call_callback_hooks(trainer, hook_name)
[334](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:334) call._call_lightning_module_hook(trainer, hook_name)
--> [336](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:336) trainer._logger_connector.on_epoch_end()
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:195, in _LoggerConnector.on_epoch_end(self)
[193](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:193) def on_epoch_end(self) -> None:
[194](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:194) assert self._first_loop_iter is None
--> [195](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:195) metrics = self.metrics
[196](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:196) self._progress_bar_metrics.update(metrics["pbar"])
[197](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:197) self._callback_metrics.update(metrics["callback"])
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:233, in _LoggerConnector.metrics(self)
[231](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:231) """This function returns either batch or epoch metrics."""
[232](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:232) on_step = self._first_loop_iter is not None
--> [233](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:233) assert self.trainer._results is not None
[234](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:234) return self.trainer._results.metrics(on_step)
AssertionError: I found a post with a similar error from April, which suggested to use class CustomPredictAfterValidationCallback(Callback):
def __init__(self, pred_datamodule):
self.pred_datamodule = pred_datamodule
def setup(self, trainer, pl_module, stage):
if stage in ("fit", "validate"):
# setup the predict data for fit/validate, as we will call it during `on_validation_epoch_end`
# not sure if needed, but doesn't hurt until I get it to work
self.pred_datamodule.prepare_data()
self.pred_datamodule.setup("predict")
def on_validation_epoch_end(self, trainer, pl_module):
if trainer.sanity_checking: # optional skip
return
# not entirely sure about how preds are returned (and how they must be concatenated), take as pseudocode
predictions = []
for batch, idx in enumerate(self.pred_datamodule.predict_dataloader()):
preds = pl_module.predict_step(batch, idx) # breaks here
predictions += preds
return convert_outputs(predictions, self.pred_datamodule.tiled) The problem with this approach is how the careamics/src/careamics/lightning_module.py Lines 151 to 155 in 0a29ea2
Referencing What to do? Somehow, |
Thanks for sharing your code! We've investigated a little and identified what is basically preventing this approach. We will have a deeper look in the next weeks to see if some refactoring would make the current code base compatible with a prediction callback! |
That's great to hear, thank you very much! |
Hi both, I know there has been some work done towards this feature, is it fully working yet? Let me know if I could help implementing it, otherwise :) |
Hi Conrad! We are hosting I2K so we are a bit overwhelmed at the moment. Let us come back to you beginning of November! |
Hi @conradkun, Sorry for the delay, life's always more busy than expected. I had a go at it and came up with a hacky way to make it fit together (https://github.com/CAREamics/careamics/actions/runs/11783874603?pr=266), in essence it looks like this: import numpy as np
from pytorch_lightning import Callback, Trainer
from careamics import CAREamist, Configuration
from careamics.lightning import PredictDataModule, create_predict_datamodule
from careamics.prediction_utils import convert_outputs
config = Configuration(**minimum_configuration)
class CustomPredictAfterValidationCallback(Callback):
def __init__(self, pred_datamodule: PredictDataModule):
self.pred_datamodule = pred_datamodule
# prepare data and setup
self.pred_datamodule.prepare_data()
self.pred_datamodule.setup()
self.pred_dataloader = pred_datamodule.predict_dataloader()
self.data = None
def on_validation_epoch_end(self, trainer: Trainer, pl_module):
if trainer.sanity_checking: # optional skip
return
# update statistics in the prediction dataset for coherence
# (they can computed on-line by the training dataset)
self.pred_datamodule.predict_dataset.image_means = (
trainer.datamodule.train_dataset.image_stats.means
)
self.pred_datamodule.predict_dataset.image_stds = (
trainer.datamodule.train_dataset.image_stats.stds
)
# predict on the dataset
outputs = []
for idx, batch in enumerate(self.pred_dataloader):
batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0)
outputs.append(pl_module.predict_step(batch, batch_idx=idx))
self.data = convert_outputs(outputs, self.pred_datamodule.tiled)
# save data here
array = np.arange(32 * 32).reshape((32, 32))
pred_datamodule = create_predict_datamodule(
pred_data=array,
data_type=config.data_config.data_type,
axes=config.data_config.axes,
image_means=[11.8], # random placeholder
image_stds=[3.14],
# can choose tiling here
)
predict_after_val_callback = CustomPredictAfterValidationCallback(
pred_datamodule=pred_datamodule
)
engine = CAREamist(config, callbacks=[predict_after_val_callback])
engine.train(train_source=array) This raised a few issues pertaining to the logic of CAREamics that I detail in the PR. We are undergoing a dataset refactoring step, and we will consider the aforementioned points during the refactoring. Regarding the PR, I need to test it with a real example before merging it (providing that it passes review), as I only made it run through a simple test! You are welcome to try it and comment in the PR, hoping that I am not making you try a faulty example. We would also be super interested in knowing how and why you are using CAREamics! And if you ever want to use it in production, we would be happy to potentially integrate specific integration tests to make sure the API is stable for your use cases. edit: Add comments to the code |
Hi @conradkun ! I tested it on the SEM notebook, with a slightly modified version of the example code in my previous message (actually saving the images, and using tiling), and it seemed to have worked well! It adds some complexity to our codebase, but only in a part we were not happy with anyway so at least we will have it in mind in the next refactoring and we will maintain this feature. We will merge the PR in a few moments. I leave this issue open, so you can let us know whether it worked for your application! |
Hi @jdeschamps ! Indeed life gets busy for me too, sorry for the late reply. This is great news, though! I will test it later this week, but it's a good sign that you got it to work on a test dataset yourself. About why I am using CAREamics, I'm not sure about how much of it I am allowed to share publicly, but I would love to talk about it with you. Let me know if there is another channel where we could talk further about this. |
I thought so. 😄 If you are willing to discuss it on private channels, drop me an email and we could have a Zoom call! |
I consider that this issue is close! Don't hesitate to open a new one if you have issues! |
#4) * Update LVAE DataType list (CAREamics#251) ### Description Update the list of possible data types for LVAE datasets according to planned reproducibility experiments * 3D model (CAREamics#240) ### Description Adding 3D/2.5D to LVAE model - **What**: Merge current implementation in the original repo with refactoring done previously. ### Changes Made - Relevant pydantic config - Added encoder_conv_strides and decoder_conv_strides params. They control the strides in conv layers, and can have len 2 or 3. They're meant to make a choice between 2D and 3D convs and control the shape of encoder/decoder (subject to change) - Input shape is now a tuple containing shapes of 2/3 dimensions - Stochastic layer are in the separate module (subject to change) - NonStochastic is removed alongside with relevant parameters - Some docs update - Basic tests ### Notes/Problems - Dosctings are a mess, should be fixed in a separate PR later - Lot's of mypy etc issues - Some tests don't pass because we need to clarify multiscale count param --- **Please ensure your PR meets the following requirements:** - [x ] Code builds and passes tests locally, including doctests - [x ] New tests have been added (for bug fixes/features) - [ ] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: melisande-c <milly.croft@gmail.com> Co-authored-by: federico-carrara <federico1.carrara@mail.polimi.it> Co-authored-by: federico-carrara <federico.carrara@fht.org> * ci(pre-commit.ci): autoupdate (CAREamics#250) <!--pre-commit.ci start--> updates: - [github.com/abravalheri/validate-pyproject: v0.19 → v0.20.2](abravalheri/validate-pyproject@v0.19...v0.20.2) - [github.com/astral-sh/ruff-pre-commit: v0.6.3 → v0.6.9](astral-sh/ruff-pre-commit@v0.6.3...v0.6.9) <!--pre-commit.ci end--> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> * Add image.sc badge * Update README.md * Feature: Predict to disk (outerloop implementation) (CAREamics#253) ### Description - **What**: Add a `predict_to_disk` function to the `CAREamist` class. - **Why**: So users can save predictions without having to write the saving process themselves. - **How**: This implementation loops through the files, predicts the result and saves each one in turn. N.b. this will eventually be replaced with the `PredictionWriterCallback` version. ### Changes Made - **Added**: - `predict_to_disk` function to `CAREamist` class. - tests - **Modified**: - Some type hints. ### Additional Notes and Examples Currently the results are saved to a directory called "predictions" but maybe it should be configurable by the user? This function can only be called on source data from a path because the prediction files are saved with the same name as the source files. Not the neatest as I expect this to be replaced soon. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) * refac: modularized + cleaned the code for LVAE losses (CAREamics#255) ### Description Please provide a brief description of the changes in this PR. Include any relevant context or background information. - **What**: Modularized loss functions for LVAE based models. Now it is possible to create custom losses for different algorithms. In addition, superfluous code has been removed. - **Why**: it allows to implement new losses for new algorithm in a more clean and modular way using the existing building blocks. Moreover, the code is now a bit simpler to read. - **How**: Created general function to aggregate KL loss according to different approaches. Simplified the way reconstruction loss is computed. Changed all the tests accordingly. NOTE: the API to call loss functions for existing algorithms (e.g., muSplit and denoiSplit) has been kept untouched. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * refac: reorganization of pydantic configs in LVAE model (CAREamics#256) ### Description - **What**: Following the polishing of LVAE losses, we reorganized also the pydantic models that handle losses and likelihoods for better readability, usability, and overall organization. - **Why**: Make clean more readable, usable, and organized for further developments. - **How**: Implemented `LVAELossConfig` to replace `LVAELossParameters`. Removed unnecessary attributes from other pydantic models. ### Breaking changes Refactoring of pydantic models will certainly break the examples in `micrSplit_reproducibility` repo. ### Note This PR is based on CAREamics#255. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Fix(BMZ export): torchvision version; model validation (CAREamics#257) ### Description - **What**: Two fixes to the bmz export: The torchvision version in the environment.yaml file was incorrectly set to be the torch version; and bmz `ModelDescription` configuration had the incorrect parameter `decimals` (it was meant to be `decimal` without the 's'). - **Why**: The incorrect env file meant the CI couldn't environment could be created and the incorrect parameter meant the model validation could not be run in the CI. - **How**: Set the correct torchvision version in the env file. Removed the configuration from the `ModelDesc`, following bioimage-io/core-bioimage-io-python#418 the decimal parameter is deprecated. ### Changes Made - **Modified**: - func `create_env_text` in `src/careamics/model_io/bioimage/bioimage_utils.py` - src/careamics/model_io/bmz_io.py - func `create_model_description` in `src/careamics/model_io/bioimage/model_description.py` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) * Fix: Enforce dataloader params to have shuffle=True (CAREamics#259) ### Description - **What**: It seems that in the `CAREamics` `TrainDataModule` the dataloader does not have shuffle set to `True`. - **Why**: Not shuffling the data during training can result in worse training, e.g. overfitting. - **How**: Allow users to explicitly pass shuffle=False with a warning, otherwise `{"shuffle": True}` is added to the param dictionary, if the dataset is not a subclass of `IterableDataset`.` ### Changes Made - **Modified**: `TrainDataModule.train_dataloader` ### Additional Notes and Examples See the discussion in CAREamics#258 for details. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> * Performance test induced fixes (CAREamics#260) Different changes happened during performance testing ### Changes Made Pydantic configs Losses NM/Likelihood refac(from CAREamics#256 ) Tests TODOs for later refactoring --- **Please ensure your PR meets the following requirements:** - [ x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x ] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * ci(pre-commit.ci): autoupdate (CAREamics#262) <!--pre-commit.ci start--> updates: - [github.com/abravalheri/validate-pyproject: v0.20.2 → v0.22](abravalheri/validate-pyproject@v0.20.2...v0.22) - [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.7.2](astral-sh/ruff-pre-commit@v0.6.9...v0.7.2) - [github.com/psf/black: 24.8.0 → 24.10.0](psf/black@24.8.0...24.10.0) - [github.com/pre-commit/mirrors-mypy: v1.11.2 → v1.13.0](pre-commit/mirrors-mypy@v1.11.2...v1.13.0) - [github.com/kynan/nbstripout: 0.7.1 → 0.8.0](kynan/nbstripout@0.7.1...0.8.0) <!--pre-commit.ci end--> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix: Allow singleton channel in convenience functions (CAREamics#265) ### Description Following CAREamics#159, this PR allows creating a configuration with a singleton channel via the configuration convenience functions. - **What**: Allow singleton channel in convenience functions. - **Why**: In some rare cases, a singleton channel might be present in the data. - **How**: Change the `if` statements and error raising conditions of the convenience functions. ### Changes Made - **Modified**: `configuration_factory.py`. ### Related Issues - Fixes [Allow singleton channel in convenience functions](CAREamics#159) --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [x] PR to the documentation exists (for bug fixes / features) * feat: Add convenience function to read loss from CSVLogger (CAREamics#267) ### Description Following CAREamics#252, this PR enforces that the `CSVLogger` is always used (even if `WandB` is requested for instance), and add an API entry point in `CAREamist` to return a dictionary of the losses. This allows users to simply plot the loss in a notebook after training for instance. While they will be better off using `WandB` or `TensorBoard`, this is enough for most users. - **What**: Enforce `CSVLogger` and add functions to read the losses from `metrics.csv`. - **Why**: So that users have an easy way to plot the loss curves. - **How**: Add a new `lightning_utils.py` file with the read csv function, and call this method from `CAREamist`. ### Changes Made - **Added**: `lightning_utils.py`. - **Modified**: `CAREamist`. ### Related Issues Link to any related issues or discussions. Use keywords like "Fixes", "Resolves", or "Closes" to link to issues automatically. - Resolves CAREamics#252 ### Additional Notes and Examples An alternative path would have been to add a `Callback` and do the logging ourselves. I decided for the solution that uses the `csv` file that is anyway created by default (when there is no WandB or TB loggers), to minimize the code that needs to be maintained. One potential issue is the particular csv file read is chosen following the experiment name recorded by `CAREamist` and the last `version_*`. This may not be true if the paths have changed, but in most cases it should be valid if called right after training. Here is what it looks like in the notebooks: ``` python import matplotlib.pyplot as plt losses = careamist.get_losses() plt.plot(losses["train_epoch"], losses["train_loss"], label="Train Loss") plt.plot(losses["val_epoch"], losses["val_loss"], label="Val Loss") ``` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) * Refac: Rename config file to careamics.yaml in BMZ export (CAREamics#271) ### Description - **What**: When exporting to bmz the config file is now called `careamics.yaml`. Searching for the config file during loading has also been made more restrictive: previously the function searched for any `.yml` file in the attachments and now it searches specifically for `careamics.yaml`. - **Why**: Renaming the file makes it clearer to users it relates to CAREamics' functionality and should prevent any future name clashes with other tools. The config file loading was made more restrictive because it was not very robust to possible cases where additional attachments are used, and using the `export_to_bmz` function doesn't allow any choice in the name of the config file. - **How**: Modified config path in `export_to_bmz` and modified config path search in `extract_model_path`. ### Changes Made - **Modified**: - `export_to_bmz` - `extract_model_path` - `save_configuration` docs ### Related Issues - Resolves CAREamics#269 --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) * Feature: Load model from BMZ using URL (CAREamics#273) ### Description - **What**: Now it is possible to pass a URL to the `load_from_bmz` function to download and load BMZ files. - **Why**: Not many users will have access to the model resource URLs, but this functionality is useful for developing the CAREamics BMZ compatibility script. - **How**: - Type hint `path` as also `pydantic.HttpUrl` in `load_from_bmz` (as in `bioimage.core`); - Remove `path` validation checks from `load_from_bmz` and allow it to be handled in `load_model_description` - Call `download` on the file resources to download and get the correct path. ### Changes Made - **Modified**: - `load_from_bmz` - `extract_model_path` ### Additional Notes and Examples This will have merge conflicts with CAREamics#271. There are currently no official tests (it does work), we can discuss using the URL of one of the existing CAREamics models uploaded to the BMZ or create a Mock. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) * feat: Enable prediction step during training (CAREamics#266) ### Description Following CAREamics#148, I have been exploring how to predict during training. This PR would allow adding `Callback` that use `predict_step` during Training. - **What**: Allow callbacks to call `predict_step` during training. - **Why**: Some applications might require predicting consistently on full images to assess training performances throughout training. - **How**: Modified `FCNModule.predict_step` to make it compatible with a `TrainDataModule` (all calls to `trainer.datamodule` were written with the expectation that it returns a `PredictDataModule`. ### Changes Made - **Modified**: `lightning_module.py`, `test_lightning_module.py` ### Related Issues - Resolves CAREamics#148 ### Additional Notes and Examples ```python import numpy as np from pytorch_lightning import Callback, Trainer from careamics import CAREamist, Configuration from careamics.lightning import PredictDataModule, create_predict_datamodule from careamics.prediction_utils import convert_outputs config = Configuration(**minimum_configuration) class CustomPredictAfterValidationCallback(Callback): def __init__(self, pred_datamodule: PredictDataModule): self.pred_datamodule = pred_datamodule # prepare data and setup self.pred_datamodule.prepare_data() self.pred_datamodule.setup() self.pred_dataloader = pred_datamodule.predict_dataloader() def on_validation_epoch_end(self, trainer: Trainer, pl_module): if trainer.sanity_checking: # optional skip return # update statistics in the prediction dataset for coherence # (they can computed on-line by the training dataset) self.pred_datamodule.predict_dataset.image_means = ( trainer.datamodule.train_dataset.image_stats.means ) self.pred_datamodule.predict_dataset.image_stds = ( trainer.datamodule.train_dataset.image_stats.stds ) # predict on the dataset outputs = [] for idx, batch in enumerate(self.pred_dataloader): batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0) outputs.append(pl_module.predict_step(batch, batch_idx=idx)) data = convert_outputs(outputs, self.pred_datamodule.tiled) # can save data here array = np.arange(32 * 32).reshape((32, 32)) pred_datamodule = create_predict_datamodule( pred_data=array, data_type=config.data_config.data_type, axes=config.data_config.axes, image_means=[11.8], # random placeholder image_stds=[3.14], # can do tiling here ) predict_after_val_callback = CustomPredictAfterValidationCallback( pred_datamodule=pred_datamodule ) engine = CAREamist(config, callbacks=[predict_after_val_callback]) engine.train(train_source=array) ``` Currently, this current implementation is not fully satisfactory and here are a few important points: - For this PR to work we need to discriminate between `TrainDataModule` and `PredictDataModule` in `predict_step`, which is a bit of a hack as it currently check `hasattr(..., "tiled")`. The reason is to avoid a circular import of `PredictDataModule`. We should revisit that. - `TrainDataModule` and `PredictDataModule` have incompatible members: `PredictDataModule` has `.tiled`, and the two have different naming conventions for the statistics (`PredictDataModule` has `image_means` and `image_stds`, while `TrainDataModule` has them wrapped in a `stats` dataclass). These statistics are retrieved either through `_trainer.datamodule.predict_dataset` or `_trainer.datamodule.train_dataset`. - We do not provide the `Callable` that would allow to use such feature. We might want to some heavy lifting here as well (see example). - Probably the most serious issue, normalization is done in the datasets but denormalization is performed in the `predict_step`. In our case, that means that normalization could be applied by a `PredictDataModule` (in the `Callback` and the denormalization by the `TrainDataModule` (in `predict_step`). That is incoherent and due to the way we wrote CAREamics. All in all, this draft exemplifies two problems with CAREamics: - `TrainDataModule` and `PredictDataModule` have different members - Normalization is done by the `DataModule` but denormalization by `LightningModule` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [x] PR to the documentation exists (for bug fixes / features) * feat: added functions to load neuron&astrocytes dataset * refac: added possibility to pick `kl_restricted` loss & other loss-related refactoring (CAREamics#272) ### Description In some LVAE training examples (see `microSplit_reproducibility` repo) there is the need to consider the *restricted KL loss*, instead of the simple sample-wise one. This PR allows the user to pick that one. - **What**: The KL loss type is no longer hardcoded in the loss functions. Now it is possible to pick also the `restricted_kl` KL loss type. - **Why**: It is needed for some experiments/examples. - **How**: added an input parameter to the KL loss functions. ### Breaking changes The `kl_type` parameter is added in the loss functions, so we need to be careful of correctly specifying it in the examples in the `microSplit_reproducibility` repo. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) * tmp: function to load 3D data, extract 2D slices and save them separately * feat: updates to handle 2D slices * refac: removed main + improved output of laoding function * A new enum for a new splitting task. (CAREamics#270) ### Description Adding a new enum type for a splitting task which I had missed communicating earlier. I am putting the relevant things in microsplit-reproducibility repo after a brief chat with @veegalinova. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> * feat: loading supported also for unmixed images * refac: made `get_fnames` public + changed way GroupType is obtained * fix: bug * updated content of examples folder * Fix(BMZ): Relax model output validation kwargs; extract weights and config file following new `spec` and `core` release (CAREamics#279) ### Description - **What**: Relaxing the model output validation kwargs, both absolute and relative tolerance, from the default, `1e-4`, to `1e-2`. - **Why**: The defaults are pretty strict and some of our uploaded models are stuck in pending because of slightly mismatching input and outputs. - e.g. (Actually maybe absolute tolerance should be put to 0, otherwise it still might not pass after this PR) ```console Output and expected output disagree: Not equal to tolerance rtol=0.0001, atol=0.00015 Mismatched elements: 40202 / 1048576 (3.83%) Max absolute difference: 0.1965332 Max relative difference: 0.0003221 ``` - **How**: In the model description config param, added the new test kwargs. Additionally, updated `bmz_export` so that the test kwargs in the model description are used during model testing at export time. ### Changes Made - **Modified**: Describe existing features or files modified. - `create_model_description`: added test_kwargs to config param - `export_to_bmz`: use test_kwargs in model description for model testing at export time. ### Related Issues - Resolves - last checkbox in CAREamics#278 EDIT: This PR also fixes loading from BMZ following an incompatible release of `bioimageio/core` (`0.7.0`) and `bioimageio/spec` (`0.5.3.5`). The problem was `load_model_description` no longer unzips the archive file but only streams the `rdf.yaml` file data. This means we have to now extract the weights and careamics config from the zip to load them, which can be done using `bioimageio.spec._internal.io.resolve_and_extract` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> * refac: renamed and moved file to read CZI * style: added progress bar to data loading function * fix: fixed FP info fetching for dyes * updated training for astro_neuron dset * added script to train λSplit on astrocytes data * updated training examples * Fix(dependencies): Set bioimageio-core version greater than 0.7.0 (CAREamics#280) ### Description - **What**: Set bioimageio-core version greater than 0.7.0 - **Why**: Following the new `bioimage-core` release (0.7.0), we needed to make some fixes (part of PR CAREamics#279). The most convenient function to solve this problem, `resolve_and_extract` only exists since 0.7.0. - **How**: In pyproject.toml ### Changes Made - **Modified**: pyproject.toml --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) * added lambda parameters to saved configs * added todos * wip: eval examples for astro neuron dset * fix: fixed a bug in KL loss aggregation (LVAE) (CAREamics#277) ### Description Found a bug in the KL loss aggregation happening in the `LadderVAE` model `training_step()`. Specifically, the application of free-bits (`free_bits_kl()`, basically clamping the values of KL entries to a certain lower threshold) was happening after KL entries were rescaled. In this way, when free-bits threshold was set to 1, all the KL entries were clamped to 1, as normally way smaller than this. - **What**: See above. - **Why**: Clear bug in the code. - **How**: Inverted the order of calls in the `get_kl_divergence_loss()` function & adjusted some parts of the code to reflect the changes. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> * feat: added command to sort FPs by wavelength at peak intensity before creating FP ref matrix * feat: function to split fnames in train and test * fix: ONNX exportability compatibity test and fix (CAREamics#275) fix: Fix for ONNX export of Maxblurpool layer and performance optimization by registering kernel as a buffer so that it doesn't need to be copied to the GPU over and over again. ### Description - **What**: Converting the pretrained models to ONNX format gives error in the Maxpool layer used in the N2V2 architecture.This is mainly because the convolution kernel is dynamically expanded to a size matching the number of channels in the input in the Maxblurpool layer. But the number of channels should be constant within the model. - **Why**: Users can convert the pytorch models to ONNX for inference in thier platforms - **How**: -- instead of using the symbolic variable x.size(1), explicitly cast it to an integer and make it a constant. -- make the kernel as a buffer to avoid the copying to GPU overhead. -- add tests for ONNX exportability ### Changes Made - **Added**: -- onnx as a test dependency in pyproject.toml -- 'test_lightning_module_onnx_exportability.py' - **Modified**: Maxblurpool module in 'layers.py' **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> * refac: changed the way train and test samples are taken * feat: CL arg to avoid logging (for debugging runs) * rm: removed outdated modules * fix: set `None` as default for `custom_logger` when logging is disables * fix: fixed bugs resulting from previous merge * fix: fixed more bugs related to previous merge + renamed `algorithm_type` into more coherent `training_mode` * fix: made a few changes to mirror updates in the model code * refac: modified function to read CZI to make it compatible with CAREamics dsets * refac: updated loading pipeline for 3D images (added padding to have same Z-dim) * fix: changed some parts of the code to allow 3D training * updated training script for 3D case * fix: deleting pre-loaded arrays of data after init dsets * feat: added fn to sort fnames by exp ID and img ID * refac: adjusted training script for 2D and loading all data in memory * refac: added function to load 3D imgs callable in InMemoryDataset class * feat: allowing passing kwargs to `read_source_fn` * example: implemented more efficient training pipeline for 3D data * style: cleaned outputs * refac: renamed func `get_train_test_fnames` into `split_train_test_fnames` * fix: solving issue of read_source_kwargs==None in patching * fix: removed `dataloader_params` from serializable attributes in `DataConfig` * refac: added dataloader_params to `InferenceConfig` to match organization of `DataConfig` * rm: removed visualization funcs for λSplit, will be moved in apposite experiment repo * rm: example scripts and notebooks --------- Co-authored-by: Vera Galinova <32124316+veegalinova@users.noreply.github.com> Co-authored-by: Igor Zubarev <zubarev.ia@gmail.com> Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: melisande-c <milly.croft@gmail.com> Co-authored-by: federico-carrara <federico.carrara@fht.org> Co-authored-by: Melisande Croft <63270704+melisande-c@users.noreply.github.com> Co-authored-by: ashesh <ashesh276@gmail.com> Co-authored-by: nimiiit <nimiviswants@gmail.com>
Hi! I recently came across this repo and it looks very promising for my use case!
As I was playing around I wanted to see how the predictions improved throughout the epochs. Ideally, I would like to have a small separate dataset on which the model could be run after each training epoch, or just a random draw of some images from the validation dataset.
I tried to implement this myself with Pytorch lightning
Callbacks
, but I don't see a clear way to get around having to calltrainer.predict
inside the callback, and I fear that messes up thetrainer.fit
loop by deleting the validation losses it keeps track of.Given that you may have a lot more intuition of how CAREamics works, do you have an idea of something that would work? I am happy to implement it myself and create a PR, but currently I do not see how I can avoid calling
training.predict
since the prediction loop is modified to stitch the image together.Thank you,
Conrad
The text was updated successfully, but these errors were encountered: