Skip to content

Commit

Permalink
Update sorting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Jul 7, 2024
1 parent f4e3917 commit 5b7e867
Showing 1 changed file with 39 additions and 38 deletions.
77 changes: 39 additions & 38 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@
__all__ = ["argsort", "partition", "sort", "sort_complex"]


def _wrap_sort_argsort(a, _sorting_fn, axis=-1, kind=None, order=None):
"""Wrap a sorting call from dpctl.tensor interface."""

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and kind != "stable":
raise NotImplementedError(
"kind keyword argument can only be None or 'stable'."
)

usm_a = dpnp.get_usm_ndarray(a)
if axis is None:
usm_a = dpt.reshape(usm_a, -1)
axis = -1

axis = normalize_axis_index(axis, ndim=usm_a.ndim)
usm_res = _sorting_fn(usm_a, axis=axis)
return dpnp_array._create_from_usm_ndarray(usm_res)


def argsort(a, axis=-1, kind=None, order=None):
"""
Returns the indices that would sort an array.
Expand Down Expand Up @@ -134,24 +156,11 @@ def argsort(a, axis=-1, kind=None, order=None):
"""

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and kind != "stable":
raise NotImplementedError(
"kind keyword argument can only be None or 'stable'."
)

dpnp.check_supported_arrays_type(a)
if axis is None:
a = a.flatten()
axis = -1

axis = normalize_axis_index(axis, ndim=a.ndim)
return dpnp_array._create_from_usm_ndarray(
dpt.argsort(dpnp.get_usm_ndarray(a), axis=axis)
result = _wrap_sort_argsort(
a, dpt.argsort, axis=axis, kind=kind, order=order
)
dpnp.synchronize_array_data(result)
return result


def partition(x1, kth, axis=-1, kind="introselect", order=None):
Expand Down Expand Up @@ -246,24 +255,9 @@ def sort(a, axis=-1, kind=None, order=None):
"""

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and kind != "stable":
raise NotImplementedError(
"kind keyword argument can only be None or 'stable'."
)

dpnp.check_supported_arrays_type(a)
if axis is None:
a = a.flatten()
axis = -1

axis = normalize_axis_index(axis, ndim=a.ndim)
return dpnp_array._create_from_usm_ndarray(
dpt.sort(dpnp.get_usm_ndarray(a), axis=axis)
)
result = _wrap_sort_argsort(a, dpt.sort, axis=axis, kind=kind, order=order)
dpnp.synchronize_array_data(result)
return result


def sort_complex(a):
Expand Down Expand Up @@ -295,9 +289,16 @@ def sort_complex(a):
"""

b = dpnp.sort(a)
b = _wrap_sort_argsort(a, dpt.sort)

if not dpnp.issubsctype(b.dtype, dpnp.complexfloating):
if b.dtype.char in "bhBH":
return b.astype(dpnp.complex64)
return b.astype(map_dtype_to_device(dpnp.complex128, b.sycl_device))
b_dt = dpnp.complex64
else:
b_dt = map_dtype_to_device(dpnp.complex128, b.sycl_device)

usm_b = dpt.astype(b.get_array(), b_dt)
b = dpnp_array._create_from_usm_ndarray(usm_b)

dpnp.synchronize_array_data(b)
return b

0 comments on commit 5b7e867

Please sign in to comment.