Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction Callback #148

Closed
conradkun opened this issue Jun 14, 2024 · 15 comments · Fixed by #266
Closed

Prediction Callback #148

conradkun opened this issue Jun 14, 2024 · 15 comments · Fixed by #266
Labels
feature New feature or request

Comments

@conradkun
Copy link
Contributor

Hi! I recently came across this repo and it looks very promising for my use case!

As I was playing around I wanted to see how the predictions improved throughout the epochs. Ideally, I would like to have a small separate dataset on which the model could be run after each training epoch, or just a random draw of some images from the validation dataset.

I tried to implement this myself with Pytorch lightning Callbacks, but I don't see a clear way to get around having to call trainer.predict inside the callback, and I fear that messes up the trainer.fit loop by deleting the validation losses it keeps track of.

Given that you may have a lot more intuition of how CAREamics works, do you have an idea of something that would work? I am happy to implement it myself and create a PR, but currently I do not see how I can avoid calling training.predict since the prediction loop is modified to stitch the image together.

Thank you,

Conrad

@conradkun conradkun added the feature New feature or request label Jun 14, 2024
@melisande-c
Copy link
Member

melisande-c commented Jun 17, 2024

Hi Conrad,

Thank you for your feedback!

Firstly, would you mind sharing a code snippet of what you have so far so we can have a clearer idea of the problem?

This discussion on the pytorch-lightning repo seems to suggest that calling trainer.predict in the Callback hook on_validation_epoch_end shouldn't cause any issues. However, it is over a year old and might be out of date.

We are currently in the process of removing the image stitching logic from the prediction loop and implementing it as a separate function. This should be merged with the main this week. It might make logging predictions throughout training easier. So watch this space!

Some solutions we will consider implementing in the future might include leveraging existing logging frameworks such as Weights & Biases. See their lightning integration docs, where they implement an example LogPredictionSamplesCallback class.

Melisande

@conradkun
Copy link
Contributor Author

Hi Melisande,

Thank you for reply! Shortly after submitting my issue I came across issue #141, and indeed I am almost certain that it would solve the problem I am having.

I was basing my solution on the PyTorch Lightning discussion you shared, but I cannot do it exactly as them since, as it stands, I need to use the prediction loop in order to stitch the images together (and all solutions I could see involve calling predict_step). This should not be an issue anymore after the refactoring. In any case, I will wait until it is all merged before I share the (fairly huge) code snippet needed to reproduce it.

The W&B Callback you shared does pretty much the same I am doing with mine, but I must use TensorBoard with my project, unfortunately. Do you have any plans to allow custom Callbacks being passed to the CAREamist class? (as of now I am just redefining _define_callbacks which is a bit ugly)

@jdeschamps
Copy link
Member

jdeschamps commented Jun 18, 2024

Hi @conradkun,

Passing custom callbacks to the CAREamist is a great idea! We will actually need this very soon for the napari plugin, so here ya go: #150

We have not worked much on the loggers so far, but we will try to have some support of both TensorBoard and WandB, with examples of useful cases. Feedback and suggestions always welcome!

EDIT: Fixing link

EDIT2: Regarding the TensorBoard callback, we actually have a way to define TensorBoard as the logger (this is set in the Configuration.training_config.logger), we have not really tested it so if you notice something necessary to make it work that is missing. Let us know!!

jdeschamps added a commit that referenced this issue Jun 20, 2024
### Description

Custom callbacks are a powerful way to customize the information
compiled during training, or integrate the training in another
application. For instance, it will be necessary for the `napari` plugin
to relay the progression of the training to the UI using callbacks. This
was also raised here: #148.

This PR add a mechanism to pass callbacks upon `CAREamist`
instantiation.

- **What**: Add the possibility to pass custom callbacks to the
`CAREamist`.
- **Why**: Enable users to customize further the training or include
CAREamics in another application.
- **How**: Added a `callbacks` parameter to the constructor of
`CAREamist`.

### Changes Made

- **Modified**: `callbacks` parameter to the constructor of `CAREamist`,
corresponding test in `test_careamist`.


### Related Issues

Linked to a request in
#148.


### Additional Notes and Examples

Here is the example from the doc PR:

```python
from pytorch_lightning.callbacks import Callback


# define a custom callback
class MyPrintingCallback(Callback): 
    def __init__(self):
        super().__init__()

        self.has_started = False
        self.has_ended = False

    def on_train_start(self, trainer, pl_module):
        self.has_started = True

    def on_train_end(self, trainer, pl_module):
        self.has_ended = True


my_callback = MyPrintingCallback() 

careamist = CAREamist(config, callbacks=[my_callback]) 
```

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [x] PR to the documentation exists (for bug fixes / features)
    - CAREamics/careamics-examples#2
    - CAREamics/careamics.github.io#8
@melisande-c
Copy link
Member

melisande-c commented Jun 20, 2024

Hi @conradkun,

#141 has now been merged with main! I hope this will help to solve your issue.

We no longer use a custom prediction loop, and instead, the outputs of lightning's Trainer.predict are converted at the end of CAREamist.predict with the new function careamics.prediction_utils.convert_output. convert_output needs to know whether the prediction outputs are tiled and this is determined in the initialisation of CAREamicsPredictData; and it's creation logic has been moved to the function careamics.prediction_utils.create_pred_datamodule. You will see this implemented in the CAREamics.predict method.

