Skip to content

Commit

Permalink
Added basic data pre-processing pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
joernweissenborn committed Feb 28, 2023
1 parent 2098c83 commit d49cf9d
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 0 deletions.
2 changes: 2 additions & 0 deletions glotaran/io/preprocessor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Tools for data pre-processing."""
from glotaran.io.preprocessor.pipeline import PreProcessingPipeline
80 changes: 80 additions & 0 deletions glotaran/io/preprocessor/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""A pre-processor pipeline for data."""
from __future__ import annotations

from typing import Annotated

import xarray as xr
from pydantic import BaseModel
from pydantic import Field

from glotaran.io.preprocessor.preprocessor import CorrectBaselineAverage
from glotaran.io.preprocessor.preprocessor import CorrectBaselineValue

PipelineAction = Annotated[
CorrectBaselineValue | CorrectBaselineAverage,
Field(discriminator="action"),
]


class PreProcessingPipeline(BaseModel):
"""A pipeline for pre-processors."""

actions: list[PipelineAction] = Field(default_factory=list)

def apply(self, original: xr.DataArray) -> xr.DataArray:
"""Apply all pre-processors on data.
Parameters
----------
original: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
"""
result = original.copy()

for action in self.actions:
result = action.apply(result)
return result

def _push_action(self, action: PipelineAction):
"""Push an action.
Parameters
----------
action: PipelineAction
The action to push.
"""
self.actions.append(action)

def correct_baseline_value(self, value: float) -> PreProcessingPipeline:
"""Correct a dataset by subtracting baseline value.
Parameters
----------
value: float
The value to subtract.
Returns
-------
PreProcessingPipeline
"""
self._push_action(CorrectBaselineValue(value=value))
return self

def correct_baseline_average(self, interval: dict[str, slice]) -> PreProcessingPipeline:
"""Correct a dataset by subtracting an average of the data.
Parameters
----------
interval: dict[str, slice]
The intervals to average.
Returns
-------
PreProcessingPipeline
"""
self._push_action(CorrectBaselineAverage(interval=interval))
return self
78 changes: 78 additions & 0 deletions glotaran/io/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""A pre-processor pipeline for data."""
from __future__ import annotations

import abc
from typing import Literal

import xarray as xr
from pydantic import BaseModel


class PreProcessor(BaseModel, abc.ABC):
"""A base class for pre=processors."""

class Config:
"""Config for BaseModel."""

arbitrary_types_allowed = True

@abc.abstractmethod
def apply(self, data: xr.DataArray) -> xr.DataArray:
"""Apply the pre-processor.
Parameters
----------
data: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
.. # noqa: DAR202
"""


class CorrectBaselineValue(PreProcessor):
"""Corrects a dataset by subtracting baseline value."""

action: Literal["baseline-value"] = "baseline-value"
value: float

def apply(self, data: xr.DataArray) -> xr.DataArray:
"""Apply the pre-processor.
Parameters
----------
data: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
"""
return data - self.value


class CorrectBaselineAverage(PreProcessor):
"""Corrects a dataset by subtracting the average over a part of the data."""

action: Literal["baseline-average"] = "baseline-average"
interval: dict[str, slice]

def apply(self, data: xr.DataArray) -> xr.DataArray:
"""Apply the pre-processor.
Parameters
----------
data: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
"""
selection = data
for axis, interval in self.interval.items():
selection = selection.sel({axis: interval})
return data - (selection.sum() / selection.size)
33 changes: 33 additions & 0 deletions glotaran/io/preprocessor/test/test_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import xarray as xr

from glotaran.io.preprocessor import PreProcessingPipeline


def test_correct_baseline_value():
pl = PreProcessingPipeline()
pl.correct_baseline_value(1)
data = xr.DataArray([[1]])
result = pl.apply(data)
assert result == data - 1


def test_correct_baseline_average():
pl = PreProcessingPipeline()
pl.correct_baseline_average({"dim_1": [0, 1]})
data = xr.DataArray([[1.1, 0.9]])
result = pl.apply(data)
assert (result == data - 1).all()


def test_to_from_dict():
pl = PreProcessingPipeline()
pl.correct_baseline_value(1)
pl.correct_baseline_average({"dim_1": [0, 1]})
pl_dict = pl.dict()
assert pl_dict == {
"actions": [
{"action": "baseline-value", "value": 1.0},
{"action": "baseline-average", "interval": {"dim_1": [0, 1]}},
]
}
assert PreProcessingPipeline.parse_obj(pl_dict) == pl
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ numpy==1.23.5
odfpy==1.4.1
openpyxl==3.1.1
pandas==1.5.3
pydantic==1.10.2
rich==13.3.1
ruamel.yaml==0.17.21
scipy==1.10.0
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ install_requires =
odfpy>=1.4.1
openpyxl>=3.0.10
pandas>=1.3.4
pydantic>=1.10.2
rich>=10.9.0
ruamel.yaml>=0.17.17
scipy>=1.7.2
Expand Down

0 comments on commit d49cf9d

Please sign in to comment.