-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added basic data pre-processing pipeline.
- Loading branch information
1 parent
2098c83
commit d49cf9d
Showing
6 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""Tools for data pre-processing.""" | ||
from glotaran.io.preprocessor.pipeline import PreProcessingPipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""A pre-processor pipeline for data.""" | ||
from __future__ import annotations | ||
|
||
from typing import Annotated | ||
|
||
import xarray as xr | ||
from pydantic import BaseModel | ||
from pydantic import Field | ||
|
||
from glotaran.io.preprocessor.preprocessor import CorrectBaselineAverage | ||
from glotaran.io.preprocessor.preprocessor import CorrectBaselineValue | ||
|
||
PipelineAction = Annotated[ | ||
CorrectBaselineValue | CorrectBaselineAverage, | ||
Field(discriminator="action"), | ||
] | ||
|
||
|
||
class PreProcessingPipeline(BaseModel): | ||
"""A pipeline for pre-processors.""" | ||
|
||
actions: list[PipelineAction] = Field(default_factory=list) | ||
|
||
def apply(self, original: xr.DataArray) -> xr.DataArray: | ||
"""Apply all pre-processors on data. | ||
Parameters | ||
---------- | ||
original: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
""" | ||
result = original.copy() | ||
|
||
for action in self.actions: | ||
result = action.apply(result) | ||
return result | ||
|
||
def _push_action(self, action: PipelineAction): | ||
"""Push an action. | ||
Parameters | ||
---------- | ||
action: PipelineAction | ||
The action to push. | ||
""" | ||
self.actions.append(action) | ||
|
||
def correct_baseline_value(self, value: float) -> PreProcessingPipeline: | ||
"""Correct a dataset by subtracting baseline value. | ||
Parameters | ||
---------- | ||
value: float | ||
The value to subtract. | ||
Returns | ||
------- | ||
PreProcessingPipeline | ||
""" | ||
self._push_action(CorrectBaselineValue(value=value)) | ||
return self | ||
|
||
def correct_baseline_average(self, interval: dict[str, slice]) -> PreProcessingPipeline: | ||
"""Correct a dataset by subtracting an average of the data. | ||
Parameters | ||
---------- | ||
interval: dict[str, slice] | ||
The intervals to average. | ||
Returns | ||
------- | ||
PreProcessingPipeline | ||
""" | ||
self._push_action(CorrectBaselineAverage(interval=interval)) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""A pre-processor pipeline for data.""" | ||
from __future__ import annotations | ||
|
||
import abc | ||
from typing import Literal | ||
|
||
import xarray as xr | ||
from pydantic import BaseModel | ||
|
||
|
||
class PreProcessor(BaseModel, abc.ABC): | ||
"""A base class for pre=processors.""" | ||
|
||
class Config: | ||
"""Config for BaseModel.""" | ||
|
||
arbitrary_types_allowed = True | ||
|
||
@abc.abstractmethod | ||
def apply(self, data: xr.DataArray) -> xr.DataArray: | ||
"""Apply the pre-processor. | ||
Parameters | ||
---------- | ||
data: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
.. # noqa: DAR202 | ||
""" | ||
|
||
|
||
class CorrectBaselineValue(PreProcessor): | ||
"""Corrects a dataset by subtracting baseline value.""" | ||
|
||
action: Literal["baseline-value"] = "baseline-value" | ||
value: float | ||
|
||
def apply(self, data: xr.DataArray) -> xr.DataArray: | ||
"""Apply the pre-processor. | ||
Parameters | ||
---------- | ||
data: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
""" | ||
return data - self.value | ||
|
||
|
||
class CorrectBaselineAverage(PreProcessor): | ||
"""Corrects a dataset by subtracting the average over a part of the data.""" | ||
|
||
action: Literal["baseline-average"] = "baseline-average" | ||
interval: dict[str, slice] | ||
|
||
def apply(self, data: xr.DataArray) -> xr.DataArray: | ||
"""Apply the pre-processor. | ||
Parameters | ||
---------- | ||
data: xr.DataArray | ||
The data to process. | ||
Returns | ||
------- | ||
xr.DataArray | ||
""" | ||
selection = data | ||
for axis, interval in self.interval.items(): | ||
selection = selection.sel({axis: interval}) | ||
return data - (selection.sum() / selection.size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import xarray as xr | ||
|
||
from glotaran.io.preprocessor import PreProcessingPipeline | ||
|
||
|
||
def test_correct_baseline_value(): | ||
pl = PreProcessingPipeline() | ||
pl.correct_baseline_value(1) | ||
data = xr.DataArray([[1]]) | ||
result = pl.apply(data) | ||
assert result == data - 1 | ||
|
||
|
||
def test_correct_baseline_average(): | ||
pl = PreProcessingPipeline() | ||
pl.correct_baseline_average({"dim_1": [0, 1]}) | ||
data = xr.DataArray([[1.1, 0.9]]) | ||
result = pl.apply(data) | ||
assert (result == data - 1).all() | ||
|
||
|
||
def test_to_from_dict(): | ||
pl = PreProcessingPipeline() | ||
pl.correct_baseline_value(1) | ||
pl.correct_baseline_average({"dim_1": [0, 1]}) | ||
pl_dict = pl.dict() | ||
assert pl_dict == { | ||
"actions": [ | ||
{"action": "baseline-value", "value": 1.0}, | ||
{"action": "baseline-average", "interval": {"dim_1": [0, 1]}}, | ||
] | ||
} | ||
assert PreProcessingPipeline.parse_obj(pl_dict) == pl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters