Skip to content

Commit

Permalink
Reduce code duplication for sorting functions (#1914)
Browse files Browse the repository at this point in the history
* Remove limitations from dpnp.take implementation

* Add more test to cover specail cases and increase code coverage

* Applied pre-commit hook

* Corrected test_over_index

* Update docsctrings with resolving typos

* Use dpnp.reshape() to change shape and create dpnp array from usm_ndarray result

* Remove limitations from dpnp.place implementation

* Update relating tests

* Roll back changed in dpnp.vander

* Remove data sync at the end of function

* Update indexing functions

* Add missing test scenario

* Updated docstring in put_along_axis() and take_along_axis() and rolled back data synchronization

* Remove data synchronization for dpnp.put()

* Remove data synchronization for dpnp.nonzero()

* Remove data synchronization for dpnp.indices()

* Remove data synchronization for dpnp.extract()

* Update indexing functions

* Update sorting functions

* Remove data sync from sort() and agrsort()

* Remove data sync from dpnp.sort_complex()

* Remove data sync in dpnp.get_result_array()
  • Loading branch information
antonwolfy authored Jul 10, 2024
1 parent f19e989 commit 001b63f
Showing 1 changed file with 28 additions and 38 deletions.
66 changes: 28 additions & 38 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@
__all__ = ["argsort", "partition", "sort", "sort_complex"]


def _wrap_sort_argsort(a, _sorting_fn, axis=-1, kind=None, order=None):
"""Wrap a sorting call from dpctl.tensor interface."""

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and kind != "stable":
raise NotImplementedError(
"kind keyword argument can only be None or 'stable'."
)

usm_a = dpnp.get_usm_ndarray(a)
if axis is None:
usm_a = dpt.reshape(usm_a, -1)
axis = -1

axis = normalize_axis_index(axis, ndim=usm_a.ndim)
usm_res = _sorting_fn(usm_a, axis=axis)
return dpnp_array._create_from_usm_ndarray(usm_res)


def argsort(a, axis=-1, kind=None, order=None):
"""
Returns the indices that would sort an array.
Expand Down Expand Up @@ -134,24 +156,7 @@ def argsort(a, axis=-1, kind=None, order=None):
"""

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and kind != "stable":
raise NotImplementedError(
"kind keyword argument can only be None or 'stable'."
)

dpnp.check_supported_arrays_type(a)
if axis is None:
a = a.flatten()
axis = -1

axis = normalize_axis_index(axis, ndim=a.ndim)
return dpnp_array._create_from_usm_ndarray(
dpt.argsort(dpnp.get_usm_ndarray(a), axis=axis)
)
return _wrap_sort_argsort(a, dpt.argsort, axis=axis, kind=kind, order=order)


def partition(x1, kth, axis=-1, kind="introselect", order=None):
Expand Down Expand Up @@ -246,24 +251,7 @@ def sort(a, axis=-1, kind=None, order=None):
"""

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and kind != "stable":
raise NotImplementedError(
"kind keyword argument can only be None or 'stable'."
)

dpnp.check_supported_arrays_type(a)
if axis is None:
a = a.flatten()
axis = -1

axis = normalize_axis_index(axis, ndim=a.ndim)
return dpnp_array._create_from_usm_ndarray(
dpt.sort(dpnp.get_usm_ndarray(a), axis=axis)
)
return _wrap_sort_argsort(a, dpt.sort, axis=axis, kind=kind, order=order)


def sort_complex(a):
Expand Down Expand Up @@ -298,6 +286,8 @@ def sort_complex(a):
b = dpnp.sort(a)
if not dpnp.issubsctype(b.dtype, dpnp.complexfloating):
if b.dtype.char in "bhBH":
return b.astype(dpnp.complex64)
return b.astype(map_dtype_to_device(dpnp.complex128, b.sycl_device))
b_dt = dpnp.complex64
else:
b_dt = map_dtype_to_device(dpnp.complex128, b.sycl_device)
return b.astype(b_dt)
return b

0 comments on commit 001b63f

Please sign in to comment.