From 1ea0418b021686800c4d8aa3c8c1e1cd14956b15 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 13 Nov 2024 16:07:23 -0800 Subject: [PATCH] fix CFD issue for dpnp.extract --- dpnp/dpnp_iface_indexing.py | 10 ++++------ tests/test_sycl_queue.py | 14 ++++++++++++++ tests/test_usm_type.py | 12 ++++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index f1d00842e5aa..fa4109acc683 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -51,10 +51,7 @@ dpnp_putmask, ) from .dpnp_array import dpnp_array -from .dpnp_utils import ( - call_origin, - get_usm_allocations, -) +from .dpnp_utils import call_origin, get_usm_allocations __all__ = [ "choose", @@ -535,11 +532,12 @@ def extract(condition, a): """ usm_a = dpnp.get_usm_ndarray(a) + usm_type, exec_q = get_usm_allocations([usm_a, condition]) usm_cond = dpnp.as_usm_ndarray( condition, dtype=dpnp.bool, - usm_type=usm_a.usm_type, - sycl_queue=usm_a.sycl_queue, + usm_type=usm_type, + sycl_queue=exec_q, ) if usm_cond.size != usm_a.size: diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 43dda9a3ed50..44e57196b5e6 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -406,6 +406,20 @@ def test_copy_operation(device): assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_extract(device): + x = dpnp.arange(3, device=device) + y = dpnp.array([True, False, True], device=device) + result = dpnp.extract(x, y) + + assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue) + assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue) + + @pytest.mark.parametrize( "device", valid_devices, diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index f58c58605de0..e4fd16a5f12a 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -804,6 +804,18 @@ def test_concat_stack(func, data1, data2, usm_type_x, usm_type_y): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types) +def test_extract(usm_type_x, usm_type_y): + x = dp.arange(3, usm_type=usm_type_x) + y = dp.array([True, False, True], usm_type=usm_type_y) + z = dp.extract(y, x) + + assert x.usm_type == usm_type_x + assert y.usm_type == usm_type_y + assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) + + @pytest.mark.parametrize( "func,data1", [