Skip to content

Commit

Permalink
Merge pull request #20 from MothNik/fix/median_performance
Browse files Browse the repository at this point in the history
Fix/median performance
  • Loading branch information
deepak7376 authored Apr 9, 2024
2 parents 5bd08f9 + b805ff7 commit 08d2a01
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 27 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
- name: Run tests
run: |
export PYTHONPATH="${PYTHONPATH}:/robustbase/"
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ celerybeat.pid

# Environments
.env
.venv
.venv*
env/
venv/
ENV/
env.bak/
venv.bak/
.vscode/

# Spyder project settings
.spyderproject
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This package provides functions to calculate the following robust statistical es

```python
from robustbase.stats import Qn

x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# With bias correction
Expand All @@ -37,11 +37,11 @@ res = Qn(x, finite_corr=False) # result: 4.43828

```python
from robustbase.stats import Sn

x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# With bias correction
res = Sn(x) # result: 3.5778
res = Sn(x) # result: 3.5778

# Without bias correction
res = Sn(x, finite_corr=False) # result: 3.5778
Expand Down Expand Up @@ -75,7 +75,7 @@ For local development setup:
```sh
git clone https://github.com/deepak7376/robustbase
cd robustbase
pip install -r requirements.txt
pip install -r requirements.txt -r requirements-dev.txt
```

## Recent Changes
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest>=8.1.1
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ certifi>=2019.11.28
docutils>=0.15.2
numpy>=1.18.0
statistics>=1.0.3.5
pytest>=8.1.1
8 changes: 4 additions & 4 deletions robustbase/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .robustbase import Qn
from .robustbase import Sn
from .robustbase import iqr
from .robustbase import mad
from .robustbase import Qn # noqa: F401
from .robustbase import Sn # noqa: F401
from .robustbase import iqr # noqa: F401
from .robustbase import mad # noqa: F401


8 changes: 4 additions & 4 deletions robustbase/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .Qn import Qn
from .Sn import Sn
from .iqr import iqr
from .mad import mad
from .iqr import iqr # noqa: F401
from .mad import mad # noqa: F401
from .Qn import Qn # noqa: F401
from .Sn import Sn # noqa: F401
4 changes: 2 additions & 2 deletions robustbase/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .mean import mean
from .median import median
from .mean import mean # noqa: F401
from .median import median # noqa: F401
29 changes: 18 additions & 11 deletions robustbase/utils/median.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,29 @@ def median(x, low=False, high=False):
Parameters:
- x: list or array-like, numeric vector of observations.
- low: bool, if True, return the low median for even sample size.
- low: bool, if True, return the low median for even sample size. If ``True``, ``high`` is ignored.
- high: bool, if True, return the high median for even sample size.
Returns:
- float: Median value.
"""
sorted_x = np.sort(x)
n = len(sorted_x)

n = len(x)
if n == 0:
raise ValueError("Empty list provided.")


# for odd sample size, all three medians are the same
if n % 2 == 1:
return sorted_x[n // 2]
elif low:
return sorted_x[n // 2 - 1]
elif high:
return sorted_x[n // 2]
else:
return (sorted_x[n // 2 - 1] + sorted_x[n // 2]) / 2
return np.median(a=x)

# for even sample sizes, the median is the average of the two middle values if
# neither the low nor high median is requested
if not (low or high):
return np.median(a=x)

# otherwise, either the low or the high median are found via introselect
median_idx = n // 2
if low:
median_idx -= 1

return np.partition(a=x, kth=median_idx)[median_idx]
47 changes: 47 additions & 0 deletions tests/test_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, Tuple, Union

import numpy as np
import pytest

from robustbase.utils.median import median

X_EMPTY = []
X_ODD_N = [5.5, 3.2, -10.0, -2.1, 8.4]
X_EVEN_N = [5.5, 3.2, -10.0, -2.1, 8.4, 0.0]


@pytest.mark.parametrize("as_array", [False, True])
@pytest.mark.parametrize(
"comb",
[
(X_EMPTY, False, False, None),
(X_EMPTY, True, False, None),
(X_EMPTY, False, True, None),
(X_EMPTY, True, True, None),
(X_ODD_N, False, False, 3.2),
(X_ODD_N, True, False, 3.2),
(X_ODD_N, False, True, 3.2),
(X_ODD_N, True, True, 3.2),
(X_EVEN_N, False, False, 1.6),
(X_EVEN_N, True, False, 0.0),
(X_EVEN_N, False, True, 3.2),
(X_EVEN_N, True, True, 0.0),
],
)
def test_median(
comb: Tuple[Union[list, np.ndarray], bool, bool, Optional[float]],
as_array: bool,
):
x, low, high, expected = comb
if as_array:
x = np.array(x)

# for empty samples, an error should be raised
if expected is None:
with pytest.raises(ValueError):
median(x, low=low, high=high)

return

# otherwise, the expected median should be returned
assert median(x, low=low, high=high) == expected

0 comments on commit 08d2a01

Please sign in to comment.