From fecbfdf66b8823ced7509c1a3b7ed56055a1aa95 Mon Sep 17 00:00:00 2001 From: Joern Weissenborn Date: Thu, 23 Feb 2023 23:31:22 +0100 Subject: [PATCH] Added basic data pre-processing pipeline. --- glotaran/io/preprocessor/__init__.py | 2 + glotaran/io/preprocessor/pipeline.py | 80 +++++++++++++++++++ glotaran/io/preprocessor/preprocessor.py | 78 ++++++++++++++++++ .../io/preprocessor/test/test_preprocessor.py | 33 ++++++++ requirements_dev.txt | 1 + setup.cfg | 1 + 6 files changed, 195 insertions(+) create mode 100644 glotaran/io/preprocessor/__init__.py create mode 100644 glotaran/io/preprocessor/pipeline.py create mode 100644 glotaran/io/preprocessor/preprocessor.py create mode 100644 glotaran/io/preprocessor/test/test_preprocessor.py diff --git a/glotaran/io/preprocessor/__init__.py b/glotaran/io/preprocessor/__init__.py new file mode 100644 index 000000000..f419c7669 --- /dev/null +++ b/glotaran/io/preprocessor/__init__.py @@ -0,0 +1,2 @@ +"""Tools for data pre-processing.""" +from glotaran.io.preprocessor.pipeline import PreProcessingPipeline diff --git a/glotaran/io/preprocessor/pipeline.py b/glotaran/io/preprocessor/pipeline.py new file mode 100644 index 000000000..554ec1157 --- /dev/null +++ b/glotaran/io/preprocessor/pipeline.py @@ -0,0 +1,80 @@ +"""A pre-processor pipeline for data.""" +from __future__ import annotations + +from typing import Annotated + +import xarray as xr +from pydantic import BaseModel +from pydantic import Field + +from glotaran.io.preprocessor.preprocessor import CorrectBaselineAverage +from glotaran.io.preprocessor.preprocessor import CorrectBaselineValue + +PipelineAction = Annotated[ + CorrectBaselineValue | CorrectBaselineAverage, + Field(discriminator="action"), +] + + +class PreProcessingPipeline(BaseModel): + """A pipeline for pre-processors.""" + + actions: list[PipelineAction] = Field(default_factory=list) + + def apply(self, original: xr.DataArray) -> xr.DataArray: + """Apply all pre-processors on data. + + Parameters + ---------- + original: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + """ + result = original.copy() + + for action in self.actions: + result = action.apply(result) + return result + + def _push_action(self, action: PipelineAction): + """Push an action. + + Parameters + ---------- + action: PipelineAction + The action to push. + """ + self.actions.append(action) + + def correct_baseline_value(self, value: float) -> PreProcessingPipeline: + """Correct a dataset by subtracting baseline value. + + Parameters + ---------- + value: float + The value to subtract. + + Returns + ------- + PreProcessingPipeline + """ + self._push_action(CorrectBaselineValue(value=value)) + return self + + def correct_baseline_average(self, interval: dict[str, list[int]]) -> PreProcessingPipeline: + """Correct a dataset by subtracting an average of the data. + + Parameters + ---------- + interval: dict[str, list[int]] + The intervals to average. + + Returns + ------- + PreProcessingPipeline + """ + self._push_action(CorrectBaselineAverage(interval=interval)) + return self diff --git a/glotaran/io/preprocessor/preprocessor.py b/glotaran/io/preprocessor/preprocessor.py new file mode 100644 index 000000000..f572803a5 --- /dev/null +++ b/glotaran/io/preprocessor/preprocessor.py @@ -0,0 +1,78 @@ +"""A pre-processor pipeline for data.""" +from __future__ import annotations + +import abc +from typing import Literal + +import xarray as xr +from pydantic import BaseModel + + +class PreProcessor(BaseModel, abc.ABC): + """A base class for pre=processors.""" + + class Config: + """Config for BaseModel.""" + + arbitrary_types_allowed = True + + @abc.abstractmethod + def apply(self, data: xr.DataArray) -> xr.DataArray: + """Apply the pre-processor. + + Parameters + ---------- + data: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + + .. # noqa: DAR202 + """ + + +class CorrectBaselineValue(PreProcessor): + """Corrects a dataset by subtracting baseline value.""" + + action: Literal["baseline-value"] = "baseline-value" + value: float + + def apply(self, data: xr.DataArray) -> xr.DataArray: + """Apply the pre-processor. + + Parameters + ---------- + data: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + """ + return data - self.value + + +class CorrectBaselineAverage(PreProcessor): + """Corrects a dataset by subtracting the average over a part of the data.""" + + action: Literal["baseline-average"] = "baseline-average" + interval: dict[str, list[int]] + + def apply(self, data: xr.DataArray) -> xr.DataArray: + """Apply the pre-processor. + + Parameters + ---------- + data: xr.DataArray + The data to process. + + Returns + ------- + xr.DataArray + """ + selection = data + for axis, interval in self.interval.items(): + selection = selection.sel({axis: interval}) + return data - (selection.sum() / selection.size) diff --git a/glotaran/io/preprocessor/test/test_preprocessor.py b/glotaran/io/preprocessor/test/test_preprocessor.py new file mode 100644 index 000000000..f88a3d97b --- /dev/null +++ b/glotaran/io/preprocessor/test/test_preprocessor.py @@ -0,0 +1,33 @@ +import xarray as xr + +from glotaran.io.preprocessor import PreProcessingPipeline + + +def test_correct_baseline_value(): + pl = PreProcessingPipeline() + pl.correct_baseline_value(1) + data = xr.DataArray([[1]]) + result = pl.apply(data) + assert result == data - 1 + + +def test_correct_baseline_average(): + pl = PreProcessingPipeline() + pl.correct_baseline_average({"dim_1": [0, 1]}) + data = xr.DataArray([[1.1, 0.9]]) + result = pl.apply(data) + assert (result == data - 1).all() + + +def test_to_from_dict(): + pl = PreProcessingPipeline() + pl.correct_baseline_value(1) + pl.correct_baseline_average({"dim_1": [0, 1]}) + pl_dict = pl.dict() + assert pl_dict == { + "actions": [ + {"action": "baseline-value", "value": 1.0}, + {"action": "baseline-average", "interval": {"dim_1": [0, 1]}}, + ] + } + assert PreProcessingPipeline.parse_obj(pl_dict) == pl diff --git a/requirements_dev.txt b/requirements_dev.txt index 628e590b1..19e3f4a4b 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -12,6 +12,7 @@ numpy==1.23.5 odfpy==1.4.1 openpyxl==3.1.1 pandas==1.5.3 +pydantic==1.10.2 rich==13.3.1 ruamel.yaml==0.17.21 scipy==1.10.0 diff --git a/setup.cfg b/setup.cfg index b08492677..3080b3f3c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ install_requires = odfpy>=1.4.1 openpyxl>=3.0.10 pandas>=1.3.4 + pydantic>=1.10.2 rich>=10.9.0 ruamel.yaml>=0.17.17 scipy>=1.7.2