From 47a1ec4221f869ba0c0cc9791648034d555f5e99 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 16 Aug 2023 11:22:38 -0500 Subject: [PATCH 1/3] create Unary/BinaryElementwisefunc during module import --- dpnp/dpnp_algo/dpnp_elementwise_common.py | 537 +++++++++++++--------- 1 file changed, 310 insertions(+), 227 deletions(-) diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 14c5f6058d8b..4cf68ff55f0e 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -215,6 +215,14 @@ def dpnp_add(x1, x2, out=None, order="K"): """ +bitwise_and_func = BinaryElementwiseFunc( + "bitwise_and", + ti._bitwise_and_result_type, + ti._bitwise_and, + _bitwise_and_docstring_, +) + + def dpnp_bitwise_and(x1, x2, out=None, order="K"): """Invokes bitwise_and() from dpctl.tensor implementation for bitwise_and() function.""" @@ -223,13 +231,9 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_and", - ti._bitwise_and_result_type, - ti._bitwise_and, - _bitwise_and_docstring_, + res_usm = bitwise_and_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -259,6 +263,14 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"): """ +bitwise_or_func = BinaryElementwiseFunc( + "bitwise_or", + ti._bitwise_or_result_type, + ti._bitwise_or, + _bitwise_or_docstring_, +) + + def dpnp_bitwise_or(x1, x2, out=None, order="K"): """Invokes bitwise_or() from dpctl.tensor implementation for bitwise_or() function.""" @@ -267,13 +279,9 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_or", - ti._bitwise_or_result_type, - ti._bitwise_or, - _bitwise_or_docstring_, + res_usm = bitwise_or_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -303,6 +311,14 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"): """ +bitwise_xor_func = BinaryElementwiseFunc( + "bitwise_xor", + ti._bitwise_xor_result_type, + ti._bitwise_xor, + _bitwise_xor_docstring_, +) + + def dpnp_bitwise_xor(x1, x2, out=None, order="K"): """Invokes bitwise_xor() from dpctl.tensor implementation for bitwise_xor() function.""" @@ -311,13 +327,9 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_xor", - ti._bitwise_xor_result_type, - ti._bitwise_xor, - _bitwise_xor_docstring_, + res_usm = bitwise_xor_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -393,33 +405,35 @@ def dpnp_ceil(x, out=None, order="K"): """ +def _call_cos(src, dst, sycl_queue, depends=None): + """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_cos_to_call(sycl_queue, src, dst): + # call pybind11 extension for cos() function from OneMKL VM + return vmi._cos(sycl_queue, src, dst, depends) + return ti._cos(src, dst, sycl_queue, depends) + + +cos_func = UnaryElementwiseFunc( + "cos", ti._cos_result_type, _call_cos, _cos_docstring +) + + def dpnp_cos(x, out=None, order="K"): """ Invokes cos() function from pybind11 extension of OneMKL VM if possible. Otherwise fully relies on dpctl.tensor implementation for cos() function. - """ - def _call_cos(src, dst, sycl_queue, depends=None): - """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" - - if depends is None: - depends = [] - - if vmi._mkl_cos_to_call(sycl_queue, src, dst): - # call pybind11 extension for cos() function from OneMKL VM - return vmi._cos(sycl_queue, src, dst, depends) - return ti._cos(src, dst, sycl_queue, depends) - # dpctl.tensor only works with usm_ndarray x1_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "cos", ti._cos_result_type, _call_cos, _cos_docstring - ) - res_usm = func(x1_usm, out=out_usm, order=order) + res_usm = cos_func(x1_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -447,57 +461,62 @@ def _call_cos(src, dst, sycl_queue, depends=None): """ -def dpnp_divide(x1, x2, out=None, order="K"): - """ - Invokes div() function from pybind11 extension of OneMKL VM if possible. +def _call_divide(src1, src2, dst, sycl_queue, depends=None): + """A callback to register in BinaryElementwiseFunc class of dpctl.tensor""" - Otherwise fully relies on dpctl.tensor implementation for divide() function. + if depends is None: + depends = [] - """ + if vmi._mkl_div_to_call(sycl_queue, src1, src2, dst): + # call pybind11 extension for div() function from OneMKL VM + return vmi._div(sycl_queue, src1, src2, dst, depends) + return ti._divide(src1, src2, dst, sycl_queue, depends) - def _call_divide(src1, src2, dst, sycl_queue, depends=None): - """A callback to register in BinaryElementwiseFunc class of dpctl.tensor""" - if depends is None: - depends = [] +def _call_divide_inplace(lhs, rhs, sycl_queue, depends=None): + """In place workaround until dpctl.tensor provides the functionality.""" - if vmi._mkl_div_to_call(sycl_queue, src1, src2, dst): - # call pybind11 extension for div() function from OneMKL VM - return vmi._div(sycl_queue, src1, src2, dst, depends) - return ti._divide(src1, src2, dst, sycl_queue, depends) + if depends is None: + depends = [] - def _call_divide_inplace(lhs, rhs, sycl_queue, depends=None): - """In place workaround until dpctl.tensor provides the functionality.""" + # allocate temporary memory for out array + out = dpt.empty_like(lhs, dtype=dpnp.result_type(lhs.dtype, rhs.dtype)) - if depends is None: - depends = [] + # call a general callback + div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends) - # allocate temporary memory for out array - out = dpt.empty_like(lhs, dtype=dpnp.result_type(lhs.dtype, rhs.dtype)) + # store the result into left input array and return events + cp_ht_, cp_ev_ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=lhs, sycl_queue=sycl_queue, depends=[div_ev_] + ) + dpctl.SyclEvent.wait_for([div_ht_]) + return (cp_ht_, cp_ev_) - # call a general callback - div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends) - # store the result into left input array and return events - cp_ht_, cp_ev_ = ti._copy_usm_ndarray_into_usm_ndarray( - src=out, dst=lhs, sycl_queue=sycl_queue, depends=[div_ev_] - ) - dpctl.SyclEvent.wait_for([div_ht_]) - return (cp_ht_, cp_ev_) +divide_func = BinaryElementwiseFunc( + "divide", + ti._divide_result_type, + _call_divide, + _divide_docstring_, + _call_divide_inplace, +) + + +def dpnp_divide(x1, x2, out=None, order="K"): + """ + Invokes div() function from pybind11 extension of OneMKL VM if possible. + + Otherwise fully relies on dpctl.tensor implementation for divide() function. + """ # dpctl.tensor only works with usm_ndarray or scalar x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1) x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "divide", - ti._divide_result_type, - _call_divide, - _divide_docstring_, - _call_divide_inplace, + res_usm = divide_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -525,6 +544,11 @@ def _call_divide_inplace(lhs, rhs, sycl_queue, depends=None): """ +equal_func = BinaryElementwiseFunc( + "equal", ti._equal_result_type, ti._equal, _equal_docstring_ +) + + def dpnp_equal(x1, x2, out=None, order="K"): """Invokes equal() from dpctl.tensor implementation for equal() function.""" @@ -533,10 +557,9 @@ def dpnp_equal(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "equal", ti._equal_result_type, ti._equal, _equal_docstring_ + res_usm = equal_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -617,6 +640,14 @@ def dpnp_floor(x, out=None, order="K"): """ +floor_divide_func = BinaryElementwiseFunc( + "floor_divide", + ti._floor_divide_result_type, + ti._floor_divide, + _floor_divide_docstring_, +) + + def dpnp_floor_divide(x1, x2, out=None, order="K"): """Invokes floor_divide() from dpctl.tensor implementation for floor_divide() function.""" @@ -625,13 +656,9 @@ def dpnp_floor_divide(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "floor_divide", - ti._floor_divide_result_type, - ti._floor_divide, - _floor_divide_docstring_, + res_usm = floor_divide_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -659,6 +686,11 @@ def dpnp_floor_divide(x1, x2, out=None, order="K"): """ +greater_func = BinaryElementwiseFunc( + "greater", ti._greater_result_type, ti._greater, _greater_docstring_ +) + + def dpnp_greater(x1, x2, out=None, order="K"): """Invokes greater() from dpctl.tensor implementation for greater() function.""" @@ -667,10 +699,9 @@ def dpnp_greater(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "greater", ti._greater_result_type, ti._greater, _greater_docstring_ + res_usm = greater_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -698,6 +729,14 @@ def dpnp_greater(x1, x2, out=None, order="K"): """ +greater_equal_func = BinaryElementwiseFunc( + "greater_equal", + ti._greater_equal_result_type, + ti._greater_equal, + _greater_equal_docstring_, +) + + def dpnp_greater_equal(x1, x2, out=None, order="K"): """Invokes greater_equal() from dpctl.tensor implementation for greater_equal() function.""" @@ -706,13 +745,9 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "greater_equal", - ti._greater_equal_result_type, - ti._greater_equal, - _greater_equal_docstring_, + res_usm = greater_equal_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -738,6 +773,14 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"): """ +invert_func = UnaryElementwiseFunc( + "invert", + ti._bitwise_invert_result_type, + ti._bitwise_invert, + _invert_docstring, +) + + def dpnp_invert(x, out=None, order="K"): """Invokes bitwise_invert() from dpctl.tensor implementation for invert() function.""" @@ -745,13 +788,7 @@ def dpnp_invert(x, out=None, order="K"): x_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "invert", - ti._bitwise_invert_result_type, - ti._bitwise_invert, - _invert_docstring, - ) - res_usm = func(x_usm, out=out_usm, order=order) + res_usm = invert_func(x_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -777,6 +814,11 @@ def dpnp_invert(x, out=None, order="K"): """ +isfinite_func = UnaryElementwiseFunc( + "isfinite", ti._isfinite_result_type, ti._isfinite, _isfinite_docstring +) + + def dpnp_isfinite(x, out=None, order="K"): """Invokes isfinite() from dpctl.tensor implementation for isfinite() function.""" @@ -784,10 +826,7 @@ def dpnp_isfinite(x, out=None, order="K"): x1_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "isfinite", ti._isfinite_result_type, ti._isfinite, _isfinite_docstring - ) - res_usm = func(x1_usm, out=out_usm, order=order) + res_usm = isfinite_func(x1_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -812,6 +851,11 @@ def dpnp_isfinite(x, out=None, order="K"): """ +isinf_func = UnaryElementwiseFunc( + "isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring +) + + def dpnp_isinf(x, out=None, order="K"): """Invokes isinf() from dpctl.tensor implementation for isinf() function.""" @@ -819,10 +863,7 @@ def dpnp_isinf(x, out=None, order="K"): x1_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring - ) - res_usm = func(x1_usm, out=out_usm, order=order) + res_usm = isinf_func(x1_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -847,6 +888,11 @@ def dpnp_isinf(x, out=None, order="K"): """ +isnan_func = UnaryElementwiseFunc( + "isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring +) + + def dpnp_isnan(x, out=None, order="K"): """Invokes isnan() from dpctl.tensor implementation for isnan() function.""" @@ -854,10 +900,7 @@ def dpnp_isnan(x, out=None, order="K"): x1_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring - ) - res_usm = func(x1_usm, out=out_usm, order=order) + res_usm = isnan_func(x1_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -887,6 +930,14 @@ def dpnp_isnan(x, out=None, order="K"): """ +leftt_shift_func = BinaryElementwiseFunc( + "bitwise_leftt_shift", + ti._bitwise_left_shift_result_type, + ti._bitwise_left_shift, + _left_shift_docstring_, +) + + def dpnp_left_shift(x1, x2, out=None, order="K"): """Invokes bitwise_left_shift() from dpctl.tensor implementation for left_shift() function.""" @@ -895,13 +946,9 @@ def dpnp_left_shift(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_leftt_shift", - ti._bitwise_left_shift_result_type, - ti._bitwise_left_shift, - _left_shift_docstring_, + res_usm = leftt_shift_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -929,6 +976,11 @@ def dpnp_left_shift(x1, x2, out=None, order="K"): """ +less_func = BinaryElementwiseFunc( + "less", ti._less_result_type, ti._less, _less_docstring_ +) + + def dpnp_less(x1, x2, out=None, order="K"): """Invokes less() from dpctl.tensor implementation for less() function.""" @@ -937,10 +989,9 @@ def dpnp_less(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "less", ti._less_result_type, ti._less, _less_docstring_ + res_usm = less_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -968,6 +1019,14 @@ def dpnp_less(x1, x2, out=None, order="K"): """ +less_equal_func = BinaryElementwiseFunc( + "less_equal", + ti._less_equal_result_type, + ti._less_equal, + _less_equal_docstring_, +) + + def dpnp_less_equal(x1, x2, out=None, order="K"): """Invokes less_equal() from dpctl.tensor implementation for less_equal() function.""" @@ -976,13 +1035,9 @@ def dpnp_less_equal(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "less_equal", - ti._less_equal_result_type, - ti._less_equal, - _less_equal_docstring_, + res_usm = less_equal_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1004,33 +1059,35 @@ def dpnp_less_equal(x1, x2, out=None, order="K"): """ +def _call_log(src, dst, sycl_queue, depends=None): + """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_ln_to_call(sycl_queue, src, dst): + # call pybind11 extension for ln() function from OneMKL VM + return vmi._ln(sycl_queue, src, dst, depends) + return ti._log(src, dst, sycl_queue, depends) + + +log_func = UnaryElementwiseFunc( + "log", ti._log_result_type, _call_log, _log_docstring +) + + def dpnp_log(x, out=None, order="K"): """ Invokes log() function from pybind11 extension of OneMKL VM if possible. Otherwise fully relies on dpctl.tensor implementation for log() function. - """ - def _call_log(src, dst, sycl_queue, depends=None): - """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" - - if depends is None: - depends = [] - - if vmi._mkl_ln_to_call(sycl_queue, src, dst): - # call pybind11 extension for ln() function from OneMKL VM - return vmi._ln(sycl_queue, src, dst, depends) - return ti._log(src, dst, sycl_queue, depends) - # dpctl.tensor only works with usm_ndarray x1_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "log", ti._log_result_type, _call_log, _log_docstring - ) - res_usm = func(x1_usm, out=out_usm, order=order) + res_usm = log_func(x1_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1057,6 +1114,14 @@ def _call_log(src, dst, sycl_queue, depends=None): """ +logical_and_func = BinaryElementwiseFunc( + "logical_and", + ti._logical_and_result_type, + ti._logical_and, + _logical_and_docstring_, +) + + def dpnp_logical_and(x1, x2, out=None, order="K"): """Invokes logical_and() from dpctl.tensor implementation for logical_and() function.""" @@ -1065,13 +1130,9 @@ def dpnp_logical_and(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "logical_and", - ti._logical_and_result_type, - ti._logical_and, - _logical_and_docstring_, + res_usm = logical_and_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1093,6 +1154,14 @@ def dpnp_logical_and(x1, x2, out=None, order="K"): """ +logical_not_func = UnaryElementwiseFunc( + "logical_not", + ti._logical_not_result_type, + ti._logical_not, + _logical_not_docstring_, +) + + def dpnp_logical_not(x, out=None, order="K"): """Invokes logical_not() from dpctl.tensor implementation for logical_not() function.""" @@ -1100,13 +1169,7 @@ def dpnp_logical_not(x, out=None, order="K"): x_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "logical_not", - ti._logical_not_result_type, - ti._logical_not, - _logical_not_docstring_, - ) - res_usm = func(x_usm, out=out_usm, order=order) + res_usm = logical_not_func(x_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1133,6 +1196,14 @@ def dpnp_logical_not(x, out=None, order="K"): """ +logical_or_func = BinaryElementwiseFunc( + "logical_or", + ti._logical_or_result_type, + ti._logical_or, + _logical_or_docstring_, +) + + def dpnp_logical_or(x1, x2, out=None, order="K"): """Invokes logical_or() from dpctl.tensor implementation for logical_or() function.""" @@ -1141,13 +1212,9 @@ def dpnp_logical_or(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "logical_or", - ti._logical_or_result_type, - ti._logical_or, - _logical_or_docstring_, + res_usm = logical_or_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1173,6 +1240,13 @@ def dpnp_logical_or(x1, x2, out=None, order="K"): An array containing the element-wise logical XOR results. """ +logical_xor_func = BinaryElementwiseFunc( + "logical_xor", + ti._logical_xor_result_type, + ti._logical_xor, + _logical_xor_docstring_, +) + def dpnp_logical_xor(x1, x2, out=None, order="K"): """Invokes logical_xor() from dpctl.tensor implementation for logical_xor() function.""" @@ -1182,13 +1256,9 @@ def dpnp_logical_xor(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "logical_xor", - ti._logical_xor_result_type, - ti._logical_xor, - _logical_xor_docstring_, + res_usm = logical_xor_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1264,6 +1334,13 @@ def dpnp_multiply(x1, x2, out=None, order="K"): The data type of the returned array is determined by the Type Promotion Rules. """ +not_equal_func = BinaryElementwiseFunc( + "not_equal", + ti._not_equal_result_type, + ti._not_equal, + _not_equal_docstring_, +) + def dpnp_not_equal(x1, x2, out=None, order="K"): """Invokes not_equal() from dpctl.tensor implementation for not_equal() function.""" @@ -1273,13 +1350,9 @@ def dpnp_not_equal(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "not_equal", - ti._not_equal_result_type, - ti._not_equal, - _not_equal_docstring_, + res_usm = not_equal_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1351,6 +1424,14 @@ def dpnp_remainder(x1, x2, out=None, order="K"): """ +right_shift_func = BinaryElementwiseFunc( + "bitwise_right_shift", + ti._bitwise_right_shift_result_type, + ti._bitwise_right_shift, + _right_shift_docstring_, +) + + def dpnp_right_shift(x1, x2, out=None, order="K"): """Invokes bitwise_right_shift() from dpctl.tensor implementation for right_shift() function.""" @@ -1359,13 +1440,9 @@ def dpnp_right_shift(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_right_shift", - ti._bitwise_right_shift_result_type, - ti._bitwise_right_shift, - _right_shift_docstring_, + res_usm = right_shift_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1388,33 +1465,35 @@ def dpnp_right_shift(x1, x2, out=None, order="K"): """ +def _call_sin(src, dst, sycl_queue, depends=None): + """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_sin_to_call(sycl_queue, src, dst): + # call pybind11 extension for sin() function from OneMKL VM + return vmi._sin(sycl_queue, src, dst, depends) + return ti._sin(src, dst, sycl_queue, depends) + + +sin_func = UnaryElementwiseFunc( + "sin", ti._sin_result_type, _call_sin, _sin_docstring +) + + def dpnp_sin(x, out=None, order="K"): """ Invokes sin() function from pybind11 extension of OneMKL VM if possible. Otherwise fully relies on dpctl.tensor implementation for sin() function. - """ - def _call_sin(src, dst, sycl_queue, depends=None): - """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" - - if depends is None: - depends = [] - - if vmi._mkl_sin_to_call(sycl_queue, src, dst): - # call pybind11 extension for sin() function from OneMKL VM - return vmi._sin(sycl_queue, src, dst, depends) - return ti._sin(src, dst, sycl_queue, depends) - # dpctl.tensor only works with usm_ndarray x1_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "sin", ti._sin_result_type, _call_sin, _sin_docstring - ) - res_usm = func(x1_usm, out=out_usm, order=order) + res_usm = sin_func(x1_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1436,36 +1515,38 @@ def _call_sin(src, dst, sycl_queue, depends=None): """ +def _call_sqrt(src, dst, sycl_queue, depends=None): + """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_sqrt_to_call(sycl_queue, src, dst): + # call pybind11 extension for sqrt() function from OneMKL VM + return vmi._sqrt(sycl_queue, src, dst, depends) + return ti._sqrt(src, dst, sycl_queue, depends) + + +sqrt_func = UnaryElementwiseFunc( + "sqrt", + ti._sqrt_result_type, + _call_sqrt, + _sqrt_docstring_, +) + + def dpnp_sqrt(x, out=None, order="K"): """ Invokes sqrt() function from pybind11 extension of OneMKL VM if possible. Otherwise fully relies on dpctl.tensor implementation for sqrt() function. - """ - def _call_sqrt(src, dst, sycl_queue, depends=None): - """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" - - if depends is None: - depends = [] - - if vmi._mkl_sqrt_to_call(sycl_queue, src, dst): - # call pybind11 extension for sqrt() function from OneMKL VM - return vmi._sqrt(sycl_queue, src, dst, depends) - return ti._sqrt(src, dst, sycl_queue, depends) - # dpctl.tensor only works with usm_ndarray or scalar x_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "sqrt", - ti._sqrt_result_type, - _call_sqrt, - _sqrt_docstring_, - ) - res_usm = func(x_usm, out=out_usm, order=order) + res_usm = sqrt_func(x_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1487,36 +1568,38 @@ def _call_sqrt(src, dst, sycl_queue, depends=None): """ +def _call_square(src, dst, sycl_queue, depends=None): + """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_sqr_to_call(sycl_queue, src, dst): + # call pybind11 extension for sqr() function from OneMKL VM + return vmi._sqr(sycl_queue, src, dst, depends) + return ti._square(src, dst, sycl_queue, depends) + + +square_func = UnaryElementwiseFunc( + "square", + ti._square_result_type, + _call_square, + _square_docstring_, +) + + def dpnp_square(x, out=None, order="K"): """ Invokes sqr() function from pybind11 extension of OneMKL VM if possible. Otherwise fully relies on dpctl.tensor implementation for square() function. - """ - def _call_square(src, dst, sycl_queue, depends=None): - """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" - - if depends is None: - depends = [] - - if vmi._mkl_sqr_to_call(sycl_queue, src, dst): - # call pybind11 extension for sqr() function from OneMKL VM - return vmi._sqr(sycl_queue, src, dst, depends) - return ti._square(src, dst, sycl_queue, depends) - # dpctl.tensor only works with usm_ndarray or scalar x_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "square", - ti._square_result_type, - _call_square, - _square_docstring_, - ) - res_usm = func(x_usm, out=out_usm, order=order) + res_usm = square_func(x_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) From afb84af724f65a99197c90591f652aac8179f173 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 16 Aug 2023 12:07:43 -0500 Subject: [PATCH 2/3] implement mkl version of add, subtract and multiply --- dpnp/backend/extensions/vm/add.hpp | 81 +++++++++++++++ dpnp/backend/extensions/vm/mul.hpp | 81 +++++++++++++++ dpnp/backend/extensions/vm/sub.hpp | 81 +++++++++++++++ dpnp/backend/extensions/vm/types_matrix.hpp | 75 ++++++++++++++ dpnp/backend/extensions/vm/vm_py.cpp | 105 ++++++++++++++++++- dpnp/dpnp_algo/dpnp_elementwise_common.py | 106 ++++++++++++++------ tests/test_mathematical.py | 2 - tests/test_usm_type.py | 2 +- 8 files changed, 497 insertions(+), 36 deletions(-) create mode 100644 dpnp/backend/extensions/vm/add.hpp create mode 100644 dpnp/backend/extensions/vm/mul.hpp create mode 100644 dpnp/backend/extensions/vm/sub.hpp diff --git a/dpnp/backend/extensions/vm/add.hpp b/dpnp/backend/extensions/vm/add.hpp new file mode 100644 index 000000000000..d80755f0ec0b --- /dev/null +++ b/dpnp/backend/extensions/vm/add.hpp @@ -0,0 +1,81 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +#include "common.hpp" +#include "types_matrix.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace vm +{ +template +sycl::event add_contig_impl(sycl::queue exec_q, + const std::int64_t n, + const char *in_a, + const char *in_b, + char *out_y, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(in_a); + const T *b = reinterpret_cast(in_b); + T *y = reinterpret_cast(out_y); + + return mkl_vm::add(exec_q, + n, // number of elements to be calculated + a, // pointer `a` containing 1st input vector of size n + b, // pointer `b` containing 2nd input vector of size n + y, // pointer `y` to the output vector of size n + depends); +} + +template +struct AddContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename types::AddOutputType::value_type, void>) + { + return nullptr; + } + else { + return add_contig_impl; + } + } +}; +} // namespace vm +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/vm/mul.hpp b/dpnp/backend/extensions/vm/mul.hpp new file mode 100644 index 000000000000..4f827a056192 --- /dev/null +++ b/dpnp/backend/extensions/vm/mul.hpp @@ -0,0 +1,81 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +#include "common.hpp" +#include "types_matrix.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace vm +{ +template +sycl::event mul_contig_impl(sycl::queue exec_q, + const std::int64_t n, + const char *in_a, + const char *in_b, + char *out_y, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(in_a); + const T *b = reinterpret_cast(in_b); + T *y = reinterpret_cast(out_y); + + return mkl_vm::mul(exec_q, + n, // number of elements to be calculated + a, // pointer `a` containing 1st input vector of size n + b, // pointer `b` containing 2nd input vector of size n + y, // pointer `y` to the output vector of size n + depends); +} + +template +struct MulContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename types::MulOutputType::value_type, void>) + { + return nullptr; + } + else { + return mul_contig_impl; + } + } +}; +} // namespace vm +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/vm/sub.hpp b/dpnp/backend/extensions/vm/sub.hpp new file mode 100644 index 000000000000..f7bec14d48bf --- /dev/null +++ b/dpnp/backend/extensions/vm/sub.hpp @@ -0,0 +1,81 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +#include "common.hpp" +#include "types_matrix.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace vm +{ +template +sycl::event sub_contig_impl(sycl::queue exec_q, + const std::int64_t n, + const char *in_a, + const char *in_b, + char *out_y, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(in_a); + const T *b = reinterpret_cast(in_b); + T *y = reinterpret_cast(out_y); + + return mkl_vm::sub(exec_q, + n, // number of elements to be calculated + a, // pointer `a` containing 1st input vector of size n + b, // pointer `b` containing 2nd input vector of size n + y, // pointer `y` to the output vector of size n + depends); +} + +template +struct SubContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename types::SubOutputType::value_type, void>) + { + return nullptr; + } + else { + return sub_contig_impl; + } + } +}; +} // namespace vm +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/vm/types_matrix.hpp b/dpnp/backend/extensions/vm/types_matrix.hpp index cd4fd76d4bef..584e39de50aa 100644 --- a/dpnp/backend/extensions/vm/types_matrix.hpp +++ b/dpnp/backend/extensions/vm/types_matrix.hpp @@ -43,6 +43,31 @@ namespace vm { namespace types { +/** + * @brief A factory to define pairs of supported types for which + * MKL VM library provides support in oneapi::mkl::vm::add function. + * + * @tparam T Type of input vectors `a` and `b` and of result vector `y`. + */ +template +struct AddOutputType +{ + using value_type = typename std::disjunction< + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::DefaultResultEntry>::result_type; +}; + /** * @brief A factory to define pairs of supported types for which * MKL VM library provides support in oneapi::mkl::vm::div function. @@ -83,6 +108,56 @@ struct CeilOutputType dpctl_td_ns::DefaultResultEntry>::result_type; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL VM library provides support in oneapi::mkl::vm::mul function. + * + * @tparam T Type of input vectors `a` and `b` and of result vector `y`. + */ +template +struct MulOutputType +{ + using value_type = typename std::disjunction< + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::DefaultResultEntry>::result_type; +}; + +/** + * @brief A factory to define pairs of supported types for which + * MKL VM library provides support in oneapi::mkl::vm::sub function. + * + * @tparam T Type of input vectors `a` and `b` and of result vector `y`. + */ +template +struct SubOutputType +{ + using value_type = typename std::disjunction< + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::DefaultResultEntry>::result_type; +}; + /** * @brief A factory to define pairs of supported types for which * MKL VM library provides support in oneapi::mkl::vm::cos function. diff --git a/dpnp/backend/extensions/vm/vm_py.cpp b/dpnp/backend/extensions/vm/vm_py.cpp index c7435ae9e2eb..170776c01ce3 100644 --- a/dpnp/backend/extensions/vm/vm_py.cpp +++ b/dpnp/backend/extensions/vm/vm_py.cpp @@ -30,15 +30,18 @@ #include #include -#include "ceil.hpp" #include "common.hpp" +#include "add.hpp" +#include "ceil.hpp" #include "cos.hpp" #include "div.hpp" #include "floor.hpp" #include "ln.hpp" +#include "mul.hpp" #include "sin.hpp" #include "sqr.hpp" #include "sqrt.hpp" +#include "sub.hpp" #include "trunc.hpp" #include "types_matrix.hpp" @@ -48,7 +51,10 @@ namespace vm_ext = dpnp::backend::ext::vm; using vm_ext::binary_impl_fn_ptr_t; using vm_ext::unary_impl_fn_ptr_t; +static binary_impl_fn_ptr_t add_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types]; +static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types]; +static binary_impl_fn_ptr_t sub_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types]; @@ -64,6 +70,39 @@ PYBIND11_MODULE(_vm_impl, m) using arrayT = dpctl::tensor::usm_ndarray; using event_vecT = std::vector; + // BinaryUfunc: ==== Add(x1, x2) ==== + { + vm_ext::init_ufunc_dispatch_vector( + add_dispatch_vector); + + auto add_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, + add_dispatch_vector); + }; + m.def("_add", add_pyapi, + "Call `add` function from OneMKL VM library to performs element " + "by element addition of vector `src1` by vector `src2` " + "to resulting vector `dst`", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("depends") = py::list()); + + auto add_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, + arrayT src2, arrayT dst) { + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, + add_dispatch_vector); + }; + m.def("_mkl_add_to_call", add_need_to_call_pyapi, + "Check input arguments to answer if `add` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst")); + } + + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + // BinaryUfunc: ==== Div(x1, x2) ==== { vm_ext::init_ufunc_dispatch_vector; + + // BinaryUfunc: ==== Mul(x1, x2) ==== + { + vm_ext::init_ufunc_dispatch_vector( + mul_dispatch_vector); + + auto mul_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, + mul_dispatch_vector); + }; + m.def("_mul", mul_pyapi, + "Call `mul` function from OneMKL VM library to performs element " + "by element multiplication of vector `src1` by vector `src2` " + "to resulting vector `dst`", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("depends") = py::list()); + + auto mul_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, + arrayT src2, arrayT dst) { + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, + mul_dispatch_vector); + }; + m.def("_mkl_mul_to_call", mul_need_to_call_pyapi, + "Check input arguments to answer if `mul` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst")); + } + + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + // BinaryUfunc: ==== Sub(x1, x2) ==== + { + vm_ext::init_ufunc_dispatch_vector( + sub_dispatch_vector); + + auto sub_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, + sub_dispatch_vector); + }; + m.def("_sub", sub_pyapi, + "Call `sub` function from OneMKL VM library to performs element " + "by element subtraction of vector `src1` by vector `src2` " + "to resulting vector `dst`", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("depends") = py::list()); + + auto sub_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, + arrayT src2, arrayT dst) { + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, + sub_dispatch_vector); + }; + m.def("_mkl_sub_to_call", sub_need_to_call_pyapi, + "Check input arguments to answer if `sub` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst")); } // UnaryUfunc: ==== Cos(x) ==== diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 4cf68ff55f0e..926aa22453d7 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -163,18 +163,33 @@ def check_nd_call_func( Default: "K". Returns: dpnp.ndarray: - an array containing the result of element-wise division. The data type + an array containing the result of element-wise addition. The data type of the returned array is determined by the Type Promotion Rules. """ +def _call_add(src1, src2, dst, sycl_queue, depends=None): + """A callback to register in BinaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_add_to_call(sycl_queue, src1, src2, dst): + # call pybind11 extension for add() function from OneMKL VM + return vmi._add(sycl_queue, src1, src2, dst, depends) + return ti._add(src1, src2, dst, sycl_queue, depends) + + +add_func = BinaryElementwiseFunc( + "add", ti._add_result_type, _call_add, _add_docstring_, ti._add_inplace +) + + def dpnp_add(x1, x2, out=None, order="K"): """ - Invokes add() from dpctl.tensor implementation for add() function. - - TODO: add a pybind11 extension of add() from OneMKL VM where possible - and would be performance effective. + Invokes add() function from pybind11 extension of OneMKL VM if possible. + Otherwise fully relies on dpctl.tensor implementation for add() function. """ # dpctl.tensor only works with usm_ndarray or scalar @@ -182,10 +197,9 @@ def dpnp_add(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "add", ti._add_result_type, ti._add, _add_docstring_, ti._add_inplace + res_usm = add_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1281,18 +1295,37 @@ def dpnp_logical_xor(x1, x2, out=None, order="K"): Default: "K". Returns: dpnp.ndarray: - an array containing the result of element-wise division. The data type + an array containing the result of element-wise multiplication. The data type of the returned array is determined by the Type Promotion Rules. """ +def _call_multiply(src1, src2, dst, sycl_queue, depends=None): + """A callback to register in BinaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_mul_to_call(sycl_queue, src1, src2, dst): + # call pybind11 extension for mul() function from OneMKL VM + return vmi._mul(sycl_queue, src1, src2, dst, depends) + return ti._multiply(src1, src2, dst, sycl_queue, depends) + + +multiply_func = BinaryElementwiseFunc( + "multiply", + ti._multiply_result_type, + _call_multiply, + _multiply_docstring_, + ti._multiply_inplace, +) + + def dpnp_multiply(x1, x2, out=None, order="K"): """ - Invokes multiply() from dpctl.tensor implementation for multiply() function. - - TODO: add a pybind11 extension of mul() from OneMKL VM where possible - and would be performance effective. + Invokes mul() function from pybind11 extension of OneMKL VM if possible. + Otherwise fully relies on dpctl.tensor implementation for multiply() function. """ # dpctl.tensor only works with usm_ndarray or scalar @@ -1300,14 +1333,9 @@ def dpnp_multiply(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "multiply", - ti._multiply_result_type, - ti._multiply, - _multiply_docstring_, - ti._multiply_inplace, + res_usm = multiply_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1622,18 +1650,37 @@ def dpnp_square(x, out=None, order="K"): Default: "K". Returns: dpnp.ndarray: - an array containing the result of element-wise division. The data type + an array containing the result of element-wise subtraction. The data type of the returned array is determined by the Type Promotion Rules. """ +def _call_subtract(src1, src2, dst, sycl_queue, depends=None): + """A callback to register in BinaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_sub_to_call(sycl_queue, src1, src2, dst): + # call pybind11 extension for sub() function from OneMKL VM + return vmi._sub(sycl_queue, src1, src2, dst, depends) + return ti._subtract(src1, src2, dst, sycl_queue, depends) + + +subtract_func = BinaryElementwiseFunc( + "subtract", + ti._subtract_result_type, + _call_subtract, + _subtract_docstring_, + ti._subtract_inplace, +) + + def dpnp_subtract(x1, x2, out=None, order="K"): """ - Invokes subtract() from dpctl.tensor implementation for subtract() function. - - TODO: add a pybind11 extension of sub() from OneMKL VM where possible - and would be performance effective. + Invokes sub() function from pybind11 extension of OneMKL VM if possible. + Otherwise fully relies on dpctl.tensor implementation for subtract() function. """ # TODO: discuss with dpctl if the check is needed to be moved there @@ -1652,14 +1699,9 @@ def dpnp_subtract(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "subtract", - ti._subtract_result_type, - ti._subtract, - _subtract_docstring_, - ti._subtract_inplace, + res_usm = subtract_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 9d944e3c597a..31fc664a8d53 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -772,7 +772,6 @@ def test_out_overlap(self, dtype): @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_none=True) ) - @pytest.mark.skip("mute unttil in-place support in dpctl is done") def test_inplace_strided_out(self, dtype): size = 21 @@ -862,7 +861,6 @@ def test_out_overlap(self, dtype): @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_none=True) ) - @pytest.mark.skip("mute unttil in-place support in dpctl is done") def test_inplace_strided_out(self, dtype): size = 21 diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 345a2d552509..239cf887ab41 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) @pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types) -def test_coerced_usm_types_sum(usm_type_x, usm_type_y): +def test_coerced_usm_types_add(usm_type_x, usm_type_y): x = dp.arange(1000, usm_type=usm_type_x) y = dp.arange(1000, usm_type=usm_type_y) From 24116e10f3fa7b5516ceb39a21c91fc5fbe24e7c Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 18 Aug 2023 12:02:50 -0500 Subject: [PATCH 3/3] address comments --- dpnp/backend/extensions/vm/types_matrix.hpp | 120 ++++++------ dpnp/backend/extensions/vm/vm_py.cpp | 196 ++++++++++---------- 2 files changed, 154 insertions(+), 162 deletions(-) diff --git a/dpnp/backend/extensions/vm/types_matrix.hpp b/dpnp/backend/extensions/vm/types_matrix.hpp index 584e39de50aa..ba88f192908d 100644 --- a/dpnp/backend/extensions/vm/types_matrix.hpp +++ b/dpnp/backend/extensions/vm/types_matrix.hpp @@ -68,31 +68,6 @@ struct AddOutputType dpctl_td_ns::DefaultResultEntry>::result_type; }; -/** - * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::div function. - * - * @tparam T Type of input vectors `a` and `b` and of result vector `y`. - */ -template -struct DivOutputType -{ - using value_type = typename std::disjunction< - dpctl_td_ns::BinaryTypeMapResultEntry, - T, - std::complex, - std::complex>, - dpctl_td_ns::BinaryTypeMapResultEntry, - T, - std::complex, - std::complex>, - dpctl_td_ns::BinaryTypeMapResultEntry, - dpctl_td_ns::BinaryTypeMapResultEntry, - dpctl_td_ns::DefaultResultEntry>::result_type; -}; - /** * @brief A factory to define pairs of supported types for which * MKL VM library provides support in oneapi::mkl::vm::ceil function. @@ -109,38 +84,32 @@ struct CeilOutputType }; /** - * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::mul function. + * @brief A factory to define pairs of supported types for which + * MKL VM library provides support in oneapi::mkl::vm::cos function. * - * @tparam T Type of input vectors `a` and `b` and of result vector `y`. + * @tparam T Type of input vector `a` and of result vector `y`. */ template -struct MulOutputType +struct CosOutputType { using value_type = typename std::disjunction< - dpctl_td_ns::BinaryTypeMapResultEntry, - T, - std::complex, - std::complex>, - dpctl_td_ns::BinaryTypeMapResultEntry, - T, - std::complex, - std::complex>, - dpctl_td_ns::BinaryTypeMapResultEntry, - dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns:: + TypeMapResultEntry, std::complex>, + dpctl_td_ns:: + TypeMapResultEntry, std::complex>, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; /** * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::sub function. + * MKL VM library provides support in oneapi::mkl::vm::div function. * * @tparam T Type of input vectors `a` and `b` and of result vector `y`. */ template -struct SubOutputType +struct DivOutputType { using value_type = typename std::disjunction< dpctl_td_ns::BinaryTypeMapResultEntry function. + * MKL VM library provides support in oneapi::mkl::vm::floor function. * * @tparam T Type of input vector `a` and of result vector `y`. */ template -struct CosOutputType +struct FloorOutputType { using value_type = typename std::disjunction< - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; @@ -179,14 +144,18 @@ struct CosOutputType /** * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::floor function. + * MKL VM library provides support in oneapi::mkl::vm::ln function. * * @tparam T Type of input vector `a` and of result vector `y`. */ template -struct FloorOutputType +struct LnOutputType { using value_type = typename std::disjunction< + dpctl_td_ns:: + TypeMapResultEntry, std::complex>, + dpctl_td_ns:: + TypeMapResultEntry, std::complex>, dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; @@ -194,20 +163,26 @@ struct FloorOutputType /** * @brief A factory to define pairs of supported types for which - * MKL VM library provides support in oneapi::mkl::vm::ln function. + * MKL VM library provides support in oneapi::mkl::vm::mul function. * - * @tparam T Type of input vector `a` and of result vector `y`. + * @tparam T Type of input vectors `a` and `b` and of result vector `y`. */ template -struct LnOutputType +struct MulOutputType { using value_type = typename std::disjunction< - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns:: - TypeMapResultEntry, std::complex>, - dpctl_td_ns::TypeMapResultEntry, - dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::BinaryTypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; @@ -264,6 +239,31 @@ struct SqrtOutputType dpctl_td_ns::DefaultResultEntry>::result_type; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL VM library provides support in oneapi::mkl::vm::sub function. + * + * @tparam T Type of input vectors `a` and `b` and of result vector `y`. + */ +template +struct SubOutputType +{ + using value_type = typename std::disjunction< + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + T, + std::complex, + std::complex>, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::DefaultResultEntry>::result_type; +}; + /** * @brief A factory to define pairs of supported types for which * MKL VM library provides support in oneapi::mkl::vm::trunc function. diff --git a/dpnp/backend/extensions/vm/vm_py.cpp b/dpnp/backend/extensions/vm/vm_py.cpp index 170776c01ce3..e4f470bee3ce 100644 --- a/dpnp/backend/extensions/vm/vm_py.cpp +++ b/dpnp/backend/extensions/vm/vm_py.cpp @@ -30,9 +30,9 @@ #include #include -#include "common.hpp" #include "add.hpp" #include "ceil.hpp" +#include "common.hpp" #include "cos.hpp" #include "div.hpp" #include "floor.hpp" @@ -52,17 +52,16 @@ using vm_ext::binary_impl_fn_ptr_t; using vm_ext::unary_impl_fn_ptr_t; static binary_impl_fn_ptr_t add_dispatch_vector[dpctl_td_ns::num_types]; -static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types]; -static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types]; -static binary_impl_fn_ptr_t sub_dispatch_vector[dpctl_td_ns::num_types]; - static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types]; +static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types]; +static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types]; +static binary_impl_fn_ptr_t sub_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t trunc_dispatch_vector[dpctl_td_ns::num_types]; PYBIND11_MODULE(_vm_impl, m) @@ -100,39 +99,6 @@ PYBIND11_MODULE(_vm_impl, m) py::arg("dst")); } - using arrayT = dpctl::tensor::usm_ndarray; - using event_vecT = std::vector; - - // BinaryUfunc: ==== Div(x1, x2) ==== - { - vm_ext::init_ufunc_dispatch_vector( - div_dispatch_vector); - - auto div_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, - arrayT dst, const event_vecT &depends = {}) { - return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, - div_dispatch_vector); - }; - m.def("_div", div_pyapi, - "Call `div` function from OneMKL VM library to performs element " - "by element division of vector `src1` by vector `src2` " - "to resulting vector `dst`", - py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("depends") = py::list()); - - auto div_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, - arrayT src2, arrayT dst) { - return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, - div_dispatch_vector); - }; - m.def("_mkl_div_to_call", div_need_to_call_pyapi, - "Check input arguments to answer if `div` function from " - "OneMKL VM library can be used", - py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), - py::arg("dst")); - } - // UnaryUfunc: ==== Ceil(x) ==== { vm_ext::init_ufunc_dispatch_vector; - - // BinaryUfunc: ==== Mul(x1, x2) ==== - { - vm_ext::init_ufunc_dispatch_vector( - mul_dispatch_vector); - - auto mul_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, - arrayT dst, const event_vecT &depends = {}) { - return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, - mul_dispatch_vector); - }; - m.def("_mul", mul_pyapi, - "Call `mul` function from OneMKL VM library to performs element " - "by element multiplication of vector `src1` by vector `src2` " - "to resulting vector `dst`", - py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("depends") = py::list()); - - auto mul_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, - arrayT src2, arrayT dst) { - return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, - mul_dispatch_vector); - }; - m.def("_mkl_mul_to_call", mul_need_to_call_pyapi, - "Check input arguments to answer if `mul` function from " - "OneMKL VM library can be used", - py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), - py::arg("dst")); - } - - using arrayT = dpctl::tensor::usm_ndarray; - using event_vecT = std::vector; - - // BinaryUfunc: ==== Sub(x1, x2) ==== - { - vm_ext::init_ufunc_dispatch_vector( - sub_dispatch_vector); - - auto sub_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, - arrayT dst, const event_vecT &depends = {}) { - return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, - sub_dispatch_vector); - }; - m.def("_sub", sub_pyapi, - "Call `sub` function from OneMKL VM library to performs element " - "by element subtraction of vector `src1` by vector `src2` " - "to resulting vector `dst`", - py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("depends") = py::list()); - - auto sub_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, - arrayT src2, arrayT dst) { - return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, - sub_dispatch_vector); - }; - m.def("_mkl_sub_to_call", sub_need_to_call_pyapi, - "Check input arguments to answer if `sub` function from " - "OneMKL VM library can be used", - py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), - py::arg("dst")); } // UnaryUfunc: ==== Cos(x) ==== @@ -253,6 +155,36 @@ PYBIND11_MODULE(_vm_impl, m) py::arg("sycl_queue"), py::arg("src"), py::arg("dst")); } + // BinaryUfunc: ==== Div(x1, x2) ==== + { + vm_ext::init_ufunc_dispatch_vector( + div_dispatch_vector); + + auto div_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, + div_dispatch_vector); + }; + m.def("_div", div_pyapi, + "Call `div` function from OneMKL VM library to performs element " + "by element division of vector `src1` by vector `src2` " + "to resulting vector `dst`", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("depends") = py::list()); + + auto div_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, + arrayT src2, arrayT dst) { + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, + div_dispatch_vector); + }; + m.def("_mkl_div_to_call", div_need_to_call_pyapi, + "Check input arguments to answer if `div` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst")); + } + // UnaryUfunc: ==== Floor(x) ==== { vm_ext::init_ufunc_dispatch_vector( + mul_dispatch_vector); + + auto mul_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, + mul_dispatch_vector); + }; + m.def("_mul", mul_pyapi, + "Call `mul` function from OneMKL VM library to performs element " + "by element multiplication of vector `src1` by vector `src2` " + "to resulting vector `dst`", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("depends") = py::list()); + + auto mul_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, + arrayT src2, arrayT dst) { + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, + mul_dispatch_vector); + }; + m.def("_mkl_mul_to_call", mul_need_to_call_pyapi, + "Check input arguments to answer if `mul` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst")); + } + // UnaryUfunc: ==== Sin(x) ==== { vm_ext::init_ufunc_dispatch_vector( + sub_dispatch_vector); + + auto sub_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, + sub_dispatch_vector); + }; + m.def("_sub", sub_pyapi, + "Call `sub` function from OneMKL VM library to performs element " + "by element subtraction of vector `src1` by vector `src2` " + "to resulting vector `dst`", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("depends") = py::list()); + + auto sub_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, + arrayT src2, arrayT dst) { + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, + sub_dispatch_vector); + }; + m.def("_mkl_sub_to_call", sub_need_to_call_pyapi, + "Check input arguments to answer if `sub` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst")); + } + // UnaryUfunc: ==== Trunc(x) ==== { vm_ext::init_ufunc_dispatch_vector