Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add op isnan isinf
Browse files Browse the repository at this point in the history
  • Loading branch information
Alicia1529 committed Feb 6, 2020
1 parent 8e0dc92 commit 5d6567b
Show file tree
Hide file tree
Showing 9 changed files with 515 additions and 4 deletions.
102 changes: 101 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'nan_to_num', 'where', 'bincount']
'diff', 'resize', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6543,6 +6543,106 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.ndarray.numpy')
@wrap_np_unary_func
def isnan(x, out=None, **kwargs):
"""
Test element-wise for NaN and return result as a boolean array.
Parameters
----------
x : ndarray
Input array.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
y : ndarray or bool
True where x is NaN, false otherwise.
This is a scalar if x is a scalar.
Notes
-----
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
This means that Not a Number is not equivalent to infinity.
This function differs from the original `numpy.isnan
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.isnan.html>`_ in
the following aspects:
- Does not support complex number for now
- Input type does not support Python native iterables(list, tuple, ...).
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
- ``out`` param does not support scalar input case.
Examples
--------
>>> np.isnan(np.nan)
True
>>> np.isnan(np.inf)
False
>>> np.isnan(np.array([np.log(-1.),1.,np.log(0)]))
array([ True, False, False])
"""
return _unary_func_helper(x, _npi.isnan, _np.isnan, out=out, **kwargs)


@set_module('mxnet.ndarray.numpy')
@wrap_np_unary_func
def isinf(x, out=None, **kwargs):
"""
Test element-wise for positive or negative infinity.
Parameters
----------
x : ndarray
Input array.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
y : ndarray or bool
True where x is positive or negative infinity, false otherwise.
This is a scalar if x is a scalar.
Notes
-----
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
This means that Not a Number is not equivalent to infinity.
This function differs from the original `numpy.isnan
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.isnan.html>`_ in
the following aspects:
- Does not support complex number for now
- Input type does not support Python native iterables(list, tuple, ...).
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
- ``out`` param does not support scalar input case.
Examples
--------
>>> np.isinf(np.inf)
True
>>> np.isinf(np.nan)
False
>>> np.isinf(np.array([np.inf, -np.inf, 1.0, np.nan]))
array([ True, True, False, False])
>>> x = np.array([-np.inf, 0., np.inf])
>>> y = np.array([True, True, True], dtype=np.bool_)
>>> np.isinf(x, y)
array([ True, False, True])
>>> y
array([ True, False, True])
"""
return _unary_func_helper(x, _npi.isinf, _np.isinf, out=out, **kwargs)


@set_module('mxnet.ndarray.numpy')
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Expand Down
104 changes: 102 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where',
'bincount']
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']

__all__ += fallback.__all__

Expand Down Expand Up @@ -8625,6 +8625,106 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
return _mx_nd_np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


@set_module('mxnet.numpy')
@wrap_np_unary_func
def isnan(x, out=None, **kwargs):
"""
Test element-wise for NaN and return result as a boolean array.
Parameters
----------
x : ndarray
Input array.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
y : ndarray or bool
True where x is NaN, false otherwise.
This is a scalar if x is a scalar.
Notes
-----
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
This means that Not a Number is not equivalent to infinity.
This function differs from the original `numpy.isnan
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.isnan.html>`_ in
the following aspects:
- Does not support complex number for now
- Input type does not support Python native iterables(list, tuple, ...).
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
- ``out`` param does not support scalar input case.
Examples
--------
>>> np.isnan(np.nan)
True
>>> np.isnan(np.inf)
False
>>> np.isnan(np.array([np.log(-1.),1.,np.log(0)]))
array([ True, False, False])
"""
return _mx_nd_np.isnan(x, out=out, **kwargs)


