Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add OP diag [numpy] (#16786)
Browse files Browse the repository at this point in the history
* Add numpy.diag method

* Fix ci diag problem

* Merge and fix ci Diag problem

* Fix format issue

* Modified Diag method

* Fixed the Diag kernel and name

* Fixed sanity problem

* change K param

* change unittest and fix

* Fix diag code format

* update 3rd party
  • Loading branch information
Tommliu authored and haojin2 committed Nov 18, 2019
1 parent 94e7ba7 commit 135c42c
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 0 deletions.
35 changes: 35 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,3 +1087,38 @@ def _npx_reshape(a, newshape, reverse=False, order='C'):
(8, 3, 2, 4, 4, 2)
"""
pass


def _np_diag(array, k = 0):
"""
Extracts a diagonal or constructs a diagonal array.
- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero.
- 2-D arrays: extracts the k-th Diagonal
Parameters
----------
array : ndarray
The array to apply diag method.
k : offset
extracts or constructs kth diagonal given input array
Examples
--------
>>> x = np.arange(9).reshape((3,3))
>>> x
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> np.diag(x)
array([0, 4, 8])
>>> np.diag(x, k=1)
array([1, 5])
>>> np.diag(x, k=-1)
array([3, 7])
>>> np.diag(np.diag(x))
array([[0, 0, 0],
[0, 4, 0],
[0, 0, 8]])
"""
pass
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'concatenate',
'copy',
'cumsum',
'diag',
'dot',
'expand_dims',
'fix',
Expand Down
205 changes: 205 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
#include <vector>
#include <algorithm>
#include <string>
#include <utility>
#include "../tensor/matrix_op-inl.h"
#include "../nn/concat-inl.h"
#include "../../common/utils.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../tensor/broadcast_reduce_op.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -945,6 +950,206 @@ void NumpyConcatenateBackward(const nnvm::NodeAttrs& attrs,
});
}

struct NumpyDiagParam : public dmlc::Parameter<NumpyDiagParam> {
int k;
DMLC_DECLARE_PARAMETER(NumpyDiagParam) {
DMLC_DECLARE_FIELD(k).set_default(0)
.describe("Diagonal in question. The default is 0. "
"Use k>0 for diagonals above the main diagonal, "
"and k<0 for diagonals below the main diagonal. ");
}
};

inline mxnet::TShape NumpyDiagShapeImpl(const mxnet::TShape &ishape,
const int k) {
CHECK_LE(ishape.ndim(), 2) << "Input must be 1- or 2-d";

if (ishape.ndim() == 1) {
auto s = ishape[0] + std::abs(k);
return mxnet::TShape({s, s});
}

auto h = ishape[0];
auto w = ishape[1];

if (k > 0) {
w -= k;
} else if (k < 0) {
h += k;
}
dim_t a = 0;
auto s = std::max(std::min(h, w), a);
// s is the length of diagonal with k as the offset

int32_t n_dim = ishape.ndim() - 1;
mxnet::TShape oshape(n_dim, -1);
oshape[n_dim - 1] = s;
return oshape;
}

inline bool NumpyDiagOpShape(const nnvm::NodeAttrs &attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

const mxnet::TShape &ishape = (*in_attrs)[0];
if (!mxnet::ndim_is_known(ishape)) {
return false;
}

const NumpyDiagParam &param = nnvm::get<NumpyDiagParam>(attrs.parsed);
mxnet::TShape oshape = NumpyDiagShapeImpl(ishape, param.k);

if (shape_is_none(oshape)) {
LOG(FATAL) << "Diagonal does not exist.";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);

return shape_is_known(out_attrs->at(0));
}

inline bool NumpyDiagOpType(const nnvm::NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
return (*out_attrs)[0] != -1;
}

template <int ndim, int req, bool back>
struct diag {
template <typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *a,
index_t stride, index_t offset) {
using namespace mxnet_op;
index_t j = offset + stride * i;

if (back) {
KERNEL_ASSIGN(out[j], req, a[i]);
} else {
KERNEL_ASSIGN(out[i], req, a[j]);
}
}
};

template <int req, bool back>
struct diag_gen {
template <typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *a,
mshadow::Shape<2> oshape, int k) {
using namespace mxnet_op;

auto j = unravel(i, oshape);
if (j[1] == (j[0] + k)) {
auto l = j[0] < j[1] ? j[0] : j[1];
if (back) {
KERNEL_ASSIGN(out[l], req, a[i]);
} else {
KERNEL_ASSIGN(out[i], req, a[l]);
}
} else if (!back) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
}
}
};

template <typename xpu, bool back>
void NumpyDiagOpImpl(const TBlob &in_data,
const TBlob &out_data,
const mxnet::TShape &ishape,
const mxnet::TShape &oshape,
index_t dsize,
const int &k,
mxnet_op::Stream<xpu> *s,
const OpReqType &req) {
using namespace mxnet_op;
using namespace mshadow;
if (ishape.ndim() > 1) {
index_t stride1 = ishape[1], stride2 = 1;
// stride1 + stride2 is the stride for
// iterating over the diagonal in question

// the extra index offset introduced by k
index_t offset;
if (k > 0) {
offset = stride2 * k;
} else if (k < 0) {
offset = stride1 * -k;
} else {
offset = 0;
}

MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
if (back && req != kAddTo && req != kNullOp) {
out_data.FlatTo1D<xpu, DType>(s) = 0;
}

Kernel<diag<2, req_type, back>, xpu>::Launch(
s, dsize, out_data.dptr<DType>(), in_data.dptr<DType>(),
stride1 + stride2, offset);
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<diag_gen<req_type, back>, xpu>::Launch(
s, dsize, out_data.dptr<DType>(), in_data.dptr<DType>(),
Shape2(oshape[0], oshape[1]), k);
});
});
}
}

template <typename xpu>
void NumpyDiagOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[0], kWriteTo);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob &in_data = inputs[0];
const TBlob &out_data = outputs[0];
const mxnet::TShape &ishape = inputs[0].shape_;
const mxnet::TShape &oshape = outputs[0].shape_;
const NumpyDiagParam &param = nnvm::get<NumpyDiagParam>(attrs.parsed);

NumpyDiagOpImpl<xpu, false>(in_data, out_data, ishape, oshape,
out_data.Size(), param.k, s, req[0]);
}

template <typename xpu>
void NumpyDiagOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();

const TBlob &in_data = inputs[0];
const TBlob &out_data = outputs[0];
const mxnet::TShape &ishape = inputs[0].shape_;
const mxnet::TShape &oshape = outputs[0].shape_;
const NumpyDiagParam &param = nnvm::get<NumpyDiagParam>(attrs.parsed);

NumpyDiagOpImpl<xpu, true>(in_data, out_data, oshape, ishape,
in_data.Size(), param.k, s, req[0]);
}

} // namespace op
} // namespace mxnet

Expand Down
23 changes: 23 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam);
DMLC_REGISTER_PARAMETER(NumpyRot90Param);
DMLC_REGISTER_PARAMETER(NumpyReshapeParam);
DMLC_REGISTER_PARAMETER(NumpyXReshapeParam);
DMLC_REGISTER_PARAMETER(NumpyDiagParam);


bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -1302,5 +1303,27 @@ NNVM_REGISTER_OP(_npi_hsplit_backward)
})
.set_attr<FCompute>("FCompute<cpu>", HSplitOpBackward<cpu>);

NNVM_REGISTER_OP(_np_diag)
.set_attr_parser(ParamParser<NumpyDiagParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs &attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyDiagOpShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyDiagOpType)
.set_attr<FCompute>("FCompute<cpu>", NumpyDiagOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_diag"})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyDiagParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_diag)
.set_attr_parser(ParamParser<NumpyDiagParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyDiagOpBackward<cpu>);

} // namespace op
} // namespace mxnet
6 changes: 6 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,11 @@ NNVM_REGISTER_OP(_npi_hsplit_backward)
NNVM_REGISTER_OP(_npx_reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

NNVM_REGISTER_OP(_np_diag)
.set_attr<FCompute>("FCompute<gpu>", NumpyDiagOpForward<gpu>);

NNVM_REGISTER_OP(_backward_np_diag)
.set_attr<FCompute>("FCompute<gpu>", NumpyDiagOpBackward<gpu>);

} // namespace op
} // namespace mxnet
31 changes: 31 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ def get_workloads(name):
return OpArgMngr._args.get(name, None)


def _add_workload_diag():
def get_mat(n):
data = _np.arange(n)
data = _np.add.outer(data, data)
return data

A = np.array([[1, 2], [3, 4], [5, 6]])
vals = (100 * np.arange(5)).astype('l')
vals_c = (100 * np.array(get_mat(5)) + 1).astype('l')
vals_f = _np.array((100 * get_mat(5) + 1), order ='F', dtype ='l')
vals_f = np.array(vals_f)

OpArgMngr.add_workload('diag', A, k= 2)
OpArgMngr.add_workload('diag', A, k= 1)
OpArgMngr.add_workload('diag', A, k= 0)
OpArgMngr.add_workload('diag', A, k= -1)
OpArgMngr.add_workload('diag', A, k= -2)
OpArgMngr.add_workload('diag', A, k= -3)
OpArgMngr.add_workload('diag', vals, k= 0)
OpArgMngr.add_workload('diag', vals, k= 2)
OpArgMngr.add_workload('diag', vals, k= -2)
OpArgMngr.add_workload('diag', vals_c, k= 0)
OpArgMngr.add_workload('diag', vals_c, k= 2)
OpArgMngr.add_workload('diag', vals_c, k= -2)
OpArgMngr.add_workload('diag', vals_f, k= 0)
OpArgMngr.add_workload('diag', vals_f, k= 2)
OpArgMngr.add_workload('diag', vals_f, k= -2)


def _add_workload_concatenate(array_pool):
OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']])
OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']], axis=1)
Expand Down Expand Up @@ -89,6 +118,7 @@ def _add_workload_concatenate(array_pool):


def _add_workload_append():

def get_new_shape(shape, axis):
shape_lst = list(shape)
if axis is not None:
Expand Down Expand Up @@ -1226,6 +1256,7 @@ def _prepare_workloads():
_add_workload_copy()
_add_workload_cumsum()
_add_workload_ravel()
_add_workload_diag()
_add_workload_dot()
_add_workload_expand_dims()
_add_workload_fix()
Expand Down
Loading

0 comments on commit 135c42c

Please sign in to comment.