Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy-compatible concatenate upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Aug 12, 2019
1 parent c3f5eea commit 18f8d8e
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 3 deletions.
26 changes: 25 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from ...context import current_context
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'concatenate']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -364,3 +365,26 @@ def tensordot(a, b, axes=2):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.ndarray.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)
25 changes: 24 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power', 'tensordot']
'mod', 'power', 'tensordot', 'concatenate']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1606,3 +1606,26 @@ def tensordot(a, b, axes=2):
[ 4928., 5306.]])
"""
return _mx_nd_np.tensordot(a, b, axes)


@set_module('mxnet.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _mx_nd_np.concatenate(seq, axis=axis, out=out)
26 changes: 25 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from .._internal import _set_np_symbol_class
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'concatenate']


def _num_outputs(sym):
Expand Down Expand Up @@ -1065,4 +1066,27 @@ def tensordot(a, b, axes=2):
return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.symbol.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The arrays must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. If axis is None,
arrays are flattened before use. Default is 0.
out : ndarray, optional
If provided, the destination to place the result. The shape must be
correct, matching that of what concatenate would have returned if no
out argument were specified.
Returns
-------
res : ndarray
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)


_set_np_symbol_class(_Symbol)
89 changes: 89 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_matrix_op.cc
* \brief CPU Implementation of numpy matrix operations
*/

#include "../nn/concat-inl.h"

namespace mxnet {
namespace op {

bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape);

bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type);

struct NumpyConcatGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};


NNVM_REGISTER_OP(_npi_concatenate)
.describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("data") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"out"};
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_concat"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_concat)
.set_num_outputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);

} // namespace op
} // namespace mxnet
37 changes: 37 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_matrix_op.cu
* \brief GPU Implementation of numpy matrix operations
*/
#include "../nn/concat-inl.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_concatenate)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>);

NNVM_REGISTER_OP(_backward_np_concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);

} // namespace op
} // namespace mxnet
51 changes: 51 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,57 @@ def hybrid_forward(self, F, a):
assert same(a.grad.asnumpy(), expected_grad)


@with_seed()
@use_np
def test_np_concat():
class TestConcat(HybridBlock):
def __init__(self, axis=None):
super(TestConcat, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a, *args):
return F.np.concatenate([a] + list(args), axis=self._axis)

def get_new_shape(shape, axis):
shape_lst = list(shape)
shape_lst[axis] = random.randint(0, 3)
return tuple(shape_lst)

for shape in [(0, 0), (2, 3)]:
for hybridize in [True, False]:
for axis in range(2):
# test gluon
test_concat = TestConcat(axis=axis)
if hybridize:
test_concat.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
a.attach_grad()
b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
b.attach_grad()
c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
c.attach_grad()
d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray()
d.attach_grad()
expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
with mx.autograd.record():
y = test_concat(a, b, c, d)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5)

y.backward()

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.concatenate([a, b, c, d], axis=axis)
np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 18f8d8e

Please sign in to comment.