Do not hesitate to get in contact with any questions!

@conradkun
Copy link
Contributor Author

Perfect, thanks for the heads up! I will reimplement what I had using the new setup soon and let you know if I come across any issues. Just saw you also merged the custom callbacks branch, so my life is even simpler! Thank you for the quick work.

@conradkun
Copy link
Contributor Author

Unfortunately, I was not able to make it work, so here I go.

What do I want to achieve?

Given a CAREamicsPredictData datamodule, I want to be able to predict on it at specific times: at the end of each training epoch, at the end of X number of training steps, etc... The best tool for this is PyTorch Lightning's Callback.

I want to use a datamodule that is detached from both the training and the validation datasets found in the CAREamicsTrainData for a number of reasons:

  • I want to predict on a full image in order to visualize it as a whole (as opposed to only looking at the patches that the validation set may contain).
  • I want to be able to manually decide how many images get predicted on, meaning a setup where I dump a bunch of pictures I want to keep track on inside a folder, and pass that path to CAREamicsPredictData would be ideal.
  • Similarly to the previous point, I want them to always be the same images for comparison among models, as opposed to randomly drawing from a larger validation set.

I think such functionality is really useful in a model like n2v where metrics are not entirely representative of the perceived performance. Picture a UI where a user is training on an image and selects a patch to see how it develops as the training progresses; this real time feedback would be great to give the user an idea of what's going on and whether training longer is a good idea, etc...

What have I tried?

With the new Callback functionality, I can pass directly the following to the CAREamist class (assuming some config):

class CustomPredictAfterValidationCallback(Callback):
    def __init__(self, pred_datamodule):
        self.pred_datamodule = pred_datamodule

    def setup(self, trainer, pl_module, stage):
        if stage in ("fit", "validate"):
            # setup the predict data for fit/validate, as we will call it during `on_validation_epoch_end`
            # not sure if needed, but doesn't hurt until I get it to work
            self.pred_datamodule.prepare_data()
            self.pred_datamodule.setup("predict")

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.sanity_checking:  # optional skip
            return

        predictions = trainer.predict(model=pl_module, datamodule=self.pred_datamodule)
        return convert_outputs(predictions, self.pred_datamodule.tiled)


pred_datamodule = create_pred_datamodule(
    source="image.tiff",
    config=config
)

predict_after_val_callback = CustomPredictAfterValidationCallback(pred_datamodule=pred_datamodule)
engine = CAREamist(config, callbacks=[predict_after_val_callback])

For some reason, PyTorch lightning really does not like this setup. In particular, I get the following error after trying to call fit:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[32], [line 1](vscode-notebook-cell:?execution_count=32&line=1)
----> [1](vscode-notebook-cell:?execution_count=32&line=1) engine.train(datamodule=data_module)

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:322, in CAREamist.train(self, datamodule, train_source, val_source, train_target, val_target, use_in_memory, val_percentage, val_minimum_split)
    [320](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:320) # train
    [321](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:321) if datamodule is not None:
--> [322](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:322)     self._train_on_datamodule(datamodule=datamodule)
    [324](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:324) else:
    [325](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:325)     # raise error if target is provided to N2V
    [326](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:326)     if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:394, in CAREamist._train_on_datamodule(self, datamodule)
    [391](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:391) # record datamodule
    [392](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:392) self.train_datamodule = datamodule
--> [394](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:394) self.trainer.fit(self.model, datamodule=datamodule)

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    [542](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:542) self.state.status = TrainerStatus.RUNNING
    [543](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:543) self.training = True
--> [544](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544) call._call_and_handle_interrupt(
    [545](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:545)     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    [546](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:546) )

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     [42](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:42)     if trainer.strategy.launcher is not None:
     [43](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:43)         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> [44](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44)     return trainer_fn(*args, **kwargs)
     [46](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:46) except _TunerExitException:
     [47](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:47)     _call_teardown_hook(trainer)

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    [573](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:573) assert self.state.fn is not None
    [574](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:574) ckpt_path = self._checkpoint_connector._select_ckpt_path(
    [575](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:575)     self.state.fn,
    [576](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:576)     ckpt_path,
    [577](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:577)     model_provided=True,
    [578](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:578)     model_connected=self.lightning_module is not None,
    [579](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:579) )
--> [580](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580) self._run(model, ckpt_path=ckpt_path)
    [582](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:582) assert self.state.stopped
    [583](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:583) self.training = False

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path)
    [982](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:982) self._signal_connector.register_signal_handlers()
    [984](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:984) # ----------------------------
    [985](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:985) # RUN THE TRAINER
    [986](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:986) # ----------------------------
--> [987](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:987) results = self._run_stage()
    [989](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:989) # ----------------------------
    [990](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:990) # POST-Training CLEAN UP
    [991](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:991) # ----------------------------
    [992](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:992) log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1033, in Trainer._run_stage(self)
   [1031](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1031)         self._run_sanity_check()
   [1032](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1032)     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> [1033](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1033)         self.fit_loop.run()
   [1034](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1034)     return None
   [1035](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1035) raise RuntimeError(f"Unexpected state {self.state}")

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
    [203](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:203) try:
    [204](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:204)     self.on_advance_start()
