Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix the problems pointed out in comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Mao committed Aug 15, 2019
1 parent 8adabd2 commit 411bdc3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
45 changes: 23 additions & 22 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,29 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
else:
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)


@set_module('mxnet.symbol.numpy')
def expand_dims(a, axis):
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
Parameters
----------
a : _Symbol
Input array.
axis : int
Position in the expanded axes where the new axis is placed.
Returns
-------
res : _Symbol
Output array. The number of dimensions is one greater than that of
the input array.
"""
return _npi.expand_dims(a, axis)


def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):
"""Helper function for unary operators.
Parameters
Expand Down Expand Up @@ -1197,28 +1220,6 @@ def invert(x, out=None, **kwargs):
return _unary_func_helper(x, _npi.invert, _np.invert, out=out, **kwargs)


@set_module('mxnet.symbol.numpy')
def expand_dims(a, axis):
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
Parameters
----------
a : _Symbol
Input array.
axis : int
Position in the expanded axes where the new axis is placed.
Returns
-------
res : _Symbol
Output array. The number of dimensions is one greater than that of
the input array.
"""
return _npi.expand_dims(a, axis)


@set_module('mxnet.symbol.numpy')
def tile(A, reps):
r"""
Expand Down
4 changes: 2 additions & 2 deletions src/operator/numpy/np_elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ invert([13]) = array([242])
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x"};
return std::vector<std::string>{"x"};
})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::invert>)
.add_argument("x", "NDArray-or-Symbol", "The input array.")
Expand Down

0 comments on commit 411bdc3

Please sign in to comment.