Skip to content

Commit

Permalink
Tweaks.
Browse files Browse the repository at this point in the history
  • Loading branch information
joernweissenborn committed Mar 1, 2023
1 parent 47fc53c commit 1ea0664
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
6 changes: 3 additions & 3 deletions glotaran/io/preprocessor/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def correct_baseline_value(self, value: float) -> PreProcessingPipeline:

def correct_baseline_average(
self,
selection: dict[str, slice | list[int] | int] | None = None,
select: dict[str, slice | list[int] | int] | None = None,
exclude: dict[str, slice | list[int] | int] | None = None,
) -> PreProcessingPipeline:
"""Correct a dataset by subtracting the average over a part of the data.
Parameters
----------
selection: dict[str, slice | list[int] | int] | None
select: dict[str, slice | list[int] | int] | None
The selection to average as dictionary of dimension and indexer.
The indexer can be a slice, a list or an integer value.
exclude: dict[str, slice | list[int] | int] | None
Expand All @@ -84,5 +84,5 @@ def correct_baseline_average(
-------
PreProcessingPipeline
"""
self._push_action(CorrectBaselineAverage(exclude=exclude, selection=selection))
self._push_action(CorrectBaselineAverage(exclude=exclude, select=select))
return self
5 changes: 2 additions & 3 deletions glotaran/io/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CorrectBaselineAverage(PreProcessor):
"""Corrects a dataset by subtracting the average over a part of the data."""

action: Literal["baseline-average"] = "baseline-average"
selection: dict[str, slice | list[int] | int] | None = None
select: dict[str, slice | list[int] | int] | None = None
exclude: dict[str, slice | list[int] | int] | None = None

def apply(self, data: xr.DataArray) -> xr.DataArray:
Expand All @@ -73,5 +73,4 @@ def apply(self, data: xr.DataArray) -> xr.DataArray:
-------
xr.DataArray
"""
selection = data.sel(self.selection or {}).drop_sel(self.exclude or {})
return data - (selection.sum() / selection.size)
return data - data.sel(self.select or {}).drop_sel(self.exclude or {}).mean()
6 changes: 3 additions & 3 deletions glotaran/io/preprocessor/test/test_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def test_correct_baseline_value():
@pytest.mark.parametrize("indexer", (slice(0, 2), [0, 1]))
def test_correct_baseline_average(indexer: slice | list[int]):
pl = PreProcessingPipeline()
pl.correct_baseline_average(selection={"dim_0": 0, "dim_1": indexer})
pl.correct_baseline_average(select={"dim_0": 0, "dim_1": indexer})
data = xr.DataArray([[1.1, 0.9]])
result = pl.apply(data)
assert (result == data - 1).all()


def test_correct_baseline_average_exclude():
pl = PreProcessingPipeline()
pl.correct_baseline_average(exclude={"dim_1": 1})
pl.correct_baseline_average(select={"dim_0": 0}, exclude={"dim_1": 1})
data = xr.DataArray([[1.1, 0.9]])
result = pl.apply(data)
print(result)
Expand All @@ -38,7 +38,7 @@ def test_to_from_dict():
assert pl_dict == {
"actions": [
{"action": "baseline-value", "value": 1.0},
{"action": "baseline-average", "selection": {"dim_1": slice(0, 2)}, "exclude": None},
{"action": "baseline-average", "select": {"dim_1": slice(0, 2)}, "exclude": None},
]
}
assert PreProcessingPipeline.parse_obj(pl_dict) == pl

0 comments on commit 1ea0664

Please sign in to comment.