--> [205](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205)     self.advance()
    [206](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:206)     self.on_advance_end()
    [207](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:207)     self._restarting = False

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
    [361](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:361) with self.trainer.profiler.profile("run_training_epoch"):
    [362](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:362)     assert self._data_fetcher is not None
--> [363](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363)     self.epoch_loop.run(self._data_fetcher)

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141, in _TrainingEpochLoop.run(self, data_fetcher)
    [139](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:139) try:
    [140](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140)     self.advance(data_fetcher)
--> [141](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141)     self.on_advance_end(data_fetcher)
    [142](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:142)     self._restarting = False
    [143](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:143) except StopIteration:

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295, in _TrainingEpochLoop.on_advance_end(self, data_fetcher)
    [291](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:291) if not self._should_accumulate():
    [292](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:292)     # clear gradients to not leave any unused memory during validation
    [293](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:293)     call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
--> [295](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295) self.val_loop.run()
    [296](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:296) self.trainer.training = True
    [297](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:297) self.trainer._logger_connector._first_loop_iter = first_loop_iter

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:182, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    [180](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:180)     context_manager = torch.no_grad
    [181](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:181) with context_manager():
--> [182](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:182)     return loop_run(self, *args, **kwargs)

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:142, in _EvaluationLoop.run(self)
    [140](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:140)         self._restarting = False
    [141](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:141) self._store_dataloader_outputs()
--> [142](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:142) return self.on_run_end()

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:254, in _EvaluationLoop.on_run_end(self)
    [251](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:251) self.trainer._logger_connector._evaluation_epoch_end()
    [253](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:253) # hook
--> [254](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:254) self._on_evaluation_epoch_end()
    [256](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:256) logged_outputs, self._logged_outputs = self._logged_outputs, []  # free memory
    [257](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:257) # include any logged outputs on epoch_end

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:336, in _EvaluationLoop._on_evaluation_epoch_end(self)
    [333](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:333) call._call_callback_hooks(trainer, hook_name)
    [334](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:334) call._call_lightning_module_hook(trainer, hook_name)
--> [336](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:336) trainer._logger_connector.on_epoch_end()

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:195, in _LoggerConnector.on_epoch_end(self)
    [193](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:193) def on_epoch_end(self) -> None:
    [194](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:194)     assert self._first_loop_iter is None
--> [195](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:195)     metrics = self.metrics
    [196](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:196)     self._progress_bar_metrics.update(metrics["pbar"])
    [197](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:197)     self._callback_metrics.update(metrics["callback"])

File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:233, in _LoggerConnector.metrics(self)
    [231](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:231) """This function returns either batch or epoch metrics."""
    [232](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:232) on_step = self._first_loop_iter is not None
--> [233](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:233) assert self.trainer._results is not None
    [234](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:234) return self.trainer._results.metrics(on_step)

AssertionError:

I found a post with a similar error from April, which suggested to use predict_step instead of predict directly on the Lightning Module. Incidentally, that is also what the aforementioned discussion from the Lightning forums seems to converge towards. So I tried version 2:

class CustomPredictAfterValidationCallback(Callback):
    def __init__(self, pred_datamodule):
        self.pred_datamodule = pred_datamodule

    def setup(self, trainer, pl_module, stage):
        if stage in ("fit", "validate"):
            # setup the predict data for fit/validate, as we will call it during `on_validation_epoch_end`
            # not sure if needed, but doesn't hurt until I get it to work
            self.pred_datamodule.prepare_data()
            self.pred_datamodule.setup("predict")

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.sanity_checking:  # optional skip
            return

        # not entirely sure about how preds are returned (and how they must be concatenated), take as pseudocode
        predictions = []
        for batch, idx in enumerate(self.pred_datamodule.predict_dataloader()):
            preds = pl_module.predict_step(batch, idx) # breaks here
            predictions += preds
        return convert_outputs(predictions, self.pred_datamodule.tiled)

The problem with this approach is how the predict_step function in CAREamicsModule has changed:

if self._trainer.datamodule.tiled:
x, *aux = batch
else:
x = batch
aux = []

Referencing _trainer directly means that it will be looking at the CAREamicsTrainData used for training, and not at the CAREamicsPredictData the current batch comes from. And training modules do not have a .tiled attribute.

What to do?

Somehow, predict_step should be getting the tiled information from the datamodule that is yielding the current batches. Any ideas?

@jdeschamps
Copy link
Member

jdeschamps commented Jun 27, 2024

Thanks for sharing your code!

We've investigated a little and identified what is basically preventing this approach. We will have a deeper look in the next weeks to see if some refactoring would make the current code base compatible with a prediction callback!

@conradkun
Copy link
Contributor Author

That's great to hear, thank you very much!

@conradkun
Copy link
Contributor Author

Hi both,

I know there has been some work done towards this feature, is it fully working yet? Let me know if I could help implementing it, otherwise :)

@jdeschamps
Copy link
Member

Hi Conrad! We are hosting I2K so we are a bit overwhelmed at the moment. Let us come back to you beginning of November!

@jdeschamps
Copy link
Member

jdeschamps commented Nov 11, 2024

Hi @conradkun,

Sorry for the delay, life's always more busy than expected.

