Skip to content

Commit

Permalink
Added basic data pre-processing pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
joernweissenborn committed Feb 26, 2023
1 parent 2098c83 commit e78d409
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions glotaran/io/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import abc
from typing import Annotated
from typing import Literal

try:
from typing import Self
except ImportError:
Self = "PreProcessor"

import xarray as xr
from pydantic import BaseModel
from pydantic import Field


class PreProcessor(BaseModel, abc.ABC):
class Config:
arbitrary_types_allowed = True

@abc.abstractmethod
def apply(self, data: xr.DataArray) -> xr.DataArray:
pass


class CorrectBaselineValue(PreProcessor):
action: Literal["baseline-value"] = "baseline-value"
value: float

def apply(self, data: xr.DataArray) -> xr.DataArray:
return data - self.value


class CorrectBaselineAverage(PreProcessor):
action: Literal["baseline-average"] = "baseline-average"
average: dict[str, slice]

def apply(self, data: xr.DataArray) -> xr.DataArray:
selection = data
for axis, interval in self.average.items():
selection = selection.sel({axis: interval})
return data - (selection.sum() / selection.size)


# TODO
# class ShiftAlongAxis(PreProcessorAction):
# action: Literal["shift-along-axis"] = "shift-along-axis"
# axis: str
# value: float | list[float]


# TODO
# class AverageAxis(PreProcessorAction):
# action: Literal["average"] = "average"
# new_coord: Literal["first", "last", "mean"] = "mean"
# axis: str
# number: int


class PreProcessorActions(BaseModel):
__root__: Annotated[
CorrectBaselineValue | CorrectBaselineAverage,
Field(discriminator="action"),
]


class PreProcessingPipeline:
def __init__(self, original: xr.DataArray, actions: list[PreProcessor] | None = None):
self._actions: list[PreProcessor] = actions or []

def pply(self, original: xr.DataArray) -> xr.DataArray:
result = original.copy()
for action in self._actions:
result = action.apply(result)
return result

def _push_action(self, action: PreProcessor):
self._actions.append(action)

def correct_baseline_value(self, value: float) -> Self:
self._push_action(CorrectBaselineValue(value=value))
return self

0 comments on commit e78d409

Please sign in to comment.