diff --git a/CHANGELOG.md b/CHANGELOG.md index f428b033c7d6..a2d935fcc1d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ This is a bug-fix release. ### Fixed +* Resolved an issue with Compute Follows Data inconsistency in `dpnp.extract` function [#2172](https://github.com/IntelPython/dpnp/pull/2172) * Resolved a compilation error when building with DPC++ 2025.1 compiler [#2211](https://github.com/IntelPython/dpnp/pull/2211) diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 3c94c7091ec9..a51c42bb4c40 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -51,10 +51,7 @@ dpnp_putmask, ) from .dpnp_array import dpnp_array -from .dpnp_utils import ( - call_origin, - get_usm_allocations, -) +from .dpnp_utils import call_origin, get_usm_allocations __all__ = [ "choose", @@ -585,11 +582,12 @@ def extract(condition, a): """ usm_a = dpnp.get_usm_ndarray(a) + usm_type, exec_q = get_usm_allocations([usm_a, condition]) usm_cond = dpnp.as_usm_ndarray( condition, dtype=dpnp.bool, - usm_type=usm_a.usm_type, - sycl_queue=usm_a.sycl_queue, + usm_type=usm_type, + sycl_queue=exec_q, ) if usm_cond.size != usm_a.size: diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index e1ae1d8e65dc..a44f156b81c0 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -406,6 +406,20 @@ def test_copy_operation(device): assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_extract(device): + x = dpnp.arange(3, device=device) + y = dpnp.array([True, False, True], device=device) + result = dpnp.extract(x, y) + + assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue) + assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue) + + @pytest.mark.parametrize( "device", valid_devices, diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 592340d6c0db..2dcd1bbf98fa 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -769,6 +769,18 @@ def test_concat_stack(func, data1, data2, usm_type_x, usm_type_y): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types) +def test_extract(usm_type_x, usm_type_y): + x = dp.arange(3, usm_type=usm_type_x) + y = dp.array([True, False, True], usm_type=usm_type_y) + z = dp.extract(y, x) + + assert x.usm_type == usm_type_x + assert y.usm_type == usm_type_y + assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) + + @pytest.mark.parametrize( "func,data1", [