Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
porting numpy-compatible hstack to master and add dstack for interope…
Browse files Browse the repository at this point in the history
…rability
  • Loading branch information
haojin2 committed Jan 12, 2020
1 parent 25dc909 commit 8d1f746
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 3 deletions.
42 changes: 41 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'row_stack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum',
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'hypot', 'bitwise_xor',
'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
Expand Down Expand Up @@ -3790,6 +3791,45 @@ def column_stack(tup):
return _npi.column_stack(*tup)


@set_module('mxnet.ndarray.numpy')
def hstack(arrays):
"""
Stack arrays in sequence horizontally (column wise).
This is equivalent to concatenation along the second axis,
except for 1-D arrays where it concatenates along the first axis.
Rebuilds arrays divided by hsplit.
This function makes most sense for arrays with up to 3 dimensions.
For instance, for pixel-data with a height (first axis), width (second axis),
and r/g/b channels (third axis). The functions concatenate,
stack and block provide more general stacking and concatenation operations.
Parameters
----------
tup : sequence of ndarrays
The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length.
Returns
-------
stacked : ndarray
The array formed by stacking the given arrays.
Examples
--------
>>> from mxnet import np,npx
>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.hstack((a,b))
array([1., 2., 3., 2., 3., 4.])
>>> a = np.array([[1],[2],[3]])
>>> b = np.array([[2],[3],[4]])
>>> np.hstack((a,b))
array([[1., 2.],
[2., 3.],
[3., 4.]])
"""
return _npi.hstack(*arrays)


@set_module('mxnet.ndarray.numpy')
def dstack(arrays):
"""
Expand Down
41 changes: 40 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'dstack',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var',
'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'flipud',
'fliplr', 'around', 'round', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
Expand Down Expand Up @@ -5563,6 +5563,45 @@ def column_stack(tup):
return _mx_nd_np.column_stack(tup)


@set_module('mxnet.numpy')
def hstack(arrays):
"""
Stack arrays in sequence horizontally (column wise).
This is equivalent to concatenation along the second axis,
except for 1-D arrays where it concatenates along the first axis.
Rebuilds arrays divided by hsplit.
This function makes most sense for arrays with up to 3 dimensions.
For instance, for pixel-data with a height (first axis), width (second axis),
and r/g/b channels (third axis). The functions concatenate,
stack and block provide more general stacking and concatenation operations.
Parameters
----------
tup : sequence of ndarrays
The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length.
Returns
-------
stacked : ndarray
The array formed by stacking the given arrays.
Examples
--------
>>> from mxnet import np,npx
>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.hstack((a,b))
array([1., 2., 3., 2., 3., 4.])
>>> a = np.array([[1],[2],[3]])
>>> b = np.array([[2],[3],[4]])
>>> np.hstack((a,b))
array([[1., 2.],
[2., 3.],
[3., 4.]])
"""
return _mx_nd_np.hstack(arrays)


@set_module('mxnet.numpy')
def dstack(arrays):
"""
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'vdot',
'vstack',
'column_stack',
'hstack',
'dstack',
'zeros_like',
'linalg.norm',
'linalg.cholesky',
Expand Down
42 changes: 41 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'row_stack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum',
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'hypot', 'bitwise_xor',
'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
Expand Down Expand Up @@ -3750,6 +3751,45 @@ def column_stack(tup):
return _npi.column_stack(*tup)


@set_module('mxnet.symbol.numpy')
def hstack(arrays):
"""
Stack arrays in sequence horizontally (column wise).
This is equivalent to concatenation along the second axis,
except for 1-D arrays where it concatenates along the first axis.
Rebuilds arrays divided by hsplit.
This function makes most sense for arrays with up to 3 dimensions.
For instance, for pixel-data with a height (first axis), width (second axis),
and r/g/b channels (third axis). The functions concatenate,
stack and block provide more general stacking and concatenation operations.
Parameters
----------
tup : _Symbol
The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length.
Returns
-------
stacked : _Symbol
The array formed by stacking the given arrays.
Examples
--------
>>> from mxnet import np,npx
>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.hstack((a,b))
array([1., 2., 3., 2., 3., 4.])
>>> a = np.array([[1],[2],[3]])
>>> b = np.array([[2],[3],[4]])
>>> np.hstack((a,b))
array([[1., 2.],
[2., 3.],
[3., 4.]])
"""
return _npi.hstack(*arrays)


