Skip to content

Commit

Permalink
Noise models for paper notebooks (#285)
Browse files Browse the repository at this point in the history
### Description

Small refac for submission notebooks. Training code works if called
explicitly. Training inside careamics needs further engineering.
NM code needs refactoring

### Changes Made

- **Added**: 
- Basic training
- plotting function


**Please ensure your PR meets the following requirements:**

- [x ] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x ] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: federico-carrara <federico1.carrara@mail.polimi.it>
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 3, 2024
1 parent 19d9203 commit b4fa28f
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 141 deletions.
48 changes: 24 additions & 24 deletions src/careamics/config/nm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
Field,
PlainSerializer,
PlainValidator,
model_validator,
)
from typing_extensions import Annotated, Self
from typing_extensions import Annotated

from careamics.utils.serializers import _array_to_json, _to_numpy

Expand Down Expand Up @@ -90,28 +89,29 @@ class GaussianMixtureNMConfig(BaseModel):
tol: float = Field(default=1e-10)
"""Tolerance used in the computation of the noise model likelihood."""

@model_validator(mode="after")
def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
"""Validate paths provided in the config.
Returns
-------
Self
Returns itself.
"""
if self.path and (self.signal is not None or self.observation is not None):
raise ValueError(
"Either only 'path' to pre-trained noise model should be"
"provided or only signal and observation in form of paths"
"or numpy arrays."
)
if not self.path and (self.signal is None or self.observation is None):
raise ValueError(
"Either only 'path' to pre-trained noise model should be"
"provided or only signal and observation in form of paths"
"or numpy arrays."
)
return self
# @model_validator(mode="after")
# def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
# """Validate paths provided in the config.

# Returns
# -------
# Self
# Returns itself.
# """
# if self.path and (self.signal is not None or self.observation is not None):
# raise ValueError(
# "Either only 'path' to pre-trained noise model should be"
# "provided or only signal and observation in form of paths"
# "or numpy arrays."
# )
# if not self.path and (self.signal is None or self.observation is None):
# raise ValueError(
# "Either only 'path' to pre-trained noise model should be"
# "provided or only signal and observation in form of paths"
# "or numpy arrays."
# )
# return self
# TODO revisit validation


# The noise model is given by a set of GMMs, one for each target
Expand Down
3 changes: 2 additions & 1 deletion src/careamics/config/vae_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def algorithm_cross_validation(self: Self) -> Self:
"Algorithm `denoisplit` with loss `denoisplit` only supports "
"`predict_logvar` as `None`."
)

if self.noise_model is None:
raise ValueError("Algorithm `denoisplit` requires a noise model.")
# TODO: what if algorithm is not musplit or denoisplit (HDN?)
# TODO: what if algorithm is not musplit or denoisplit
return self

@model_validator(mode="after")
Expand Down
5 changes: 3 additions & 2 deletions src/careamics/lvae_training/dataset/multifile_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Union, Callable, Sequence
from collections.abc import Sequence
from typing import Callable, Union

import numpy as np
from numpy.typing import NDArray

from .config import DatasetConfig
from .lc_dataset import LCMultiChDloader
from .multich_dataset import MultiChDloader
from .types import DataSplitType
from .lc_dataset import LCMultiChDloader


class TwoChannelData(Sequence):
Expand Down
Loading

0 comments on commit b4fa28f

Please sign in to comment.