Skip to content

Commit

Permalink
Reuse add(), multiply() and subtract() from dpctl (#1430)
Browse files Browse the repository at this point in the history
* Reuse add(), multiply() and subtract() from dpctl

* add in-place support
  • Loading branch information
antonwolfy committed Jun 15, 2023
1 parent ac59bc3 commit 66e5a9f
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 147 deletions.
13 changes: 0 additions & 13 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
cdef enum DPNPFuncName "DPNPFuncName":
DPNP_FN_ABSOLUTE
DPNP_FN_ABSOLUTE_EXT
DPNP_FN_ADD
DPNP_FN_ADD_EXT
DPNP_FN_ALL
DPNP_FN_ALL_EXT
DPNP_FN_ALLCLOSE
Expand Down Expand Up @@ -117,7 +115,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_DIAG_INDICES_EXT
DPNP_FN_DIAGONAL
DPNP_FN_DIAGONAL_EXT
DPNP_FN_DIVIDE
DPNP_FN_DOT
DPNP_FN_DOT_EXT
DPNP_FN_EDIFF1D
Expand Down Expand Up @@ -203,8 +200,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_MINIMUM_EXT
DPNP_FN_MODF
DPNP_FN_MODF_EXT
DPNP_FN_MULTIPLY
DPNP_FN_MULTIPLY_EXT
DPNP_FN_NANVAR
DPNP_FN_NANVAR_EXT
DPNP_FN_NEGATIVE
Expand Down Expand Up @@ -323,8 +318,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_SQUARE_EXT
DPNP_FN_STD
DPNP_FN_STD_EXT
DPNP_FN_SUBTRACT
DPNP_FN_SUBTRACT_EXT
DPNP_FN_SUM
DPNP_FN_SUM_EXT
DPNP_FN_SVD
Expand Down Expand Up @@ -523,8 +516,6 @@ cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1)
"""
Mathematical functions
"""
cpdef dpnp_descriptor dpnp_add(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_arctan2(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_hypot(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
Expand All @@ -533,15 +524,11 @@ cpdef dpnp_descriptor dpnp_maximum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_multiply(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_subtract(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)


"""
Expand Down
27 changes: 0 additions & 27 deletions dpnp/dpnp_algo/dpnp_algo_mathematical.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ and the rest of the library

__all__ += [
"dpnp_absolute",
"dpnp_add",
"dpnp_arctan2",
"dpnp_around",
"dpnp_ceil",
Expand All @@ -57,7 +56,6 @@ __all__ += [
"dpnp_maximum",
"dpnp_minimum",
"dpnp_modf",
"dpnp_multiply",
"dpnp_nancumprod",
"dpnp_nancumsum",
"dpnp_nanprod",
Expand All @@ -67,7 +65,6 @@ __all__ += [
"dpnp_prod",
"dpnp_remainder",
"dpnp_sign",
"dpnp_subtract",
"dpnp_sum",
"dpnp_trapz",
"dpnp_trunc"
Expand Down Expand Up @@ -123,14 +120,6 @@ cpdef utils.dpnp_descriptor dpnp_absolute(utils.dpnp_descriptor x1):
return result


cpdef utils.dpnp_descriptor dpnp_add(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
utils.dpnp_descriptor out=None,
object where=True):
return call_fptr_2in_1out_strides(DPNP_FN_ADD_EXT, x1_obj, x2_obj, dtype, out, where)


cpdef utils.dpnp_descriptor dpnp_arctan2(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
Expand Down Expand Up @@ -426,14 +415,6 @@ cpdef tuple dpnp_modf(utils.dpnp_descriptor x1):
return (result1.get_pyobj(), result2.get_pyobj())


cpdef utils.dpnp_descriptor dpnp_multiply(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
utils.dpnp_descriptor out=None,
object where=True):
return call_fptr_2in_1out_strides(DPNP_FN_MULTIPLY_EXT, x1_obj, x2_obj, dtype, out, where)


cpdef utils.dpnp_descriptor dpnp_nancumprod(utils.dpnp_descriptor x1):
cur_x1 = dpnp_copy(x1).get_pyobj()

Expand Down Expand Up @@ -586,14 +567,6 @@ cpdef utils.dpnp_descriptor dpnp_sign(utils.dpnp_descriptor x1):
return call_fptr_1in_1out_strides(DPNP_FN_SIGN_EXT, x1)


cpdef utils.dpnp_descriptor dpnp_subtract(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
utils.dpnp_descriptor out=None,
object where=True):
return call_fptr_2in_1out_strides(DPNP_FN_SUBTRACT_EXT, x1_obj, x2_obj, dtype, out, where)


cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor x1,
object axis=None,
object dtype=None,
Expand Down
131 changes: 130 additions & 1 deletion dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,55 @@


__all__ = [
"dpnp_divide"
"dpnp_add",
"dpnp_divide",
"dpnp_multiply",
"dpnp_subtract"
]


_add_docstring_ = """
add(x1, x2, out=None, order='K')
Calculates the sum for each element `x1_i` of the input array `x1` with
the respective element `x2_i` of the input array `x2`.
Args:
x1 (dpnp.ndarray):
First input array, expected to have numeric data type.
x2 (dpnp.ndarray):
Second input array, also expected to have numeric data type.
out ({None, dpnp.ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", None, optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
dpnp.ndarray:
an array containing the result of element-wise division. The data type
of the returned array is determined by the Type Promotion Rules.
"""

def dpnp_add(x1, x2, out=None, order='K'):
"""
Invokes add() from dpctl.tensor implementation for add() function.
TODO: add a pybind11 extension of add() from OneMKL VM where possible
and would be performance effective.
"""

# dpctl.tensor only works with usm_ndarray or scalar
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc("add", ti._add_result_type, ti._add,
_add_docstring_, ti._add_inplace)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


_divide_docstring_ = """
divide(x1, x2, out=None, order='K')
Expand Down Expand Up @@ -88,3 +133,87 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, _divide_docstring_)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


_multiply_docstring_ = """
multiply(x1, x2, out=None, order='K')
Calculates the product for each element `x1_i` of the input array `x1`
with the respective element `x2_i` of the input array `x2`.
Args:
x1 (dpnp.ndarray):
First input array, expected to have numeric data type.
x2 (dpnp.ndarray):
Second input array, also expected to have numeric data type.
out ({None, dpnp.ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", None, optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
dpnp.ndarray:
an array containing the result of element-wise division. The data type
of the returned array is determined by the Type Promotion Rules.
"""

def dpnp_multiply(x1, x2, out=None, order='K'):
"""
Invokes multiply() from dpctl.tensor implementation for multiply() function.
TODO: add a pybind11 extension of mul() from OneMKL VM where possible
and would be performance effective.
"""

# dpctl.tensor only works with usm_ndarray or scalar
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc("multiply", ti._multiply_result_type, ti._multiply,
_multiply_docstring_, ti._multiply_inplace)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


_subtract_docstring_ = """
subtract(x1, x2, out=None, order='K')
Calculates the difference bewteen each element `x1_i` of the input
array `x1` and the respective element `x2_i` of the input array `x2`.
Args:
x1 (dpnp.ndarray):
First input array, expected to have numeric data type.
x2 (dpnp.ndarray):
Second input array, also expected to have numeric data type.
out ({None, dpnp.ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", None, optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
dpnp.ndarray:
an array containing the result of element-wise division. The data type
of the returned array is determined by the Type Promotion Rules.
"""

def dpnp_subtract(x1, x2, out=None, order='K'):
"""
Invokes subtract() from dpctl.tensor implementation for subtract() function.
TODO: add a pybind11 extension of sub() from OneMKL VM where possible
and would be performance effective.
"""

# dpctl.tensor only works with usm_ndarray or scalar
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc("subtract", ti._subtract_result_type, ti._subtract,
_subtract_docstring_, ti._subtract_inplace)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)
10 changes: 8 additions & 2 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,15 @@ def __irshift__(self, other):
dpnp.right_shift(self, other, out=self)
return self

# '__isub__',
def __isub__(self, other):
dpnp.subtract(self, other, out=self)
return self

# '__iter__',
# '__itruediv__',

def __itruediv__(self, other):
dpnp.true_divide(self, other, out=self)
return self

def __ixor__(self, other):
dpnp.bitwise_xor(self, other, out=self)
Expand Down
Loading

0 comments on commit 66e5a9f

Please sign in to comment.