@set_module('mxnet.numpy')
@wrap_np_unary_func
def isinf(x, out=None, **kwargs):
"""
Test element-wise for positive or negative infinity.
Parameters
----------
x : ndarray
Input array.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
y : ndarray or bool
True where x is positive or negative infinity, false otherwise.
This is a scalar if x is a scalar.
Notes
-----
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
This means that Not a Number is not equivalent to infinity.
This function differs from the original `numpy.isnan
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.isnan.html>`_ in
the following aspects:
- Does not support complex number for now
- Input type does not support Python native iterables(list, tuple, ...).
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
- ``out`` param does not support scalar input case.
Examples
--------
>>> np.isinf(np.inf)
True
>>> np.isinf(np.nan)
False
>>> np.isinf(np.array([np.inf, -np.inf, 1.0, np.nan]))
array([ True, True, False, False])
>>> x = np.array([-np.inf, 0., np.inf])
>>> y = np.array([True, True, True], dtype=np.bool_)
>>> np.isinf(x, y)
array([ True, False, True])
>>> y
array([ True, False, True])
"""
return _mx_nd_np.isinf(x, out=out, **kwargs)


@set_module('mxnet.numpy')
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'bincount',
'empty_like',
'nan_to_num',
'isnan',
'isinf',
]


Expand Down
78 changes: 77 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num', 'where', 'bincount']
'resize', 'nan_to_num', 'isnan', 'isinf', 'where', 'bincount']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -5883,6 +5883,82 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.symbol.numpy')
@wrap_np_unary_func
def isnan(x, out=None, **kwargs):
"""
Test element-wise for NaN and return result as a boolean array.
Parameters
----------
x : _Symbol
Input array.
out : _Symbol or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
y : _Symbol or bool
True where x is NaN, false otherwise.
This is a scalar if x is a scalar.
Notes
-----
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
This means that Not a Number is not equivalent to infinity.
This function differs from the original `numpy.isnan
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.isnan.html>`_ in
the following aspects:
- Does not support complex number for now
- Input type does not support Python native iterables(list, tuple, ...).
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
- ``out`` param does not support scalar input case.
"""
return _unary_func_helper(x, _npi.isnan, _np.isnan, out=out, **kwargs)


@set_module('mxnet.symbol.numpy')
@wrap_np_unary_func
def isinf(x, out=None, **kwargs):
"""
Test element-wise for positive or negative infinity.
Parameters
----------
x : _Symbol
Input array.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.
Returns
-------
y : _Symbol or bool
True where x is positive or negative infinity, false otherwise.
This is a scalar if x is a scalar.
Notes
-----
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
This means that Not a Number is not equivalent to infinity.
This function differs from the original `numpy.isnan
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.isnan.html>`_ in
the following aspects:
- Does not support complex number for now
- Input type does not support Python native iterables(list, tuple, ...).
- ``out`` param: cannot perform auto broadcasting. ``out`` ndarray's shape must be the same as the expected output.
- ``out`` param: cannot perform auto type cast. ``out`` ndarray's dtype must be the same as the expected output.
- ``out`` param does not support scalar input case.
"""
return _unary_func_helper(x, _npi.isinf, _np.isinf, out=out, **kwargs)


@set_module('mxnet.symbol.numpy')
def where(condition, x, y):
"""
Expand Down
45 changes: 45 additions & 0 deletions src/operator/numpy/np_elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
namespace mxnet {
namespace op {

inline bool NumpyInferBoolType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

MXNET_OPERATOR_REGISTER_UNARY(_npx_relu)
.describe(R"code(Computes rectified linear activation.
.. math::
Expand Down Expand Up @@ -462,5 +471,41 @@ NNVM_REGISTER_OP(_npi_backward_nan_to_num)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyNanToNumOpBackward<cpu>);

NNVM_REGISTER_OP(_npi_isnan)
.describe("" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", NumpyInferBoolType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyIsNanOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x"};
})
.add_argument("x", "NDArray-or-Symbol", "The input array.");

NNVM_REGISTER_OP(_npi_isinf)
.describe("" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", NumpyInferBoolType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyIsInfOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x"};
})
.add_argument("x", "NDArray-or-Symbol", "The input array.");

} // namespace op
} // namespace mxnet
Loading

0 comments on commit 5d6567b

Please sign in to comment.