-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Support for Median #907
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False): | |
return ret | ||
|
||
|
||
def median(x: TensorLike, axis=None) -> TensorVariable: | ||
""" | ||
Computes the median along the given axis(es) of a tensor `input`. | ||
|
||
Parameters | ||
---------- | ||
x: TensorVariable | ||
The input tensor. | ||
axis: None or int or (list of int) (see `Sum`) | ||
Compute the median along this axis of the tensor. | ||
None means all axes (like numpy). | ||
""" | ||
from pytensor.ifelse import ifelse | ||
|
||
x = as_tensor_variable(x) | ||
x_ndim = x.type.ndim | ||
if axis is None: | ||
axis = list(range(x_ndim)) | ||
else: | ||
axis = list(normalize_axis_tuple(axis, x_ndim)) | ||
|
||
non_axis = [i for i in range(x_ndim) if i not in axis] | ||
non_axis_shape = [x.shape[i] for i in non_axis] | ||
|
||
# Put axis at the end and unravel them | ||
x_raveled = x.transpose(*non_axis, *axis) | ||
if len(axis) > 1: | ||
x_raveled = x_raveled.reshape((*non_axis_shape, -1)) | ||
raveled_size = x_raveled.shape[-1] | ||
k = raveled_size // 2 | ||
|
||
# Sort the input tensor along the specified axis and pick median value | ||
x_sorted = x_raveled.sort(axis=-1) | ||
k_values = x_sorted[..., k] | ||
km1_values = x_sorted[..., k - 1] | ||
Comment on lines
+1602
to
+1603
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I simplified the indexing, we can use simple indexing instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks great! I did not know we can use simple indexing so conveniently. |
||
|
||
even_median = (k_values + km1_values) / 2.0 | ||
odd_median = k_values.astype(even_median.type.dtype) | ||
even_k = eq(mod(raveled_size, 2), 0) | ||
return ifelse(even_k, even_median, odd_median, name="median") | ||
|
||
|
||
@scalar_elemwise(symbolname="scalar_maximum") | ||
def maximum(x, y): | ||
"""elemwise maximum. See max for the maximum in one tensor""" | ||
|
@@ -3015,6 +3057,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): | |
"sum", | ||
"prod", | ||
"mean", | ||
"median", | ||
"var", | ||
"std", | ||
"std", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a small optimization to avoid reshaping when not needed