From 27f8ba643d7db1dd19e9a24e6794240e2b1fd6d3 Mon Sep 17 00:00:00 2001 From: hanke580 <38852697+hanke580@users.noreply.github.com> Date: Mon, 23 Mar 2020 15:36:39 +0800 Subject: [PATCH] [Numpy] Add op fmax, fmin, fmod (#17567) * [Numpy] Add op fmax, fmin * Fix sanity * Fix bug of gpu part, add scalar compute * Finish cpu,gpu test of fmax, fmin * * Fix 3rd Party * * Prune redundent alias * * Add op fmod (rfmod still need check) * * Fix rfmod function * * Impl FFI * * Fix windows oversize by adding files Co-authored-by: Han --- benchmark/python/ffi/benchmark_ffi.py | 3 + python/mxnet/ndarray/numpy/_op.py | 76 +++++++++- python/mxnet/numpy/multiarray.py | 96 +++++++++++- python/mxnet/numpy_dispatch_protocol.py | 3 + python/mxnet/symbol/numpy/_symbol.py | 24 ++- .../np_elemwise_broadcast_op_extended_sec.cc | 56 +++++++ src/operator/mshadow_op.h | 50 ++++++ .../np_elemwise_broadcast_op_extended_sec.cc | 142 ++++++++++++++++++ .../np_elemwise_broadcast_op_extended_sec.cu | 77 ++++++++++ src/operator/operator_tune.cc | 4 + .../unittest/test_numpy_interoperability.py | 24 +++ tests/python/unittest/test_numpy_op.py | 10 ++ 12 files changed, 559 insertions(+), 6 deletions(-) create mode 100644 src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc create mode 100644 src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc create mode 100644 src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index 1983de594b28..98addb02ffda 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -80,6 +80,9 @@ def prepare_workloads(): OpArgMngr.add_workload("ones_like", pool['2x2']) OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1) OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1']) + OpArgMngr.add_workload("fmax", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("fmin", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("fmod", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1]) OpArgMngr.add_workload("roll", pool["2x2"], 1, axis=0) OpArgMngr.add_workload("rot90", pool["2x2"], 2) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 32519d142c1e..84aea04b10c3 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -30,7 +30,8 @@ __all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'invert', 'delete', - 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not', + 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', + 'power', 'bitwise_not', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'insert', 'fabs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'matmul', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', @@ -38,7 +39,7 @@ 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit', 'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', - 'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_', 'flatnonzero', + 'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', 'around', 'round', 'round_', 'flatnonzero', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', @@ -1169,6 +1170,35 @@ def mod(x1, x2, out=None, **kwargs): return _api_internal.mod(x1, x2, out) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def fmod(x1, x2, out=None, **kwargs): + """ + Return element-wise remainder of division. + + Parameters + ---------- + x1 : ndarray or scalar + Dividend array. + + x2 : ndarray or scalar + Divisor array. + + out : ndarray + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + out : ndarray or scalar + This is a scalar if both x1 and x2 are scalars. + """ + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + _np.fmod(x1, x2, out=out) + return _api_internal.fmod(x1, x2, out) + + @set_module('mxnet.ndarray.numpy') def delete(arr, obj, axis=None): """ @@ -4366,6 +4396,27 @@ def maximum(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def fmax(x1, x2, out=None, **kwargs): + """ + Returns element-wise maximum of the input arrays with broadcasting. (Ignores NaNs) + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + _np.fmax(x1, x2, out=out) + return _api_internal.fmax(x1, x2, out) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def minimum(x1, x2, out=None, **kwargs): @@ -4385,6 +4436,27 @@ def minimum(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def fmin(x1, x2, out=None, **kwargs): + """ + Returns element-wise minimum of the input arrays with broadcasting. (Ignores NaNs) + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + _np.fmin(x1, x2, out=out) + return _api_internal.fmin(x1, x2, out) + + @set_module('mxnet.ndarray.numpy') def swapaxes(a, axis1, axis2): """Interchange two axes of an array. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 04476919dbd0..25d46912ea6e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -53,7 +53,8 @@ __all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'all', 'any', 'broadcast_to', - 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not', 'delete', + 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', 'power', 'bitwise_not', + 'delete', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'invert', 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram', @@ -61,7 +62,8 @@ 'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit', 'flatnonzero', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', - 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert', + 'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', + 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', @@ -3147,6 +3149,38 @@ def mod(x1, x2, out=None, **kwargs): return _mx_nd_np.mod(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def fmod(x1, x2, out=None, **kwargs): + """ + Return element-wise remainder of division. + + Parameters + ---------- + x1 : ndarray or scalar + Dividend array. + + x2 : ndarray or scalar + Divisor array. + + out : ndarray + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + out : ndarray or scalar + This is a scalar if both x1 and x2 are scalars. + + Examples + -------- + >>> np.fmod(np.arange(7), 5) + array([0., 1., 2., 3., 4., 0., 1.]) + """ + return _mx_nd_np.fmod(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def matmul(a, b, out=None, **kwargs): @@ -6185,6 +6219,35 @@ def maximum(x1, x2, out=None, **kwargs): return _mx_nd_np.maximum(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def fmax(x1, x2, out=None, **kwargs): + """ + Returns element-wise maximum of the input arrays with broadcasting. (Ignores NaNs) + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Examples + -------- + >>> np.fmax(np.array([2, 3, 4]), np.array([1, 5, 2])) + array([2., 5., 4.]) + + >>> np.fmax(np.eye(2), np.array([0.5, 2])) # broadcasting + array([[1. , 2. ], + [0.5, 2. ]]) + """ + return _mx_nd_np.fmax(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def minimum(x1, x2, out=None, **kwargs): @@ -6214,6 +6277,35 @@ def minimum(x1, x2, out=None, **kwargs): return _mx_nd_np.minimum(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def fmin(x1, x2, out=None, **kwargs): + """ + Returns element-wise minimum of the input arrays with broadcasting. (Ignores NaNs) + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The fmin of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Examples + -------- + >>> np.fmin(np.array([2, 3, 4]), np.array([1, 5, 2])) + array([1., 3., 2.]) + + >>> np.fmin(np.eye(2), np.array([0.5, 2])) # broadcasting + array([[0.5, 0. ], + [0. , 1. ]]) + """ + return _mx_nd_np.fmin(x1, x2, out=out) + + @set_module('mxnet.numpy') def swapaxes(a, axis1, axis2): """Interchange two axes of an array. diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 781ec55b3796..a4b251b55607 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -248,6 +248,7 @@ def _register_array_function(): 'negative', 'power', 'mod', + 'fmod', 'matmul', 'absolute', 'rint', @@ -277,7 +278,9 @@ def _register_array_function(): 'arccosh', 'arctanh', 'maximum', + 'fmax', 'minimum', + 'fmin', 'ceil', 'trunc', 'floor', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 717049a6a8bf..d29768b49d1b 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -36,14 +36,16 @@ from builtins import slice as py_slice __all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'bitwise_not', 'invert', - 'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', + 'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', + 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'matmul', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', 'insert', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit', 'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', - 'average', 'mean', 'maximum', 'minimum', 'any', 'all', 'around', 'round', 'round_', 'flatnonzero', + 'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', 'any', 'all', 'around', 'round', 'round_', + 'flatnonzero', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', @@ -1620,6 +1622,12 @@ def mod(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.mod, _np.mod, _npi.mod_scalar, _npi.rmod_scalar, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def fmod(x1, x2, out=None, **kwargs): + return _ufunc_helper(x1, x2, _npi.fmod, _np.fmod, _npi.fmod_scalar, _npi.rfmod_scalar, out) + + @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def remainder(x1, x2, out=None, **kwargs): @@ -4127,12 +4135,24 @@ def maximum(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def fmax(x1, x2, out=None, **kwargs): + return _ufunc_helper(x1, x2, _npi.fmax, _np.fmax, _npi.fmax_scalar, None, out) + + @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def minimum(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def fmin(x1, x2, out=None, **kwargs): + return _ufunc_helper(x1, x2, _npi.fmin, _np.fmin, _npi.fmin_scalar, None, out) + + @set_module('mxnet.symbol.numpy') def all(a, axis=None, out=None, keepdims=False): """ diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc b/src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc new file mode 100644 index 000000000000..248af4dd6e3e --- /dev/null +++ b/src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_elemwise_broadcast_op_extended_sec.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc + */ +#include +#include +#include "../utils.h" +#include "../ufunc_helper.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.fmax") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_fmax"); + const nnvm::Op* op_scalar = Op::Get("_npi_fmax_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); +}); + +MXNET_REGISTER_API("_npi.fmin") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_fmin"); + const nnvm::Op* op_scalar = Op::Get("_npi_fmin_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); +}); + +MXNET_REGISTER_API("_npi.fmod") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_fmod"); + const nnvm::Op* op_scalar = Op::Get("_npi_fmod_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rfmod_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); +}); + +} // namespace mxnet diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 9106ee222542..9f732343d882 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -807,6 +807,28 @@ struct mod : public mxnet_op::tunable { } }; +struct fmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (b == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(a), static_cast(b))); + } + } +}; + +struct rfmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (a == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(b), static_cast(a))); + } + } +}; + template<> MSHADOW_XINLINE mshadow::half::half2_t mod::Map (mshadow::half::half2_t a, @@ -1143,6 +1165,20 @@ struct maximum : public mxnet_op::tunable { } }; +/*! \brief used for computing binary operator fmax */ +struct fmax : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (IsNan(b)) { + return a; + } else if (IsNan(a)) { + return b; + } else { + return (a > b ? a : b); + } + } +}; + /*! \brief used for computing binary operator minimum */ struct minimum : public mxnet_op::tunable { template @@ -1155,6 +1191,20 @@ struct minimum : public mxnet_op::tunable { } }; +/*! \brief used for computing binary operator fmin */ +struct fmin : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (IsNan(b)) { + return a; + } else if (IsNan(a)) { + return b; + } else { + return (a < b ? a : b); + } + } +}; + /*! \brief boolean any/all kernel that determines whether elem is NonZero */ struct NonZero { template diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc new file mode 100644 index 000000000000..7455da139a14 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_broadcast_op_extended_sec.cc + * \brief CPU Implementation of extended functions for elementwise numpy binary broadcast operator. (Second extended file) + */ + +#include "../../common/utils.h" +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_argument("scalar", "float", "scalar input") + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmax) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmax"}); + +NNVM_REGISTER_OP(_backward_npi_fmax) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmax_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmax_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmax_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmin) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmin"}); + +NNVM_REGISTER_OP(_backward_npi_fmin) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmin_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmin_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmin_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmod) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmod"}); + +NNVM_REGISTER_OP(_backward_npi_fmod) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmod_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmod_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rfmod_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rfmod_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu new file mode 100644 index 000000000000..fa2f3bf080c7 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cu @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_broadcast_op_extended_sec.cu + * \brief GPU Implementation of extended functions for elementwise binary broadcast operator. (Second extended file) + */ + +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_fmax) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_fmax) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_npi_fmax_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_fmax_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_npi_fmin) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_fmin) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_npi_fmin_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_fmin_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_npi_fmod) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_fmod) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_npi_fmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_fmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_rfmod_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 0cc0dc92f884..b76e341b9fc6 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -345,6 +345,8 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mod); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fmod); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rfmod); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mod_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rmod); // NOLINT() @@ -375,7 +377,9 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fmax); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fmin); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::hypot_grad_left); // NOLINT() diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 2d1710eeedc9..f58002c8634b 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1415,6 +1415,13 @@ def _add_workload_mod(array_pool): OpArgMngr.add_workload('mod', array_pool['4x1'], array_pool['1x1x0']) +def _add_workload_fmod(array_pool): + OpArgMngr.add_workload('fmod', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('fmod', array_pool['4x1'], 2) + OpArgMngr.add_workload('fmod', 2, array_pool['4x1']) + OpArgMngr.add_workload('fmod', array_pool['4x1'], array_pool['1x1x0']) + + def _add_workload_remainder(): # test remainder basic OpArgMngr.add_workload('remainder', np.array([0, 1, 2, 4, 2], dtype=np.float16), @@ -1481,6 +1488,13 @@ def _add_workload_maximum(array_pool): OpArgMngr.add_workload('maximum', array_pool['4x1'], array_pool['1x1x0']) +def _add_workload_fmax(array_pool): + OpArgMngr.add_workload('fmax', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('fmax', array_pool['4x1'], 2) + OpArgMngr.add_workload('fmax', 2, array_pool['4x1']) + OpArgMngr.add_workload('fmax', array_pool['4x1'], array_pool['1x1x0']) + + def _add_workload_minimum(array_pool): OpArgMngr.add_workload('minimum', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('minimum', array_pool['4x1'], 2) @@ -1488,6 +1502,13 @@ def _add_workload_minimum(array_pool): OpArgMngr.add_workload('minimum', array_pool['4x1'], array_pool['1x1x0']) +def _add_workload_fmin(array_pool): + OpArgMngr.add_workload('fmin', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('fmin', array_pool['4x1'], 2) + OpArgMngr.add_workload('fmin', 2, array_pool['4x1']) + OpArgMngr.add_workload('fmin', array_pool['4x1'], array_pool['1x1x0']) + + def _add_workload_negative(array_pool): OpArgMngr.add_workload('negative', array_pool['4x1']) @@ -2833,9 +2854,12 @@ def _prepare_workloads(): _add_workload_multiply(array_pool) _add_workload_power(array_pool) _add_workload_mod(array_pool) + _add_workload_fmod(array_pool) _add_workload_remainder() _add_workload_maximum(array_pool) + _add_workload_fmax(array_pool) _add_workload_minimum(array_pool) + _add_workload_fmin(array_pool) _add_workload_negative(array_pool) _add_workload_absolute(array_pool) _add_workload_sign(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 85dbe70397d5..ef5de655cbca 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2254,6 +2254,12 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): [lambda y, x1, x2: -_np.floor(x1 / x2), lambda y, x1, x2: _np.zeros(y.shape)], [[_np.float16, _np.float32, _np.float64], [_np.int32]]), + 'fmod': (1.0, 10.0, + [lambda y, x1, x2: _np.ones(y.shape), + lambda y, x1, x2: _np.zeros(y.shape)], + [lambda y, x1, x2: -_np.floor(x1 / x2), + lambda y, x1, x2: _np.zeros(y.shape)], + [[_np.float16, _np.float32, _np.float64], [_np.int32]]), 'remainder': (1.0, 10.0, [lambda y, x1, x2: _np.ones(y.shape), lambda y, x1, x2: _np.zeros(y.shape)], @@ -2268,8 +2274,12 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'bitwise_or': (-100, 100, [None], None, [[_np.int32]]), 'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), + 'fmax': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], + [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), 'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 > x2)]), + 'fmin': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)], + [lambda y, x1, x2: _np.ones(y.shape) * (x1 > x2)]), 'copysign': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (((x1 * x2) >= 0).astype(_np.float32) - ((x1 * x2) < 0).astype(_np.float32))], [lambda y, x1, x2: _np.zeros(y.shape)]),