From 63dcc3316a37a134b69299af030a248b58d2bd21 Mon Sep 17 00:00:00 2001 From: Alexander Kalistratov Date: Thu, 20 Jul 2023 00:54:54 +0200 Subject: [PATCH 1/2] Use specilized kernel for f-arrays and sum by axis=1. Add keepdims support --- dpnp/dpnp_iface_mathematical.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 2c45f05e2b9a..a7e6800c6e22 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -1828,12 +1828,18 @@ def sum( elif where is not True: pass else: - if axis == (0,) and len(x.shape) == 2 and not keepdims: + if len(x.shape) == 2 and ( + (axis == (0,) and x.flags.c_contiguous) + or (axis == (1,) and x.flags.f_contiguous) + ): from dpctl.tensor._reduction import _default_reduction_dtype from dpnp.backend.extensions.sycl_ext import _sycl_ext_impl - input = dpnp.get_usm_ndarray(x) + input = x + if axis == (1,): + input = input.T + input = dpnp.get_usm_ndarray(input) queue = input.sycl_queue out_dtype = ( @@ -1850,7 +1856,14 @@ def sum( if sum: sum(input, output, []).wait() - return dpnp_array._create_from_usm_ndarray(output) + result = dpnp_array._create_from_usm_ndarray(output) + + if keepdims: + result = result.reshape((1,) + output.shape) + if axis == (1,): + result = result.T + + return result y = dpt.sum( dpnp.get_usm_ndarray(x), axis=axis, dtype=dtype, keepdims=keepdims From 48348c95c1114d2988932cbc8054e437588644c0 Mon Sep 17 00:00:00 2001 From: Alexander Kalistratov Date: Thu, 20 Jul 2023 19:28:47 +0200 Subject: [PATCH 2/2] Review remarks --- dpnp/dpnp_iface_mathematical.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index a7e6800c6e22..36e9804618f8 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -1859,9 +1859,11 @@ def sum( result = dpnp_array._create_from_usm_ndarray(output) if keepdims: - result = result.reshape((1,) + output.shape) - if axis == (1,): - result = result.T + if axis == (0,): + res_sh = (1,) + output.shape + else: + res_sh = output.shape + (1,) + result = result.reshape(res_sh) return result