I had a go at it and came up with a hacky way to make it fit together (https://github.com/CAREamics/careamics/actions/runs/11783874603?pr=266), in essence it looks like this:

import numpy as np
from pytorch_lightning import Callback, Trainer

from careamics import CAREamist, Configuration
from careamics.lightning import PredictDataModule, create_predict_datamodule
from careamics.prediction_utils import convert_outputs

config = Configuration(**minimum_configuration)

class CustomPredictAfterValidationCallback(Callback):
    def __init__(self, pred_datamodule: PredictDataModule):
        self.pred_datamodule = pred_datamodule

        # prepare data and setup
        self.pred_datamodule.prepare_data()
        self.pred_datamodule.setup()
        self.pred_dataloader = pred_datamodule.predict_dataloader()
        self.data = None

    def on_validation_epoch_end(self, trainer: Trainer, pl_module):
        if trainer.sanity_checking:  # optional skip
            return

        # update statistics in the prediction dataset for coherence
        # (they can computed on-line by the training dataset)
        self.pred_datamodule.predict_dataset.image_means = (
            trainer.datamodule.train_dataset.image_stats.means
        )
        self.pred_datamodule.predict_dataset.image_stds = (
            trainer.datamodule.train_dataset.image_stats.stds
        )

        # predict on the dataset
        outputs = []
        for idx, batch in enumerate(self.pred_dataloader):
            batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0)
            outputs.append(pl_module.predict_step(batch, batch_idx=idx))

        self.data = convert_outputs(outputs, self.pred_datamodule.tiled)
        # save data here


array = np.arange(32 * 32).reshape((32, 32))
pred_datamodule = create_predict_datamodule(
    pred_data=array,
    data_type=config.data_config.data_type,
    axes=config.data_config.axes,
    image_means=[11.8], # random placeholder
    image_stds=[3.14],
    # can choose tiling here
)

predict_after_val_callback = CustomPredictAfterValidationCallback(
    pred_datamodule=pred_datamodule
)
engine = CAREamist(config, callbacks=[predict_after_val_callback])
engine.train(train_source=array)

This raised a few issues pertaining to the logic of CAREamics that I detail in the PR. We are undergoing a dataset refactoring step, and we will consider the aforementioned points during the refactoring.

Regarding the PR, I need to test it with a real example before merging it (providing that it passes review), as I only made it run through a simple test! You are welcome to try it and comment in the PR, hoping that I am not making you try a faulty example.

We would also be super interested in knowing how and why you are using CAREamics! And if you ever want to use it in production, we would be happy to potentially integrate specific integration tests to make sure the API is stable for your use cases.

edit: Add comments to the code

@jdeschamps
Copy link
Member

Hi @conradkun !

I tested it on the SEM notebook, with a slightly modified version of the example code in my previous message (actually saving the images, and using tiling), and it seemed to have worked well!

It adds some complexity to our codebase, but only in a part we were not happy with anyway so at least we will have it in mind in the next refactoring and we will maintain this feature. We will merge the PR in a few moments.

I leave this issue open, so you can let us know whether it worked for your application!

@conradkun
Copy link
Contributor Author

Hi @jdeschamps !

Indeed life gets busy for me too, sorry for the late reply. This is great news, though! I will test it later this week, but it's a good sign that you got it to work on a test dataset yourself.

About why I am using CAREamics, I'm not sure about how much of it I am allowed to share publicly, but I would love to talk about it with you. Let me know if there is another channel where we could talk further about this.

@jdeschamps
Copy link
Member

About why I am using CAREamics, I'm not sure about how much of it I am allowed to share publicly, but I would love to talk about it with you. Let me know if there is another channel where we could talk further about this.

I thought so. 😄

If you are willing to discuss it on private channels, drop me an email and we could have a Zoom call!

@jdeschamps
Copy link
Member

jdeschamps commented Dec 6, 2024

I consider that this issue is close! Don't hesitate to open a new one if you have issues!

federico-carrara added a commit to federico-carrara/careamics that referenced this issue Dec 10, 2024
#4)

