Skip to content

Commit

Permalink
Added basic data pre-processing pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
joernweissenborn committed Mar 1, 2023
1 parent 2098c83 commit 9210faa
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 0 deletions.
2 changes: 2 additions & 0 deletions glotaran/io/preprocessor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Tools for data pre-processing."""
from glotaran.io.preprocessor.pipeline import PreProcessingPipeline
83 changes: 83 additions & 0 deletions glotaran/io/preprocessor/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""A pre-processor pipeline for data."""
from __future__ import annotations

from typing import Annotated

import xarray as xr
from pydantic import BaseModel
from pydantic import Field

from glotaran.io.preprocessor.preprocessor import CorrectBaselineAverage
from glotaran.io.preprocessor.preprocessor import CorrectBaselineValue

PipelineAction = Annotated[
CorrectBaselineValue | CorrectBaselineAverage,
Field(discriminator="action"),
]


class PreProcessingPipeline(BaseModel):
"""A pipeline for pre-processors."""

actions: list[PipelineAction] = Field(default_factory=list)

def apply(self, original: xr.DataArray) -> xr.DataArray:
"""Apply all pre-processors on data.
Parameters
----------
original: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
"""
result = original.copy()

for action in self.actions:
result = action.apply(result)
return result

def _push_action(self, action: PipelineAction):
"""Push an action.
Parameters
----------
action: PipelineAction
The action to push.
"""
self.actions.append(action)

def correct_baseline_value(self, value: float) -> PreProcessingPipeline:
"""Correct a dataset by subtracting baseline value.
Parameters
----------
value: float
The value to subtract.
Returns
-------
PreProcessingPipeline
"""
self._push_action(CorrectBaselineValue(value=value))
return self

def correct_baseline_average(
self, selection: dict[str, slice | list[int] | int]
) -> PreProcessingPipeline:
"""Correct a dataset by subtracting the average over a part of the data.
Parameters
----------
selection: dict[str, slice | list[int] | int]
The selection to average as dictionary of dimension and indexer.
The indexer can be a slice, a list or an integer value.
Returns
-------
PreProcessingPipeline
"""
self._push_action(CorrectBaselineAverage(selection=selection))
return self
76 changes: 76 additions & 0 deletions glotaran/io/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""A pre-processor pipeline for data."""
from __future__ import annotations

import abc
from typing import Literal

import xarray as xr
from pydantic import BaseModel


class PreProcessor(BaseModel, abc.ABC):
"""A base class for pre=processors."""

class Config:
"""Config for BaseModel."""

arbitrary_types_allowed = True

@abc.abstractmethod
def apply(self, data: xr.DataArray) -> xr.DataArray:
"""Apply the pre-processor.
Parameters
----------
data: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
.. # noqa: DAR202
"""


class CorrectBaselineValue(PreProcessor):
"""Corrects a dataset by subtracting baseline value."""

action: Literal["baseline-value"] = "baseline-value"
value: float

def apply(self, data: xr.DataArray) -> xr.DataArray:
"""Apply the pre-processor.
Parameters
----------
data: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
"""
return data - self.value


class CorrectBaselineAverage(PreProcessor):
"""Corrects a dataset by subtracting the average over a part of the data."""

action: Literal["baseline-average"] = "baseline-average"
selection: dict[str, slice | list[int] | int]

def apply(self, data: xr.DataArray) -> xr.DataArray:
"""Apply the pre-processor.
Parameters
----------
data: xr.DataArray
The data to process.
Returns
-------
xr.DataArray
"""
selection = data.sel(self.selection)
return data - (selection.sum() / selection.size)
35 changes: 35 additions & 0 deletions glotaran/io/preprocessor/test/test_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import xarray as xr

from glotaran.io.preprocessor import PreProcessingPipeline


def test_correct_baseline_value():
pl = PreProcessingPipeline()
pl.correct_baseline_value(1)
data = xr.DataArray([[1]])
result = pl.apply(data)
assert result == data - 1


@pytest.mark.parametrize("indexer", (slice(0, 2), [0, 1]))
def test_correct_baseline_average(indexer: slice | list[int]):
pl = PreProcessingPipeline()
pl.correct_baseline_average({"dim_0": 0, "dim_1": indexer})
data = xr.DataArray([[1.1, 0.9]])
result = pl.apply(data)
assert (result == data - 1).all()


def test_to_from_dict():
pl = PreProcessingPipeline()
pl.correct_baseline_value(1)
pl.correct_baseline_average({"dim_1": slice(0, 2)})
pl_dict = pl.dict()
assert pl_dict == {
"actions": [
{"action": "baseline-value", "value": 1.0},
{"action": "baseline-average", "selection": {"dim_1": slice(0, 2)}},
]
}
assert PreProcessingPipeline.parse_obj(pl_dict) == pl
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ numpy==1.23.5
odfpy==1.4.1
openpyxl==3.1.1
pandas==1.5.3
pydantic==1.10.2
rich==13.3.1
ruamel.yaml==0.17.21
scipy==1.10.0
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ install_requires =
odfpy>=1.4.1
openpyxl>=3.0.10
pandas>=1.3.4
pydantic>=1.10.2
rich>=10.9.0
ruamel.yaml>=0.17.17
scipy>=1.7.2
Expand Down

0 comments on commit 9210faa

Please sign in to comment.