Skip to content

Commit

Permalink
Added exclude to correct_baseline_average.
Browse files Browse the repository at this point in the history
  • Loading branch information
joernweissenborn committed Mar 3, 2023
1 parent 9f57298 commit c1dbca4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
11 changes: 8 additions & 3 deletions glotaran/io/preprocessor/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,24 @@ def correct_baseline_value(self, value: float) -> PreProcessingPipeline:
return self

def correct_baseline_average(
self, selection: dict[str, slice | list[int] | int]
self,
selection: dict[str, slice | list[int] | int] | None = None,
exclude: dict[str, slice | list[int] | int] | None = None,
) -> PreProcessingPipeline:
"""Correct a dataset by subtracting the average over a part of the data.
Parameters
----------
selection: dict[str, slice | list[int] | int]
selection: dict[str, slice | list[int] | int] | None
The selection to average as dictionary of dimension and indexer.
The indexer can be a slice, a list or an integer value.
exclude: dict[str, slice | list[int] | int] | None
Excluded regions from the average as dictionary of dimension and indexer.
The indexer can be a slice, a list or an integer value.
Returns
-------
PreProcessingPipeline
"""
self._push_action(CorrectBaselineAverage(selection=selection))
self._push_action(CorrectBaselineAverage(exclude=exclude, selection=selection))
return self
5 changes: 3 additions & 2 deletions glotaran/io/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class CorrectBaselineAverage(PreProcessor):
"""Corrects a dataset by subtracting the average over a part of the data."""

action: Literal["baseline-average"] = "baseline-average"
selection: dict[str, slice | list[int] | int]
selection: dict[str, slice | list[int] | int] | None = None
exclude: dict[str, slice | list[int] | int] | None = None

def apply(self, data: xr.DataArray) -> xr.DataArray:
"""Apply the pre-processor.
Expand All @@ -72,5 +73,5 @@ def apply(self, data: xr.DataArray) -> xr.DataArray:
-------
xr.DataArray
"""
selection = data.sel(self.selection)
selection = data.sel(self.selection or {}).drop_sel(self.exclude or {})
return data - (selection.sum() / selection.size)
13 changes: 11 additions & 2 deletions glotaran/io/preprocessor/test/test_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@ def test_correct_baseline_value():
@pytest.mark.parametrize("indexer", (slice(0, 2), [0, 1]))
def test_correct_baseline_average(indexer: slice | list[int]):
pl = PreProcessingPipeline()
pl.correct_baseline_average({"dim_0": 0, "dim_1": indexer})
pl.correct_baseline_average(selection={"dim_0": 0, "dim_1": indexer})
data = xr.DataArray([[1.1, 0.9]])
result = pl.apply(data)
assert (result == data - 1).all()


def test_correct_baseline_average_exclude():
pl = PreProcessingPipeline()
pl.correct_baseline_average(exclude={"dim_1": 1})
data = xr.DataArray([[1.1, 0.9]])
result = pl.apply(data)
print(result)
assert (result == data - 1.1).all()


def test_to_from_dict():
pl = PreProcessingPipeline()
pl.correct_baseline_value(1)
Expand All @@ -29,7 +38,7 @@ def test_to_from_dict():
assert pl_dict == {
"actions": [
{"action": "baseline-value", "value": 1.0},
{"action": "baseline-average", "selection": {"dim_1": slice(0, 2)}},
{"action": "baseline-average", "selection": {"dim_1": slice(0, 2)}, "exclude": None},
]
}
assert PreProcessingPipeline.parse_obj(pl_dict) == pl

0 comments on commit c1dbca4

Please sign in to comment.