Skip to content

Commit

Permalink
Added basic dataprocessor.
Browse files Browse the repository at this point in the history
  • Loading branch information
joernweissenborn committed Feb 23, 2023
1 parent 2098c83 commit 92c0927
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions glotaran/io/preprocssor/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import abc
from typing import Annotated
from typing import Literal

try:
from typing import Self
except ImportError:
Self = "PreProcessor"

import xarray as xr
from pydantic import BaseModel
from pydantic import Field


class PreProcessorAction(BaseModel, abc.ABC):
class Config:
arbitrary_types_allowed = True

@abc.abstractmethod
def apply(self, data: xr.DataArray) -> xr.DataArray:
pass


class CorrectBaselineValue(PreProcessorAction):
action: Literal["baseline-value"] = "baseline-value"
value: float

def apply(self, data: xr.DataArray) -> xr.DataArray:
return data - self.value


class CorrectBaselineAverage(PreProcessorAction):
action: Literal["baseline-average"] = "baseline-average"
average: dict[str, slice]

def apply(self, data: xr.DataArray) -> xr.DataArray:
selection = data
for axis, interval in self.average.items():
selection = selection.sel({axis: interval})
return data - (selection.sum() / selection.size)


# TODO
# class ShiftAlongAxis(PreProcessorAction):
# action: Literal["shift-along-axis"] = "shift-along-axis"
# axis: str
# value: float | list[float]


# TODO
# class AverageAxis(PreProcessorAction):
# action: Literal["average"] = "average"
# new_coord: Literal["first", "last", "mean"] = "mean"
# axis: str
# number: int


class PreProcessorActions(BaseModel):
__root__: Annotated[
CorrectBaselineValue | CorrectBaselineAverage,
Field(discriminator="action"),
]


class PreProcessor:
def __init__(self, original: xr.DataArray, actions: list[PreProcessorAction] | None = None):
self._actions: list[PreProcessorAction] = actions or []
self._original: xr.DataArray = original

def get_result(self) -> xr.DataArray:
data = self._original
for action in self._actions:
data = action.apply(data)
return data

def push_action(self, action: PreProcessorAction):
self._actions.append(action)

def correct_baseline_value(self, value: float) -> Self:
self.push_action(CorrectBaselineValue(value=value))
return self

0 comments on commit 92c0927

Please sign in to comment.