-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added basic data pre-processing pipeline.
- Loading branch information
1 parent
2098c83
commit 9210faa
Showing
6 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""Tools for data pre-processing.""" | ||
from glotaran.io.preprocessor.pipeline import PreProcessingPipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""A pre-processor pipeline for data.""" | ||
from __future__ import annotations | ||
|
||
from typing import Annotated | ||
|
||
import xarray as xr | ||
from pydantic import BaseModel | ||
from pydantic import Field | ||
|
||
from glotaran.io.preprocessor.preprocessor import CorrectBaselineAverage | ||
from glotaran.io.preprocessor.preprocessor import CorrectBaselineValue | ||
|
||
PipelineAction = Annotated[ | ||
CorrectBaselineValue | CorrectBaselineAverage, | ||
Field(discriminator="action"), | ||
] | ||
|
||
|
||
class PreProcessingPipeline(BaseModel): | ||
"""A pipeline for pre-processors.""" | ||
|
||
actions: list[PipelineAction] = Field(default_factory=list) | ||
|
||
def apply(self, original: xr.DataArray) -> xr.DataArray: | ||
"""Apply all pre-processors on data. | ||
Parameters | ||
---------- | ||
original: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
""" | ||
result = original.copy() | ||
|
||
for action in self.actions: | ||
result = action.apply(result) | ||
return result | ||
|
||
def _push_action(self, action: PipelineAction): | ||
"""Push an action. | ||
Parameters | ||
---------- | ||
action: PipelineAction | ||
The action to push. | ||
""" | ||
self.actions.append(action) | ||
|
||
def correct_baseline_value(self, value: float) -> PreProcessingPipeline: | ||
"""Correct a dataset by subtracting baseline value. | ||
Parameters | ||
---------- | ||
value: float | ||
The value to subtract. | ||
Returns | ||
------- | ||
PreProcessingPipeline | ||
""" | ||
self._push_action(CorrectBaselineValue(value=value)) | ||
return self | ||
|
||
def correct_baseline_average( | ||
self, selection: dict[str, slice | list[int] | int] | ||
) -> PreProcessingPipeline: | ||
"""Correct a dataset by subtracting the average over a part of the data. | ||
Parameters | ||
---------- | ||
selection: dict[str, slice | list[int] | int] | ||
The selection to average as dictionary of dimension and indexer. | ||
The indexer can be a slice, a list or an integer value. | ||
Returns | ||
------- | ||
PreProcessingPipeline | ||
""" | ||
self._push_action(CorrectBaselineAverage(selection=selection)) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""A pre-processor pipeline for data.""" | ||
from __future__ import annotations | ||
|
||
import abc | ||
from typing import Literal | ||
|
||
import xarray as xr | ||
from pydantic import BaseModel | ||
|
||
|
||
class PreProcessor(BaseModel, abc.ABC): | ||
"""A base class for pre=processors.""" | ||
|
||
class Config: | ||
"""Config for BaseModel.""" | ||
|
||
arbitrary_types_allowed = True | ||
|
||
@abc.abstractmethod | ||
def apply(self, data: xr.DataArray) -> xr.DataArray: | ||
"""Apply the pre-processor. | ||
Parameters | ||
---------- | ||
data: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
.. # noqa: DAR202 | ||
""" | ||
|
||
|
||
class CorrectBaselineValue(PreProcessor): | ||
"""Corrects a dataset by subtracting baseline value.""" | ||
|
||
action: Literal["baseline-value"] = "baseline-value" | ||
value: float | ||
|
||
def apply(self, data: xr.DataArray) -> xr.DataArray: | ||
"""Apply the pre-processor. | ||
Parameters | ||
---------- | ||
data: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
""" | ||
return data - self.value | ||
|
||
|
||
class CorrectBaselineAverage(PreProcessor): | ||
"""Corrects a dataset by subtracting the average over a part of the data.""" | ||
|
||
action: Literal["baseline-average"] = "baseline-average" | ||
selection: dict[str, slice | list[int] | int] | ||
|
||
def apply(self, data: xr.DataArray) -> xr.DataArray: | ||
"""Apply the pre-processor. | ||
Parameters | ||
---------- | ||
data: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
""" | ||
selection = data.sel(self.selection) | ||
return data - (selection.sum() / selection.size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
import xarray as xr | ||
|
||
from glotaran.io.preprocessor import PreProcessingPipeline | ||
|
||
|
||
def test_correct_baseline_value(): | ||
pl = PreProcessingPipeline() | ||
pl.correct_baseline_value(1) | ||
data = xr.DataArray([[1]]) | ||
result = pl.apply(data) | ||
assert result == data - 1 | ||
|
||
|
||
@pytest.mark.parametrize("indexer", (slice(0, 2), [0, 1])) | ||
def test_correct_baseline_average(indexer: slice | list[int]): | ||
pl = PreProcessingPipeline() | ||
pl.correct_baseline_average({"dim_0": 0, "dim_1": indexer}) | ||
data = xr.DataArray([[1.1, 0.9]]) | ||
result = pl.apply(data) | ||
assert (result == data - 1).all() | ||
|
||
|
||
def test_to_from_dict(): | ||
pl = PreProcessingPipeline() | ||
pl.correct_baseline_value(1) | ||
pl.correct_baseline_average({"dim_1": slice(0, 2)}) | ||
pl_dict = pl.dict() | ||
assert pl_dict == { | ||
"actions": [ | ||
{"action": "baseline-value", "value": 1.0}, | ||
{"action": "baseline-average", "selection": {"dim_1": slice(0, 2)}}, | ||
] | ||
} | ||
assert PreProcessingPipeline.parse_obj(pl_dict) == pl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters