Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Median #907

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
return ret


def median(x: TensorLike, axis=None) -> TensorVariable:
"""
Computes the median along the given axis(es) of a tensor `input`.

Parameters
----------
x: TensorVariable
The input tensor.
axis: None or int or (list of int) (see `Sum`)
Compute the median along this axis of the tensor.
None means all axes (like numpy).
"""
from pytensor.ifelse import ifelse

x = as_tensor_variable(x)
x_ndim = x.type.ndim
if axis is None:
axis = list(range(x_ndim))
else:
axis = list(normalize_axis_tuple(axis, x_ndim))

non_axis = [i for i in range(x_ndim) if i not in axis]
non_axis_shape = [x.shape[i] for i in non_axis]

# Put axis at the end and unravel them
x_raveled = x.transpose(*non_axis, *axis)
if len(axis) > 1:
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a small optimization to avoid reshaping when not needed

raveled_size = x_raveled.shape[-1]
k = raveled_size // 2

# Sort the input tensor along the specified axis and pick median value
x_sorted = x_raveled.sort(axis=-1)
k_values = x_sorted[..., k]
km1_values = x_sorted[..., k - 1]
Comment on lines +1602 to +1603
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the indexing, we can use simple indexing instead of take_along_axis

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I did not know we can use simple indexing so conveniently.


even_median = (k_values + km1_values) / 2.0
odd_median = k_values.astype(even_median.type.dtype)
even_k = eq(mod(raveled_size, 2), 0)
return ifelse(even_k, even_median, odd_median, name="median")


@scalar_elemwise(symbolname="scalar_maximum")
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
Expand Down Expand Up @@ -3015,6 +3057,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"sum",
"prod",
"mean",
"median",
"var",
"std",
"std",
Expand Down
31 changes: 31 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
max_and_argmax,
maximum,
mean,
median,
min,
minimum,
mod,
Expand Down Expand Up @@ -3735,3 +3736,33 @@ def test_nan_to_num(nan, posinf, neginf):
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)


@pytest.mark.parametrize(
"ndim, axis",
[
(2, None),
(2, 1),
(2, (0, 1)),
(3, None),
(3, (1, 2)),
(4, (1, 3, 0)),
],
)
def test_median(ndim, axis):
# Generate random data with both odd and even lengths
shape_even = np.arange(1, ndim + 1) * 2
shape_odd = shape_even - 1

data_even = np.random.rand(*shape_even)
data_odd = np.random.rand(*shape_odd)

x = tensor(dtype="float64", shape=(None,) * ndim)
f = function([x], median(x, axis=axis))
result_odd = f(data_odd)
result_even = f(data_even)
expected_odd = np.median(data_odd, axis=axis)
expected_even = np.median(data_even, axis=axis)

assert np.allclose(result_odd, expected_odd)
assert np.allclose(result_even, expected_even)
Loading