@set_module('mxnet.symbol.numpy')
def dstack(arrays):
"""
Expand Down
44 changes: 44 additions & 0 deletions src/operator/nn/concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
});
}

template<typename xpu>
void HStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
param.dim = inputs[0].shape_.ndim() > 1 ? 1 : 0;
std::vector<TBlob> modified_inputs(inputs.size());
for (int i = 0; i < param.num_args; ++i) {
if (inputs[i].shape_.ndim() == 0) {
modified_inputs[i] = inputs[i].reshape(TShape(1, 1));
} else {
modified_inputs[i] = inputs[i];
}
}
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Forward(ctx, modified_inputs, req, outputs);
});
}

template<typename xpu>
void DStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down Expand Up @@ -185,6 +207,28 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
});
}

template<typename xpu>
void HStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
param.dim = inputs[0].shape_.ndim() > 1 ? 1 : 0;
std::vector<TBlob> modified_outputs(outputs.size());
for (int i = 0; i < param.num_args; ++i) {
if (outputs[i].shape_.ndim() == 0) {
modified_outputs[i] = outputs[i].reshape(TShape(1, 1));
} else {
modified_outputs[i] = outputs[i];
}
}
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs);
});
}

template<typename xpu>
void DStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down
91 changes: 91 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,59 @@ NNVM_REGISTER_OP(_np_squeeze)
.add_argument("a", "NDArray-or-Symbol", "data to squeeze")
.add_arguments(SqueezeParam::__FIELDS__());

bool HStackShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
using namespace mshadow;
ConcatParam param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
mxnet::TShape dshape;
dim_t size = 0;
bool has_unknown_dim_size = false;
int axis = (*in_shape)[0].ndim() > 1 ? 1 : 0;
param_.dim = axis;
for (int i = 0; i < param_.num_args; ++i) {
// scalor tensor is treated as one dimensional vector
if ((*in_shape)[i].ndim() == 0) {
(*in_shape)[i] = mxnet::TShape(1, 1);
}
mxnet::TShape &tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
CheckAxis(axis, tmp.ndim());
if (!mxnet::dim_size_is_known(tmp, axis)) {
has_unknown_dim_size = true;
} else {
size += tmp[axis];
}
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
}

mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}

if (dshape.ndim() == -1) return false;
CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";

for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}

if (!has_unknown_dim_size) {
dshape[axis] = size;
}
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];

return shape_is_known(dshape);
}

bool DStackShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
Expand Down Expand Up @@ -986,6 +1039,44 @@ NNVM_REGISTER_OP(_backward_np_vstack)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyVstackBackward<cpu>);

NNVM_REGISTER_OP(_npi_hstack)
.describe(R"code(Stack tensors horizontally (in second dimension))code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("data") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"out"};
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", HStackShape)
.set_attr<FCompute>("FCompute<cpu>", HStackCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_hstack"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_hstack)
.set_num_outputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", HStackGradCompute<cpu>);

NNVM_REGISTER_OP(_npi_dstack)
.describe(R"code(Stack tensors in sequence depthwise (in third dimension))code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ NNVM_REGISTER_OP(_npi_vstack)
NNVM_REGISTER_OP(_backward_np_vstack)
.set_attr<FCompute>("FCompute<gpu>", NumpyVstackBackward<gpu>);

NNVM_REGISTER_OP(_npi_hstack)
.set_attr<FCompute>("FCompute<gpu>", HStackCompute<gpu>);

NNVM_REGISTER_OP(_backward_np_hstack)
.set_attr<FCompute>("FCompute<gpu>", HStackGradCompute<gpu>);

NNVM_REGISTER_OP(_npi_dstack)
.set_attr<FCompute>("FCompute<gpu>", DStackCompute<gpu>);

Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,18 @@ def _add_workload_column_stack():
OpArgMngr.add_workload('column_stack', [np.array(_np.arange(3)) for _ in range(2)])


def _add_workload_hstack(array_pool):
OpArgMngr.add_workload('hstack', (np.random.uniform(size=(1, 4)), np.random.uniform(size=(1, 4))))
OpArgMngr.add_workload('hstack', array_pool['4x1'])
OpArgMngr.add_workload('hstack', array_pool['1x1x0'])


def _add_workload_dstack(array_pool):
OpArgMngr.add_workload('dstack', (np.random.uniform(size=(5, 1, 2)), np.random.uniform(size=(5, 1, 3))))
OpArgMngr.add_workload('dstack', array_pool['4x1'])
OpArgMngr.add_workload('dstack', array_pool['1x1x0'])


def _add_workload_equal(array_pool):
# TODO(junwu): fp16 does not work yet with TVM generated ops
# OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
Expand Down Expand Up @@ -1737,6 +1749,8 @@ def _prepare_workloads():
_add_workload_vdot()
_add_workload_vstack(array_pool)
_add_workload_column_stack()
_add_workload_hstack(array_pool)
_add_workload_dstack(array_pool)
_add_workload_equal(array_pool)
_add_workload_not_equal(array_pool)
_add_workload_greater(array_pool)
Expand Down
Loading

0 comments on commit 8d1f746

Please sign in to comment.