From 3d8035b40005623d99dffef9962516abb086758b Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 14 Aug 2019 21:52:10 -0700 Subject: [PATCH] Numpy-compatible concatenate upstream (#15894) * numpy-compatible concatenate upstream * extend ci deadline --- ci/jenkins/Jenkinsfile_unix_cpu | 2 +- python/mxnet/ndarray/numpy/_op.py | 25 ++++++++++- python/mxnet/numpy/multiarray.py | 26 +++++++++++- python/mxnet/symbol/numpy/_symbol.py | 25 ++++++++++- src/operator/numpy/np_matrix_op-inl.h | 1 + src/operator/numpy/np_matrix_op.cc | 58 +++++++++++++++++++++++++- src/operator/numpy/np_matrix_op.cu | 7 +++- tests/python/unittest/test_numpy_op.py | 51 ++++++++++++++++++++++ 8 files changed, 189 insertions(+), 6 deletions(-) diff --git a/ci/jenkins/Jenkinsfile_unix_cpu b/ci/jenkins/Jenkinsfile_unix_cpu index fa0942988d9c..3533a123ea86 100644 --- a/ci/jenkins/Jenkinsfile_unix_cpu +++ b/ci/jenkins/Jenkinsfile_unix_cpu @@ -21,7 +21,7 @@ // See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/ // timeout in minutes -max_time = 180 +max_time = 240 node('utility') { // Loading the utilities requires a node context unfortunately diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index db4d86179e59..d7f3fd1ace54 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -27,7 +27,7 @@ from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] @set_module('mxnet.ndarray.numpy') @@ -682,3 +682,26 @@ def split(ary, indices_or_sections, axis=0): if not isinstance(ret, list): return [ret] return ret + + +@set_module('mxnet.ndarray.numpy') +def concatenate(seq, axis=0, out=None): + """Join a sequence of arrays along an existing axis. + Parameters + ---------- + a1, a2, ... : sequence of array_like + The arrays must have the same shape, except in the dimension + corresponding to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. If axis is None, + arrays are flattened before use. Default is 0. + out : ndarray, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what concatenate would have returned if no + out argument were specified. + Returns + ------- + res : ndarray + The concatenated array. + """ + return _npi.concatenate(*seq, dim=axis, out=out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 0f05aa8da50e..8988b4eb19c9 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -44,7 +44,8 @@ from ..ndarray.numpy import _internal as _npi __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', - 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split'] + 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', + 'concatenate'] # This function is copied from ndarray.py since pylint @@ -1853,3 +1854,26 @@ def split(ary, indices_or_sections, axis=0): If `indices_or_sections` is given as an integer, but a split does not result in equal division.""" return _mx_nd_np.split(ary, indices_or_sections, axis=axis) + + +@set_module('mxnet.numpy') +def concatenate(seq, axis=0, out=None): + """Join a sequence of arrays along an existing axis. + Parameters + ---------- + a1, a2, ... : sequence of array_like + The arrays must have the same shape, except in the dimension + corresponding to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. If axis is None, + arrays are flattened before use. Default is 0. + out : ndarray, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what concatenate would have returned if no + out argument were specified. + Returns + ------- + res : ndarray + The concatenated array. + """ + return _mx_nd_np.concatenate(seq, axis=axis, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index bf4a6d159363..a6699d60871a 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -30,7 +30,7 @@ from . import _internal as _npi __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] def _num_outputs(sym): @@ -1312,4 +1312,27 @@ def split(ary, indices_or_sections, axis=0): return ret +@set_module('mxnet.symbol.numpy') +def concatenate(seq, axis=0, out=None): + """Join a sequence of arrays along an existing axis. + Parameters + ---------- + a1, a2, ... : sequence of array_like + The arrays must have the same shape, except in the dimension + corresponding to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. If axis is None, + arrays are flattened before use. Default is 0. + out : ndarray, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what concatenate would have returned if no + out argument were specified. + Returns + ------- + res : ndarray + The concatenated array. + """ + return _npi.concatenate(*seq, dim=axis, out=out) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 44a6c909c9cf..6d3d9ea5ec85 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -27,6 +27,7 @@ #include #include "../tensor/matrix_op-inl.h" +#include "../nn/concat-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 1f31a650e771..73340981037d 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -24,7 +24,6 @@ */ #include "./np_matrix_op-inl.h" -#include "../nn/concat-inl.h" namespace mxnet { namespace op { @@ -248,5 +247,62 @@ NNVM_REGISTER_OP(_np_squeeze) .add_argument("a", "NDArray-or-Symbol[]", "data to squeeze") .add_arguments(SqueezeParam::__FIELDS__()); +bool ConcatShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape); + +bool ConcatType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type); + +struct NumpyConcatGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) const { + CHECK_EQ(ograds.size(), 1); + std::vector heads(ograds.begin(), ograds.end()); + return MakeGradNode(op_name, n, heads, n->attrs.dict); + } +}; + + +NNVM_REGISTER_OP(_npi_concatenate) +.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + std::vector ret; + for (int i = 0; i < params.num_args; ++i) { + ret.push_back(std::string("data") + std::to_string(i)); + } + return ret; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"out"}; +}) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInferType", ConcatType) +.set_attr("FInferShape", ConcatShape) +.set_attr("FCompute", ConcatCompute) +.set_attr("FGradient", NumpyConcatGrad{"_backward_np_concat"}) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") +.add_arguments(ConcatParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_concat) +.set_num_outputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", ConcatGradCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index efeba8e089af..f192560f4ac9 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -23,7 +23,6 @@ * \brief GPU Implementation of numpy matrix operations */ #include "./np_matrix_op-inl.h" -#include "../nn/concat-inl.h" namespace mxnet { namespace op { @@ -37,5 +36,11 @@ NNVM_REGISTER_OP(_np_reshape) NNVM_REGISTER_OP(_np_squeeze) .set_attr("FCompute", UnaryOp::IdentityCompute); +NNVM_REGISTER_OP(_npi_concatenate) +.set_attr("FCompute", ConcatCompute); + +NNVM_REGISTER_OP(_backward_np_concat) +.set_attr("FCompute", ConcatGradCompute); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index aba262f88042..2291bcdb6d3d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -829,6 +829,57 @@ def get_indices(axis_size): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_concat(): + class TestConcat(HybridBlock): + def __init__(self, axis=None): + super(TestConcat, self).__init__() + self._axis = axis + + def hybrid_forward(self, F, a, *args): + return F.np.concatenate([a] + list(args), axis=self._axis) + + def get_new_shape(shape, axis): + shape_lst = list(shape) + shape_lst[axis] = random.randint(0, 3) + return tuple(shape_lst) + + for shape in [(0, 0), (2, 3)]: + for hybridize in [True, False]: + for axis in range(2): + # test gluon + test_concat = TestConcat(axis=axis) + if hybridize: + test_concat.hybridize() + + a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + a.attach_grad() + b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + b.attach_grad() + c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + c.attach_grad() + d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() + d.attach_grad() + expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + with mx.autograd.record(): + y = test_concat(a, b, c, d) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + + y.backward() + + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) + assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_out = np.concatenate([a, b, c, d], axis=axis) + np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule()