diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 33158baf10a5..9e237cb0049b 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -1087,3 +1087,38 @@ def _npx_reshape(a, newshape, reverse=False, order='C'): (8, 3, 2, 4, 4, 2) """ pass + + +def _np_diag(array, k = 0): + """ + Extracts a diagonal or constructs a diagonal array. + - 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero. + - 2-D arrays: extracts the k-th Diagonal + + Parameters + ---------- + array : ndarray + The array to apply diag method. + k : offset + extracts or constructs kth diagonal given input array + + Examples + -------- + >>> x = np.arange(9).reshape((3,3)) + >>> x + array([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> np.diag(x) + array([0, 4, 8]) + >>> np.diag(x, k=1) + array([1, 5]) + >>> np.diag(x, k=-1) + array([3, 7]) + + >>> np.diag(np.diag(x)) + array([[0, 0, 0], + [0, 4, 0], + [0, 0, 8]]) + """ + pass diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index ef9ed6cc363f..e233c9702c11 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -93,6 +93,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'concatenate', 'copy', 'cumsum', + 'diag', 'dot', 'expand_dims', 'fix', diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index a9828f40436d..508968718af0 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -28,9 +28,14 @@ #include #include #include +#include #include "../tensor/matrix_op-inl.h" #include "../nn/concat-inl.h" #include "../../common/utils.h" +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" +#include "../tensor/broadcast_reduce_op.h" namespace mxnet { namespace op { @@ -945,6 +950,206 @@ void NumpyConcatenateBackward(const nnvm::NodeAttrs& attrs, }); } +struct NumpyDiagParam : public dmlc::Parameter { + int k; + DMLC_DECLARE_PARAMETER(NumpyDiagParam) { + DMLC_DECLARE_FIELD(k).set_default(0) + .describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. "); + } +}; + +inline mxnet::TShape NumpyDiagShapeImpl(const mxnet::TShape &ishape, + const int k) { + CHECK_LE(ishape.ndim(), 2) << "Input must be 1- or 2-d"; + + if (ishape.ndim() == 1) { + auto s = ishape[0] + std::abs(k); + return mxnet::TShape({s, s}); + } + + auto h = ishape[0]; + auto w = ishape[1]; + + if (k > 0) { + w -= k; + } else if (k < 0) { + h += k; + } + dim_t a = 0; + auto s = std::max(std::min(h, w), a); + // s is the length of diagonal with k as the offset + + int32_t n_dim = ishape.ndim() - 1; + mxnet::TShape oshape(n_dim, -1); + oshape[n_dim - 1] = s; + return oshape; +} + +inline bool NumpyDiagOpShape(const nnvm::NodeAttrs &attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape &ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) { + return false; + } + + const NumpyDiagParam ¶m = nnvm::get(attrs.parsed); + mxnet::TShape oshape = NumpyDiagShapeImpl(ishape, param.k); + + if (shape_is_none(oshape)) { + LOG(FATAL) << "Diagonal does not exist."; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + return shape_is_known(out_attrs->at(0)); +} + +inline bool NumpyDiagOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + return (*out_attrs)[0] != -1; +} + +template +struct diag { + template + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *a, + index_t stride, index_t offset) { + using namespace mxnet_op; + index_t j = offset + stride * i; + + if (back) { + KERNEL_ASSIGN(out[j], req, a[i]); + } else { + KERNEL_ASSIGN(out[i], req, a[j]); + } + } +}; + +template +struct diag_gen { + template + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *a, + mshadow::Shape<2> oshape, int k) { + using namespace mxnet_op; + + auto j = unravel(i, oshape); + if (j[1] == (j[0] + k)) { + auto l = j[0] < j[1] ? j[0] : j[1]; + if (back) { + KERNEL_ASSIGN(out[l], req, a[i]); + } else { + KERNEL_ASSIGN(out[i], req, a[l]); + } + } else if (!back) { + KERNEL_ASSIGN(out[i], req, static_cast(0)); + } + } +}; + +template +void NumpyDiagOpImpl(const TBlob &in_data, + const TBlob &out_data, + const mxnet::TShape &ishape, + const mxnet::TShape &oshape, + index_t dsize, + const int &k, + mxnet_op::Stream *s, + const OpReqType &req) { + using namespace mxnet_op; + using namespace mshadow; + if (ishape.ndim() > 1) { + index_t stride1 = ishape[1], stride2 = 1; + // stride1 + stride2 is the stride for + // iterating over the diagonal in question + + // the extra index offset introduced by k + index_t offset; + if (k > 0) { + offset = stride2 * k; + } else if (k < 0) { + offset = stride1 * -k; + } else { + offset = 0; + } + + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + if (back && req != kAddTo && req != kNullOp) { + out_data.FlatTo1D(s) = 0; + } + + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + stride1 + stride2, offset); + }); + }); + } else { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + Kernel, xpu>::Launch( + s, dsize, out_data.dptr(), in_data.dptr(), + Shape2(oshape[0], oshape[1]), k); + }); + }); + } +} + +template +void NumpyDiagOpForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteTo); + Stream *s = ctx.get_stream(); + const TBlob &in_data = inputs[0]; + const TBlob &out_data = outputs[0]; + const mxnet::TShape &ishape = inputs[0].shape_; + const mxnet::TShape &oshape = outputs[0].shape_; + const NumpyDiagParam ¶m = nnvm::get(attrs.parsed); + + NumpyDiagOpImpl(in_data, out_data, ishape, oshape, + out_data.Size(), param.k, s, req[0]); +} + +template +void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + + const TBlob &in_data = inputs[0]; + const TBlob &out_data = outputs[0]; + const mxnet::TShape &ishape = inputs[0].shape_; + const mxnet::TShape &oshape = outputs[0].shape_; + const NumpyDiagParam ¶m = nnvm::get(attrs.parsed); + + NumpyDiagOpImpl(in_data, out_data, oshape, ishape, + in_data.Size(), param.k, s, req[0]); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 3967cde91d2a..912b32c2e8fb 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -36,6 +36,7 @@ DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam); DMLC_REGISTER_PARAMETER(NumpyRot90Param); DMLC_REGISTER_PARAMETER(NumpyReshapeParam); DMLC_REGISTER_PARAMETER(NumpyXReshapeParam); +DMLC_REGISTER_PARAMETER(NumpyDiagParam); bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, @@ -1302,5 +1303,27 @@ NNVM_REGISTER_OP(_npi_hsplit_backward) }) .set_attr("FCompute", HSplitOpBackward); +NNVM_REGISTER_OP(_np_diag) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs &attrs) { + return std::vector{"data"}; +}) +.set_attr("FInferShape", NumpyDiagOpShape) +.set_attr("FInferType", NumpyDiagOpType) +.set_attr("FCompute", NumpyDiagOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_diag"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyDiagParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_diag) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyDiagOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 7ca205565413..33f5aab7717c 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -118,5 +118,11 @@ NNVM_REGISTER_OP(_npi_hsplit_backward) NNVM_REGISTER_OP(_npx_reshape) .set_attr("FCompute", UnaryOp::IdentityCompute); +NNVM_REGISTER_OP(_np_diag) +.set_attr("FCompute", NumpyDiagOpForward); + +NNVM_REGISTER_OP(_backward_np_diag) +.set_attr("FCompute", NumpyDiagOpBackward); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 79a478679aaf..7584564bf387 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -57,6 +57,35 @@ def get_workloads(name): return OpArgMngr._args.get(name, None) +def _add_workload_diag(): + def get_mat(n): + data = _np.arange(n) + data = _np.add.outer(data, data) + return data + + A = np.array([[1, 2], [3, 4], [5, 6]]) + vals = (100 * np.arange(5)).astype('l') + vals_c = (100 * np.array(get_mat(5)) + 1).astype('l') + vals_f = _np.array((100 * get_mat(5) + 1), order ='F', dtype ='l') + vals_f = np.array(vals_f) + + OpArgMngr.add_workload('diag', A, k= 2) + OpArgMngr.add_workload('diag', A, k= 1) + OpArgMngr.add_workload('diag', A, k= 0) + OpArgMngr.add_workload('diag', A, k= -1) + OpArgMngr.add_workload('diag', A, k= -2) + OpArgMngr.add_workload('diag', A, k= -3) + OpArgMngr.add_workload('diag', vals, k= 0) + OpArgMngr.add_workload('diag', vals, k= 2) + OpArgMngr.add_workload('diag', vals, k= -2) + OpArgMngr.add_workload('diag', vals_c, k= 0) + OpArgMngr.add_workload('diag', vals_c, k= 2) + OpArgMngr.add_workload('diag', vals_c, k= -2) + OpArgMngr.add_workload('diag', vals_f, k= 0) + OpArgMngr.add_workload('diag', vals_f, k= 2) + OpArgMngr.add_workload('diag', vals_f, k= -2) + + def _add_workload_concatenate(array_pool): OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']]) OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']], axis=1) @@ -89,6 +118,7 @@ def _add_workload_concatenate(array_pool): def _add_workload_append(): + def get_new_shape(shape, axis): shape_lst = list(shape) if axis is not None: @@ -1226,6 +1256,7 @@ def _prepare_workloads(): _add_workload_copy() _add_workload_cumsum() _add_workload_ravel() + _add_workload_diag() _add_workload_dot() _add_workload_expand_dims() _add_workload_fix() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 76b38eb45205..1e1b53e88c92 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4529,6 +4529,67 @@ def hybrid_forward(self, F, x, *args, **kwargs): assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) +@with_seed() +@use_np +def test_np_diag(): + class TestDiag(HybridBlock): + def __init__(self, k=0): + super(TestDiag, self).__init__() + self._k = k + + def hybrid_forward(self, F, a): + return F.np.diag(a, k=self._k) + + shapes = [(), (2,), (1, 5), (2, 2), (2, 5), (3, 3), (4, 3)] + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + range_k = 6 + combination = itertools.product([False, True], shapes, dtypes, list(range(-range_k, range_k))) + for hybridize, shape, dtype, k in combination: + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + test_diag = TestDiag(k) + if hybridize: + test_diag.hybridize() + + x = np.random.uniform(-2.0, 2.0, size=shape).astype(dtype) if len(shape) != 0 else np.array(()) + x.attach_grad() + + np_out = _np.diag(x.asnumpy(), k) + with mx.autograd.record(): + mx_out = test_diag(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + # check backward function + mx_out.backward() + if len(shape) == 0: + np_backward = np.array(()) + elif len(shape) == 1: + np_backward = np.ones(shape[0]) + else: + np_backward = np.zeros(shape) + h = shape[0] + w = shape[1] + if k > 0: + w -= k + else: + h += k + s = min(w, h) + if s > 0: + if k >= 0: + for i in range(s): + np_backward[0+i][k+i] = 1 + else: + for i in range(s): + np_backward[-k+i][0+i] = 1 + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.diag(x, k) + np_out = _np.diag(x.asnumpy(), k) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + @with_seed() @use_np def test_np_nan_to_num():