Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] add op fabs, sometrue, round_ (#17619)
Browse files Browse the repository at this point in the history
* ok

* solve

* delete convenient method

* change sth
  • Loading branch information
Yiyan66 authored Feb 25, 2020
1 parent f9b2a63 commit 12d9191
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 27 deletions.
13 changes: 13 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ def _np_any(a, axis=None, keepdims=False, out=None):
pass


def _np_sometrue(a, axis=None, keepdims=False, out=None):
"""
Check whether some values are true.
Refer to `any` for full documentation.
See Also
--------
any : equivalent function; see for details.
"""
pass


def _np_cumsum(a, axis=None, dtype=None, out=None):
"""
Return the cumulative sum of the elements along a given axis.
Expand Down
60 changes: 57 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@

__all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'invert', 'delete',
'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'insert',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'insert', 'fabs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'matmul',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort',
'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
Expand Down Expand Up @@ -2251,6 +2251,41 @@ def abs(x, out=None, **kwargs):
return _unary_func_helper(x, _npi.abs, _np.abs, out=out, **kwargs)


@set_module('mxnet.ndarray.numpy')
@wrap_np_unary_func
def fabs(x, out=None, **kwargs):
r"""
Calculate the absolute value element-wise.
This function returns the absolute values (positive magnitude) of the
data in `x`. Complex values are not handled, use `absolute` to find the
absolute values of complex data.
Parameters
----------
x : ndarray or scalar
Input array.
out : ndarray or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
absolute : ndarray
An ndarray containing the absolute value of
each element in `x`. This is a scalar if `x` is a scalar.
Examples
--------
>>> np.fabs(-1)
1.0
>>> np.fabs(np.array([-1.2, 1.2]))s
array([ 1.2, 1.2])
"""
return _unary_func_helper(x, _npi.abs, _np.abs, out=out, **kwargs)


@set_module('mxnet.ndarray.numpy')
@wrap_np_unary_func
def absolute(x, out=None, **kwargs):
Expand Down Expand Up @@ -5457,6 +5492,25 @@ def around(x, decimals=0, out=None, **kwargs):

@set_module('mxnet.ndarray.numpy')
def round(x, decimals=0, out=None, **kwargs):
r"""
round(a, decimals=0, out=None)
Round an array to the given number of decimals.
See Also
--------
around : equivalent function; see for details.
"""
from ...numpy import ndarray
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, ndarray):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.ndarray.numpy')
def round_(x, decimals=0, out=None, **kwargs):
r"""
round_(a, decimals=0, out=None)
Round an array to the given number of decimals.
Expand Down
54 changes: 51 additions & 3 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@
'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'broadcast_to',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not', 'delete',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'invert',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert',
'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman',
'flip', 'flipud', 'fliplr', 'around', 'round', 'arctan2', 'hypot',
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
Expand Down Expand Up @@ -3554,6 +3554,41 @@ def abs(x, out=None, **kwargs):
return _mx_nd_np.abs(x, out=out, **kwargs)


@set_module('mxnet.numpy')
@wrap_np_unary_func
def fabs(x, out=None, **kwargs):
r"""
Calculate the absolute value element-wise.
This function returns the absolute values (positive magnitude) of the
data in `x`. Complex values are not handled, use `absolute` to find the
absolute values of complex data.
Parameters
----------
x : ndarray or scalar
Input array.
out : ndarray or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
absolute : ndarray
An ndarray containing the absolute value of
each element in `x`. This is a scalar if `x` is a scalar.
Examples
--------
>>> np.fabs(-1)
1.0
>>> np.fabs(np.array([-1.2, 1.2]))s
array([ 1.2, 1.2])
"""
return _mx_nd_np.fabs(x, out=out, **kwargs)


@set_module('mxnet.numpy')
@wrap_np_unary_func
def absolute(x, out=None, **kwargs):
Expand Down Expand Up @@ -7285,6 +7320,19 @@ def around(x, decimals=0, out=None, **kwargs):

@set_module('mxnet.numpy')
def round(x, decimals=0, out=None, **kwargs):
r"""
round(a, decimals=0, out=None)
Round an array to the given number of decimals.
See Also
--------
around : equivalent function; see for details.
"""
return _mx_nd_np.round(x, decimals, out=out, **kwargs)


@set_module('mxnet.numpy')
def round_(x, decimals=0, out=None, **kwargs):
r"""
round_(a, decimals=0, out=None)
Round an array to the given number of decimals.
Expand All @@ -7293,7 +7341,7 @@ def round(x, decimals=0, out=None, **kwargs):
--------
around : equivalent function; see for details.
"""
return _mx_nd_np.around(x, decimals, out=out, **kwargs)
return _mx_nd_np.round_(x, decimals, out=out, **kwargs)


@set_module('mxnet.numpy')
Expand Down
3 changes: 3 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
_NUMPY_ARRAY_FUNCTION_LIST = [
'all',
'any',
'sometrue',
'argmin',
'argmax',
'around',
'round',
'round_',
'argsort',
'sort',
'append',
Expand Down Expand Up @@ -230,6 +232,7 @@ def _register_array_function():
# https://docs.scipy.org/doc/numpy/reference/ufuncs.html#available-ufuncs
_NUMPY_ARRAY_UFUNC_LIST = [
'abs',
'fabs',
'add',
'arctan2',
'copysign',
Expand Down
50 changes: 47 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@

__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'bitwise_not', 'invert',
'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'matmul',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', 'insert',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
Expand Down Expand Up @@ -2421,6 +2421,32 @@ def abs(x, out=None, **kwargs):
return _unary_func_helper(x, _npi.abs, _np.abs, out=out, **kwargs)


@set_module('mxnet.symbol.numpy')
@wrap_np_unary_func
def fabs(x, out=None, **kwargs):
r"""
Calculate the absolute value element-wise.
This function returns the absolute values (positive magnitude) of the
data in `x`. Complex values are not handled, use `absolute` to find the
absolute values of complex data.
Parameters
----------
x : _Symbol or scalar
Input array.
out : _Symbol or None
Dummy parameter to keep the consistency with the ndarray counterpart.
Returns
-------
absolute : _Symbol
An ndarray containing the absolute value of
each element in `x`. This is a scalar if `x` is a scalar.
"""
return _unary_func_helper(x, _npi.abs, _np.abs, out=out, **kwargs)


@set_module('mxnet.symbol.numpy')
@wrap_np_unary_func
def absolute(x, out=None, **kwargs):
Expand Down Expand Up @@ -5043,6 +5069,24 @@ def around(x, decimals=0, out=None, **kwargs):

@set_module('mxnet.symbol.numpy')
def round(x, decimals=0, out=None, **kwargs):
r"""
round(a, decimals=0, out=None)
Round an array to the given number of decimals.
See Also
--------
around : equivalent function; see for details.
"""
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, _Symbol):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.symbol.numpy')
def round_(x, decimals=0, out=None, **kwargs):
r"""
round_(a, decimals=0, out=None)
Round an array to the given number of decimals.
Expand Down
1 change: 1 addition & 0 deletions src/operator/numpy/np_broadcast_reduce_op_boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ inline bool NumpyReduceAxesBoolType(const nnvm::NodeAttrs& attrs,
DMLC_REGISTER_PARAMETER(NumpyReduceAxesBoolParam);

NNVM_REGISTER_OP(_np_any)
.add_alias("_np_sometrue")
.set_attr_parser(ParamParser<NumpyReduceAxesBoolParam>)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ def _add_workload_any():
OpArgMngr.add_workload('any', d)


def _add_workload_sometrue():
# check bad element in all positions
for i in range(256-7):
d = np.array([False] * 256, dtype=bool)[7::]
d[i] = True
OpArgMngr.add_workload('sometrue', d)
# big array test for blocked libc loops
for i in list(range(9, 6000, 507)) + [7764, 90021, -10]:
d = np.array([False] * 100043, dtype=bool)
d[i] = True
OpArgMngr.add_workload('sometrue', d)


def _add_workload_unravel_index():
OpArgMngr.add_workload('unravel_index', indices=np.array([2],dtype=_np.int64), shape=(2, 2))
OpArgMngr.add_workload('unravel_index', np.array([(2*3 + 1)*6 + 4], dtype=_np.int64), (4, 3, 6))
Expand Down Expand Up @@ -819,6 +832,10 @@ def _add_workload_round():
OpArgMngr.add_workload('round', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1)


def _add_workload_round_():
OpArgMngr.add_workload('round_', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1)


def _add_workload_argsort():
for dtype in [np.int32, np.float32]:
a = np.arange(101, dtype=dtype)
Expand Down Expand Up @@ -1224,6 +1241,12 @@ def _add_workload_abs():
OpArgMngr.add_workload('abs', np.array([np.inf, -np.inf, np.nan]))


def _add_workload_fabs():
OpArgMngr.add_workload('fabs', np.random.uniform(size=(11,)).astype(np.float32))
OpArgMngr.add_workload('fabs', np.random.uniform(size=(5,)).astype(np.float64))
OpArgMngr.add_workload('fabs', np.array([np.inf, -np.inf, np.nan]))


def _add_workload_add(array_pool):
OpArgMngr.add_workload('add', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('add', array_pool['4x1'], 2)
Expand Down Expand Up @@ -2691,10 +2714,12 @@ def _prepare_workloads():

_add_workload_all()
_add_workload_any()
_add_workload_sometrue()
_add_workload_argmin()
_add_workload_argmax()
_add_workload_around()
_add_workload_round()
_add_workload_round_()
_add_workload_argsort()
_add_workload_sort()
_add_workload_append()
Expand Down Expand Up @@ -2775,6 +2800,7 @@ def _prepare_workloads():
_add_workload_meshgrid()
_add_workload_einsum()
_add_workload_abs()
_add_workload_fabs()
_add_workload_add(array_pool)
_add_workload_arctan2()
_add_workload_copysign()
Expand Down
Loading

0 comments on commit 12d9191

Please sign in to comment.