* Update LVAE DataType list (CAREamics#251)

### Description

Update the list of possible data types for LVAE datasets according to
planned reproducibility experiments

* 3D model (CAREamics#240)

### Description

Adding 3D/2.5D to LVAE model

- **What**: Merge current implementation in the original repo with
refactoring done previously.

### Changes Made
- Relevant pydantic config
- Added encoder_conv_strides and decoder_conv_strides params. They
control the strides in conv layers, and can have len 2 or 3. They're
meant to make a choice between 2D and 3D convs and control the shape of
encoder/decoder (subject to change)
- Input shape is now a tuple containing shapes of 2/3 dimensions
- Stochastic layer are in the separate module (subject to change)
- NonStochastic is removed alongside with relevant parameters
- Some docs update
- Basic tests 

### Notes/Problems
- Dosctings are a mess, should be fixed in a separate PR later
- Lot's of mypy etc issues
- Some tests don't pass because we need to clarify multiscale count
param


---

**Please ensure your PR meets the following requirements:**

- [x ] Code builds and passes tests locally, including doctests
- [x ] New tests have been added (for bug fixes/features)
- [ ] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: melisande-c <milly.croft@gmail.com>
Co-authored-by: federico-carrara <federico1.carrara@mail.polimi.it>
Co-authored-by: federico-carrara <federico.carrara@fht.org>

* ci(pre-commit.ci): autoupdate (CAREamics#250)

<!--pre-commit.ci start-->
updates:
- [github.com/abravalheri/validate-pyproject: v0.19 →
v0.20.2](abravalheri/validate-pyproject@v0.19...v0.20.2)
- [github.com/astral-sh/ruff-pre-commit: v0.6.3 →
v0.6.9](astral-sh/ruff-pre-commit@v0.6.3...v0.6.9)
<!--pre-commit.ci end-->

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>

* Add image.sc badge

* Update README.md

* Feature: Predict to disk (outerloop implementation) (CAREamics#253)

### Description

- **What**: Add a `predict_to_disk` function to the `CAREamist` class.
- **Why**: So users can save predictions without having to write the
saving process themselves.
- **How**: This implementation loops through the files, predicts the
result and saves each one in turn. N.b. this will eventually be replaced
with the `PredictionWriterCallback` version.

### Changes Made

- **Added**: 
  - `predict_to_disk` function to `CAREamist` class.
  - tests
- **Modified**:
  - Some type hints.

### Additional Notes and Examples

Currently the results are saved to a directory called "predictions" but
maybe it should be configurable by the user?

This function can only be called on source data from a path because the
prediction files are saved with the same name as the source files.

Not the neatest as I expect this to be replaced soon.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

* refac: modularized + cleaned the code for LVAE losses (CAREamics#255)

### Description

Please provide a brief description of the changes in this PR. Include
any relevant context or background information.

- **What**: Modularized loss functions for LVAE based models. Now it is
possible to create custom losses for different algorithms. In addition,
superfluous code has been removed.
- **Why**: it allows to implement new losses for new algorithm in a more
clean and modular way using the existing building blocks. Moreover, the
code is now a bit simpler to read.
- **How**: Created general function to aggregate KL loss according to
different approaches. Simplified the way reconstruction loss is
computed. Changed all the tests accordingly.

NOTE: the API to call loss functions for existing algorithms (e.g.,
muSplit and denoiSplit) has been kept untouched.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* refac: reorganization of pydantic configs in LVAE model (CAREamics#256)

### Description

- **What**: Following the polishing of LVAE losses, we reorganized also
the pydantic models that handle losses and likelihoods for better
readability, usability, and overall organization.
- **Why**: Make clean more readable, usable, and organized for further
developments.
- **How**: Implemented `LVAELossConfig` to replace `LVAELossParameters`.
Removed unnecessary attributes from other pydantic models.

### Breaking changes

Refactoring of pydantic models will certainly break the examples in
`micrSplit_reproducibility` repo.

### Note

This PR is based on CAREamics#255.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix(BMZ export): torchvision version; model validation  (CAREamics#257)

### Description

- **What**: Two fixes to the bmz export: The torchvision version in the
environment.yaml file was incorrectly set to be the torch version; and
bmz `ModelDescription` configuration had the incorrect parameter
`decimals` (it was meant to be `decimal` without the 's').
- **Why**: The incorrect env file meant the CI couldn't environment
could be created and the incorrect parameter meant the model validation
could not be run in the CI.
- **How**: Set the correct torchvision version in the env file. Removed
the configuration from the `ModelDesc`, following
bioimage-io/core-bioimage-io-python#418 the
decimal parameter is deprecated.

### Changes Made

- **Modified**: 
- func `create_env_text` in
`src/careamics/model_io/bioimage/bioimage_utils.py`
  - src/careamics/model_io/bmz_io.py
- func `create_model_description` in
`src/careamics/model_io/bioimage/model_description.py`

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

* Fix: Enforce dataloader params to have shuffle=True (CAREamics#259)

### Description

- **What**: It seems that in the `CAREamics` `TrainDataModule` the
dataloader does not have shuffle set to `True`.
- **Why**: Not shuffling the data during training can result in worse
training, e.g. overfitting.
- **How**: Allow users to explicitly pass shuffle=False with a warning,
otherwise `{"shuffle": True}` is added to the param dictionary, if the
dataset is not a subclass of `IterableDataset`.`

### Changes Made

- **Modified**: `TrainDataModule.train_dataloader`

### Additional Notes and Examples

See the discussion in CAREamics#258 for details.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>

* Performance test induced fixes (CAREamics#260)

Different changes happened during performance testing

### Changes Made

Pydantic configs
Losses
NM/Likelihood refac(from CAREamics#256 )
Tests
TODOs for later refactoring


---

**Please ensure your PR meets the following requirements:**

- [ x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x ] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* ci(pre-commit.ci): autoupdate (CAREamics#262)

<!--pre-commit.ci start-->
updates:
- [github.com/abravalheri/validate-pyproject: v0.20.2 →
v0.22](abravalheri/validate-pyproject@v0.20.2...v0.22)
- [github.com/astral-sh/ruff-pre-commit: v0.6.9 →
v0.7.2](astral-sh/ruff-pre-commit@v0.6.9...v0.7.2)
- [github.com/psf/black: 24.8.0 →
24.10.0](psf/black@24.8.0...24.10.0)
- [github.com/pre-commit/mirrors-mypy: v1.11.2 →
v1.13.0](pre-commit/mirrors-mypy@v1.11.2...v1.13.0)
- [github.com/kynan/nbstripout: 0.7.1 →
0.8.0](kynan/nbstripout@0.7.1...0.8.0)
<!--pre-commit.ci end-->

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix: Allow singleton channel in convenience functions (CAREamics#265)

### Description

Following CAREamics#159, this PR
allows creating a configuration with a singleton channel via the
configuration convenience functions.

- **What**: Allow singleton channel in convenience functions.
- **Why**: In some rare cases, a singleton channel might be present in
the data.
- **How**: Change the `if` statements and error raising conditions of
the convenience functions.

### Changes Made


- **Modified**: `configuration_factory.py`.

### Related Issues

- Fixes [Allow singleton channel in convenience
functions](CAREamics#159)

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [x] PR to the documentation exists (for bug fixes / features)

* feat: Add convenience function to read loss from CSVLogger (CAREamics#267)

### Description

Following CAREamics#252, this PR
enforces that the `CSVLogger` is always used (even if `WandB` is
requested for instance), and add an API entry point in `CAREamist` to
return a dictionary of the losses.

This allows users to simply plot the loss in a notebook after training
for instance. While they will be better off using `WandB` or
`TensorBoard`, this is enough for most users.

- **What**: Enforce `CSVLogger` and add functions to read the losses
from `metrics.csv`.
- **Why**: So that users have an easy way to plot the loss curves.
- **How**: Add a new `lightning_utils.py` file with the read csv
function, and call this method from `CAREamist`.

### Changes Made

- **Added**: `lightning_utils.py`.
- **Modified**: `CAREamist`.

### Related Issues

Link to any related issues or discussions. Use keywords like "Fixes",
"Resolves", or "Closes" to link to issues automatically.

- Resolves CAREamics#252

### Additional Notes and Examples

An alternative path would have been to add a `Callback` and do the
logging ourselves. I decided for the solution that uses the `csv` file
that is anyway created by default (when there is no WandB or TB
loggers), to minimize the code that needs to be maintained.

One potential issue is the particular csv file read is chosen following
the experiment name recorded by `CAREamist` and the last `version_*`.
This may not be true if the paths have changed, but in most cases it
should be valid if called right after training.

Here is what it looks like in the notebooks:
``` python
import matplotlib.pyplot as plt

losses = careamist.get_losses()

plt.plot(losses["train_epoch"], losses["train_loss"], label="Train Loss")
plt.plot(losses["val_epoch"], losses["val_loss"], label="Val Loss")
``` 

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

* Refac: Rename config file to careamics.yaml in BMZ export (CAREamics#271)

### Description

- **What**: When exporting to bmz the config file is now called
`careamics.yaml`. Searching for the config file during loading has also
been made more restrictive: previously the function searched for any
`.yml` file in the attachments and now it searches specifically for
`careamics.yaml`.
- **Why**: Renaming the file makes it clearer to users it relates to
CAREamics' functionality and should prevent any future name clashes with
other tools. The config file loading was made more restrictive because
it was not very robust to possible cases where additional attachments
are used, and using the `export_to_bmz` function doesn't allow any
choice in the name of the config file.
- **How**: Modified config path in `export_to_bmz` and modified config
path search in `extract_model_path`.

### Changes Made

- **Modified**: 
  - `export_to_bmz`
  - `extract_model_path`
  - `save_configuration` docs

### Related Issues

- Resolves CAREamics#269 

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

* Feature: Load model from BMZ using URL (CAREamics#273)

### Description

- **What**: Now it is possible to pass a URL to the `load_from_bmz`
function to download and load BMZ files.
- **Why**: Not many users will have access to the model resource URLs,
but this functionality is useful for developing the CAREamics BMZ
compatibility script.
- **How**: 
- Type hint `path` as also `pydantic.HttpUrl` in `load_from_bmz` (as in
`bioimage.core`);
- Remove `path` validation checks from `load_from_bmz` and allow it to
be handled in `load_model_description`
- Call `download` on the file resources to download and get the correct
path.

### Changes Made

- **Modified**: 
  - `load_from_bmz`
  - `extract_model_path`

### Additional Notes and Examples

This will have merge conflicts with CAREamics#271.

There are currently no official tests (it does work), we can discuss
using the URL of one of the existing CAREamics models uploaded to the
BMZ or create a Mock.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

* feat: Enable prediction step during training (CAREamics#266)

### Description

Following CAREamics#148, I have been
exploring how to predict during training. This PR would allow adding
`Callback` that use `predict_step` during Training.

- **What**: Allow callbacks to call `predict_step` during training.
- **Why**: Some applications might require predicting consistently on
full images to assess training performances throughout training.
- **How**: Modified `FCNModule.predict_step` to make it compatible with
a `TrainDataModule` (all calls to `trainer.datamodule` were written with
the expectation that it returns a `PredictDataModule`.

### Changes Made

- **Modified**: `lightning_module.py`, `test_lightning_module.py`

### Related Issues

- Resolves CAREamics#148

### Additional Notes and Examples

```python
    import numpy as np
    from pytorch_lightning import Callback, Trainer

    from careamics import CAREamist, Configuration
    from careamics.lightning import PredictDataModule, create_predict_datamodule
    from careamics.prediction_utils import convert_outputs

    config = Configuration(**minimum_configuration)

    class CustomPredictAfterValidationCallback(Callback):
        def __init__(self, pred_datamodule: PredictDataModule):
            self.pred_datamodule = pred_datamodule

            # prepare data and setup
            self.pred_datamodule.prepare_data()
            self.pred_datamodule.setup()
            self.pred_dataloader = pred_datamodule.predict_dataloader()

        def on_validation_epoch_end(self, trainer: Trainer, pl_module):
            if trainer.sanity_checking:  # optional skip
                return

            # update statistics in the prediction dataset for coherence
            # (they can computed on-line by the training dataset)
            self.pred_datamodule.predict_dataset.image_means = (
                trainer.datamodule.train_dataset.image_stats.means
            )
            self.pred_datamodule.predict_dataset.image_stds = (
                trainer.datamodule.train_dataset.image_stats.stds
            )

            # predict on the dataset
            outputs = []
            for idx, batch in enumerate(self.pred_dataloader):
                batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0)
                outputs.append(pl_module.predict_step(batch, batch_idx=idx))

           data = convert_outputs(outputs, self.pred_datamodule.tiled)
           # can save data here

    array = np.arange(32 * 32).reshape((32, 32))
    pred_datamodule = create_predict_datamodule(
        pred_data=array,
        data_type=config.data_config.data_type,
        axes=config.data_config.axes,
        image_means=[11.8], # random placeholder
        image_stds=[3.14],
        # can do tiling here
    )

    predict_after_val_callback = CustomPredictAfterValidationCallback(
        pred_datamodule=pred_datamodule
    )
    engine = CAREamist(config, callbacks=[predict_after_val_callback])
    engine.train(train_source=array)
```



Currently, this current implementation is not fully satisfactory and
here are a few important points:

- For this PR to work we need to discriminate between `TrainDataModule`
and `PredictDataModule` in `predict_step`, which is a bit of a hack as
it currently check `hasattr(..., "tiled")`. The reason is to avoid a
circular import of `PredictDataModule`. We should revisit that.
- `TrainDataModule` and `PredictDataModule` have incompatible members:
`PredictDataModule` has `.tiled`, and the two have different naming
conventions for the statistics (`PredictDataModule` has `image_means`
and `image_stds`, while `TrainDataModule` has them wrapped in a `stats`
dataclass). These statistics are retrieved either through
`_trainer.datamodule.predict_dataset` or
`_trainer.datamodule.train_dataset`.
- We do not provide the `Callable` that would allow to use such feature.
We might want to some heavy lifting here as well (see example).
- Probably the most serious issue, normalization is done in the datasets
but denormalization is performed in the `predict_step`. In our case,
that means that normalization could be applied by a `PredictDataModule`
(in the `Callback` and the denormalization by the `TrainDataModule` (in
`predict_step`). That is incoherent and due to the way we wrote
CAREamics.

All in all, this draft exemplifies two problems with CAREamics:
- `TrainDataModule` and `PredictDataModule` have different members
- Normalization is done by the `DataModule` but denormalization by
`LightningModule`

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [x] PR to the documentation exists (for bug fixes / features)

* feat: added functions to load neuron&astrocytes dataset

* refac: added possibility to pick `kl_restricted` loss & other loss-related refactoring (CAREamics#272)

### Description

In some LVAE training examples (see `microSplit_reproducibility` repo)
there is the need to consider the *restricted KL loss*, instead of the
simple sample-wise one. This PR allows the user to pick that one.

- **What**: The KL loss type is no longer hardcoded in the loss
functions. Now it is possible to pick also the `restricted_kl` KL loss
type.
- **Why**: It is needed for some experiments/examples.
- **How**:  added an input parameter to the KL loss functions.

### Breaking changes

The `kl_type` parameter is added in the loss functions, so we need to be
careful of correctly specifying it in the examples in the
`microSplit_reproducibility` repo.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

* tmp: function to load 3D data, extract 2D slices and save them separately

* feat: updates to handle 2D slices

* refac: removed main + improved output of laoding function

* A new enum for a new splitting task.  (CAREamics#270)

### Description
Adding a new enum type for a splitting task which I had missed
communicating earlier.

I am putting the relevant things in microsplit-reproducibility repo
after a brief chat with @veegalinova.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>

* feat: loading supported also for unmixed images

* refac: made `get_fnames` public + changed way GroupType is obtained

* fix: bug

* updated content of examples folder

* Fix(BMZ): Relax model output validation kwargs; extract weights and config file following new `spec` and `core` release (CAREamics#279)

### Description

- **What**: Relaxing the model output validation kwargs, both absolute
and relative tolerance, from the default, `1e-4`, to `1e-2`.
- **Why**: The defaults are pretty strict and some of our uploaded
models are stuck in pending because of slightly mismatching input and
outputs.
- e.g. (Actually maybe absolute tolerance should be put to 0, otherwise
it still might not pass after this PR)
  ```console
  Output and expected output disagree:
    Not equal to tolerance rtol=0.0001, atol=0.00015
  Mismatched elements: 40202 / 1048576 (3.83%)
  Max absolute difference: 0.1965332
  Max relative difference: 0.0003221
  ```
- **How**: In the model description config param, added the new test
kwargs. Additionally, updated `bmz_export` so that the test kwargs in
the model description are used during model testing at export time.

### Changes Made

- **Modified**: Describe existing features or files modified.
  - `create_model_description`: added test_kwargs to config param
- `export_to_bmz`: use test_kwargs in model description for model
testing at export time.

### Related Issues

- Resolves 
  - last checkbox in CAREamics#278 

EDIT:
This PR also fixes loading from BMZ following an incompatible release of
`bioimageio/core` (`0.7.0`) and `bioimageio/spec` (`0.5.3.5`). The
problem was `load_model_description` no longer unzips the archive file
but only streams the `rdf.yaml` file data. This means we have to now
extract the weights and careamics config from the zip to load them,
which can be done using
`bioimageio.spec._internal.io.resolve_and_extract`

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>

* refac: renamed and moved file to read CZI

* style: added progress bar to data loading function

* fix: fixed FP info fetching for dyes

* updated training for astro_neuron dset

* added script to train λSplit on astrocytes data

* updated training examples

* Fix(dependencies): Set bioimageio-core version greater than 0.7.0 (CAREamics#280)

### Description

- **What**: Set bioimageio-core version greater than 0.7.0
- **Why**: Following the new `bioimage-core` release (0.7.0), we needed
to make some fixes (part of PR CAREamics#279). The most convenient function to
solve this problem, `resolve_and_extract` only exists since 0.7.0.
- **How**: In pyproject.toml

### Changes Made

- **Modified**: pyproject.toml

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

* added lambda parameters to saved configs

* added todos

* wip: eval examples for astro neuron dset

* fix: fixed a bug in KL loss aggregation (LVAE) (CAREamics#277)

### Description

Found a bug in the KL loss aggregation happening in the `LadderVAE`
model `training_step()`.
Specifically, the application of free-bits (`free_bits_kl()`, basically
clamping the values of KL entries to a certain lower threshold) was
happening after KL entries were rescaled. In this way, when free-bits
threshold was set to 1, all the KL entries were clamped to 1, as
normally way smaller than this.

- **What**: See above.
- **Why**: Clear bug in the code.
- **How**: Inverted the order of calls in the `get_kl_divergence_loss()`
function & adjusted some parts of the code to reflect the changes.
---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>

* feat: added command to sort FPs by wavelength at peak intensity before creating FP ref matrix

* feat: function to split fnames in train and test

* fix: ONNX exportability compatibity test and fix (CAREamics#275)

fix: Fix for ONNX export of Maxblurpool layer and performance
optimization by registering kernel as a buffer so that it doesn't need
to be copied to the GPU over and over again.
### Description
- **What**: Converting the pretrained models to ONNX format gives error
in the Maxpool layer used in the N2V2 architecture.This is mainly
because the convolution kernel is dynamically expanded to a size
matching the number of channels in the input in the Maxblurpool layer.
But the number of channels should be constant within the model.
- **Why**: Users can convert the pytorch models to ONNX for inference in
thier platforms
- **How**:  
-- instead of using the symbolic variable x.size(1), explicitly cast it
to an integer and make it a constant.
-- make the kernel as a buffer to avoid the copying to GPU overhead.
-- add tests for ONNX exportability
### Changes Made

- **Added**:  
-- onnx as a test dependency in pyproject.toml
--  'test_lightning_module_onnx_exportability.py'
- **Modified**: Maxblurpool module in  'layers.py'

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>

* refac: changed the way train and test samples are taken

* feat: CL arg to avoid logging (for debugging runs)

* rm: removed outdated modules

* fix: set `None` as default for `custom_logger` when logging is disables

* fix: fixed bugs resulting from previous merge

* fix: fixed more bugs related to previous merge + renamed `algorithm_type` into more coherent `training_mode`

* fix: made a few changes to mirror updates in the model code

* refac: modified function to read CZI to make it compatible with CAREamics dsets

* refac: updated loading pipeline for 3D images (added padding to have same Z-dim)

* fix: changed some parts of the code to allow 3D training

* updated training script for 3D case

* fix: deleting pre-loaded arrays of data after init dsets

* feat: added fn to sort fnames by exp ID and img ID

* refac: adjusted training script for 2D and loading all data in memory

* refac: added function to load 3D imgs callable in InMemoryDataset class

* feat: allowing passing kwargs to `read_source_fn`

* example: implemented more efficient training pipeline for 3D data

* style: cleaned outputs

* refac: renamed func `get_train_test_fnames` into `split_train_test_fnames`

* fix: solving issue of read_source_kwargs==None in patching

* fix: removed `dataloader_params` from serializable attributes in `DataConfig`

* refac: added dataloader_params to `InferenceConfig` to match organization of `DataConfig`

* rm: removed visualization funcs for λSplit, will be moved in apposite experiment repo

* rm: example scripts and notebooks

---------

Co-authored-by: Vera Galinova <32124316+veegalinova@users.noreply.github.com>
Co-authored-by: Igor Zubarev <zubarev.ia@gmail.com>
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: melisande-c <milly.croft@gmail.com>
Co-authored-by: federico-carrara <federico.carrara@fht.org>
Co-authored-by: Melisande Croft <63270704+melisande-c@users.noreply.github.com>
Co-authored-by: ashesh <ashesh276@gmail.com>
Co-authored-by: nimiiit <nimiviswants@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants