forked from CAREamics/careamics
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refac: refactored serialization of arrays for noise model and likelih…
…ood configs (CAREamics#232) ### Description - **What**: Implemented logic to serialize arrays defined within pydantic models (specifically with reference to `nm_model.py` and `likelihood_model.py`). - **Why**: Because during training we need to save info in the configs. Specifically, since arrays cannot be deserialized by default we need to decide whether to keep them or not. - **How**: Excluded large arrays from serialization, wrote custom serializers for others. ### Changes Made - **Added**: - Custom serializer `array_to_json()`. - 2 different custom deserializers: `list_to_numpy()`, `list_to_torch()`. - **Modified**: Excluded some arrays from serialization. - **Removed**: None. ### For further discussion NOTE1: why deserializer takes in a list? Because in our use case we need deserializers to move config files usually stored as dicts into pydantic models. But, such dicts are often the result of loading config file mainly stored as `json`, `pkl`, or `yml`. The loader for these file types automatically deserialize strings into lists. That's why what is left to do for us is to move lists into arrays or tensors. NOTE2: why 2 different deserializers? Because in some cases we want to deserialize list to torch tensor and in some other cases we prefer numpy arrays. After discussing with @jdeschamps we realized that some large arrays should not be part of configs (e.g., `signal`, `observation` in `nm_model.py`). Therefore, I left TODOs to remind about discussing this for future refactoring. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
- Loading branch information
1 parent
44aee3e
commit 1b60f07
Showing
4 changed files
with
177 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""A script for serializers in the careamics package.""" | ||
|
||
import ast | ||
import json | ||
from typing import Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
def _array_to_json(arr: Union[np.ndarray, torch.Tensor]) -> str: | ||
"""Convert an array to a list and then to a JSON string. | ||
Parameters | ||
---------- | ||
arr : Union[np.ndarray, torch.Tensor] | ||
Array to be serialized. | ||
Returns | ||
------- | ||
str | ||
JSON string representing the array. | ||
""" | ||
return json.dumps(arr.tolist()) | ||
|
||
|
||
def _to_numpy(lst: Union[str, list]) -> np.ndarray: | ||
"""Deserialize a list or string representing a list into `np.ndarray`. | ||
Parameters | ||
---------- | ||
lst : list | ||
List or string representing a list with the array content to be deserialized. | ||
Returns | ||
------- | ||
np.ndarray | ||
The deserialized array. | ||
""" | ||
if isinstance(lst, str): | ||
lst = ast.literal_eval(lst) | ||
return np.asarray(lst) | ||
|
||
|
||
def _to_torch(lst: Union[str, list]) -> torch.Tensor: | ||
"""Deserialize list or string representing a list into `torch.Tensor`. | ||
Parameters | ||
---------- | ||
lst : Union[str, list] | ||
List or string representing a list swith the array content to be deserialized. | ||
Returns | ||
------- | ||
torch.Tensor | ||
The deserialized tensor. | ||
""" | ||
if isinstance(lst, str): | ||
lst = ast.literal_eval(lst) | ||
return torch.tensor(lst) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import json | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
from careamics.config.likelihood_model import Tensor | ||
from careamics.config.nm_model import Array | ||
|
||
|
||
class MyArray(BaseModel): | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
arr: Array | ||
|
||
|
||
class MyTensor(BaseModel): | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
arr: Tensor | ||
|
||
|
||
@pytest.mark.parametrize("arr", [np.array([1, 2]), torch.tensor([1, 2])]) | ||
def test_serialize_array(arr: Union[np.ndarray, torch.Tensor]): | ||
"""Test array_to_json function.""" | ||
arr_model = MyArray(arr=arr) | ||
assert arr_model.model_dump() == {"arr": "[1, 2]"} | ||
|
||
|
||
@pytest.mark.parametrize("arr", [np.array([1, 2]), torch.tensor([1, 2])]) | ||
def test_serialize_tensor(arr: Union[np.ndarray, torch.Tensor]): | ||
"""Test array_to_json function.""" | ||
arr_model = MyTensor(arr=arr) | ||
assert arr_model.model_dump() == {"arr": "[1, 2]"} | ||
|
||
|
||
def test_deserialize_array(tmp_path: Path): | ||
"""Test list_to_numpy function.""" | ||
arr_model = MyArray(arr=np.array([1, 2])) | ||
# save to JSON | ||
with open(tmp_path / "array_config.json", "w") as f: | ||
f.write(arr_model.model_dump_json()) | ||
# load from JSON | ||
with open(tmp_path / "array_config.json") as f: | ||
config = json.load(f) | ||
new_arr_model = MyArray(**config) | ||
assert np.array_equal(new_arr_model.arr, np.array([1, 2])) | ||
|
||
|
||
def test_deserialize_tensor(tmp_path: Path): | ||
"""Test list_to_tensor function.""" | ||
arr_model = MyTensor(arr=torch.tensor([1, 2])) | ||
# save to JSON | ||
with open(tmp_path / "tensor_config.json", "w") as f: | ||
f.write(arr_model.model_dump_json()) | ||
# load from JSON | ||
with open(tmp_path / "tensor_config.json") as f: | ||
config = json.load(f) | ||
new_arr_model = MyTensor(**config) | ||
assert torch.equal(new_arr_model.arr, torch.tensor([1, 2])) |