diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index f1d00842e5aa..fa4109acc683 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -51,10 +51,7 @@ dpnp_putmask, ) from .dpnp_array import dpnp_array -from .dpnp_utils import ( - call_origin, - get_usm_allocations, -) +from .dpnp_utils import call_origin, get_usm_allocations __all__ = [ "choose", @@ -535,11 +532,12 @@ def extract(condition, a): """ usm_a = dpnp.get_usm_ndarray(a) + usm_type, exec_q = get_usm_allocations([usm_a, condition]) usm_cond = dpnp.as_usm_ndarray( condition, dtype=dpnp.bool, - usm_type=usm_a.usm_type, - sycl_queue=usm_a.sycl_queue, + usm_type=usm_type, + sycl_queue=exec_q, ) if usm_cond.size != usm_a.size: diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index cddf3e9269b3..87e13dcb6586 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -406,6 +406,20 @@ def test_copy_operation(device): assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_extract(device): + x = dpnp.arange(3, device=device) + y = dpnp.array([True, False, True], device=device) + result = dpnp.extract(x, y) + + assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue) + assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue) + + @pytest.mark.parametrize( "device", valid_devices, diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 1a889954d946..d6926ab16a4e 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -812,6 +812,18 @@ def test_concat_stack(func, data1, data2, usm_type_x, usm_type_y): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types) +def test_extract(usm_type_x, usm_type_y): + x = dp.arange(3, usm_type=usm_type_x) + y = dp.array([True, False, True], usm_type=usm_type_y) + z = dp.extract(y, x) + + assert x.usm_type == usm_type_x + assert y.usm_type == usm_type_y + assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) + + @pytest.mark.parametrize( "func,data1", [