From e78d4096bb6ac7bc8074558eba521be581a37d85 Mon Sep 17 00:00:00 2001 From: Joern Weissenborn Date: Thu, 23 Feb 2023 23:31:22 +0100 Subject: [PATCH] Added basic data pre-processing pipeline. --- glotaran/io/preprocessor/preprocessor.py | 82 ++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 glotaran/io/preprocessor/preprocessor.py diff --git a/glotaran/io/preprocessor/preprocessor.py b/glotaran/io/preprocessor/preprocessor.py new file mode 100644 index 000000000..15645affc --- /dev/null +++ b/glotaran/io/preprocessor/preprocessor.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import abc +from typing import Annotated +from typing import Literal + +try: + from typing import Self +except ImportError: + Self = "PreProcessor" + +import xarray as xr +from pydantic import BaseModel +from pydantic import Field + + +class PreProcessor(BaseModel, abc.ABC): + class Config: + arbitrary_types_allowed = True + + @abc.abstractmethod + def apply(self, data: xr.DataArray) -> xr.DataArray: + pass + + +class CorrectBaselineValue(PreProcessor): + action: Literal["baseline-value"] = "baseline-value" + value: float + + def apply(self, data: xr.DataArray) -> xr.DataArray: + return data - self.value + + +class CorrectBaselineAverage(PreProcessor): + action: Literal["baseline-average"] = "baseline-average" + average: dict[str, slice] + + def apply(self, data: xr.DataArray) -> xr.DataArray: + selection = data + for axis, interval in self.average.items(): + selection = selection.sel({axis: interval}) + return data - (selection.sum() / selection.size) + + +# TODO +# class ShiftAlongAxis(PreProcessorAction): +# action: Literal["shift-along-axis"] = "shift-along-axis" +# axis: str +# value: float | list[float] + + +# TODO +# class AverageAxis(PreProcessorAction): +# action: Literal["average"] = "average" +# new_coord: Literal["first", "last", "mean"] = "mean" +# axis: str +# number: int + + +class PreProcessorActions(BaseModel): + __root__: Annotated[ + CorrectBaselineValue | CorrectBaselineAverage, + Field(discriminator="action"), + ] + + +class PreProcessingPipeline: + def __init__(self, original: xr.DataArray, actions: list[PreProcessor] | None = None): + self._actions: list[PreProcessor] = actions or [] + + def pply(self, original: xr.DataArray) -> xr.DataArray: + result = original.copy() + for action in self._actions: + result = action.apply(result) + return result + + def _push_action(self, action: PreProcessor): + self._actions.append(action) + + def correct_baseline_value(self, value: float) -> Self: + self._push_action(CorrectBaselineValue(value=value)) + return self