diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a03e15a5a836c..683f5a28b4f46 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -381,6 +381,25 @@ struct OneHotAttrs : public tvm::AttrsNode { } }; // struct OneHotAttrs +/*! \brief Attributes used in matrix_set_diag operator */ +struct MatrixSetDiagAttrs : public tvm::AttrsNode { + int k1; + int k2; + bool super_diag_right_align; + bool sub_diag_right_align; + + TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") { + TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals."); + TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals."); + TVM_ATTR_FIELD(super_diag_right_align) + .set_default(true) + .describe("Bool, true iff super-diagonal is right aligned (left-padded)."); + TVM_ATTR_FIELD(sub_diag_right_align) + .set_default(false) + .describe("Bool, true iff sub-diagonal is right aligned (left-padded)."); + } +}; // struct MatrixSetDiagAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 2c0d102e35b17..e01eb703cb992 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1524,29 +1524,60 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array } /*! - * \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonal. + * \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonals. * \param input input tensor. - * \param diagonal values to be filled in the diagonal. + * \param diagonal values to be filled in the diagonals. + * \param k1 lower limit (included) of the range of diagonals. + * \param k2 upper limit (included) of the range of diagonals. + * \param super_diag_right_align bool, true iff super-diagonal is right aligned (left-padded). + * \param sub_diag_right_align bool, true iff sub-diagonal is right aligned (left-padded). * \param name output tensor name. * \param tag output tensor tag. * \return new tensor with given diagonal values. */ -inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, +inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2, + bool super_diag_right_align, bool sub_diag_right_align, const std::string name = "T_matrix_set_diag", const std::string tag = kInjective) { size_t ndim = input->shape.size() - 1; + bool only_one_diagonal = k1 == k2; + return compute( input->shape, [&](const Array& iter_vars) { auto get_diag = [&]() { Array diagonal_indices; - for (size_t i = 0; i < ndim; i++) { + PrimExpr k, offset = 0; + for (size_t i = 0; i < ndim - 1; i++) { diagonal_indices.push_back(iter_vars[i]); } + if (only_one_diagonal) { + k = k1; + } else { + // Determining which diagonal/sub-diagonal/super-diagonal it is + k = iter_vars[ndim] - iter_vars[ndim - 1]; + diagonal_indices.push_back(k2 - k); + + // Calculating the offset in diagonal tensor for this diagonal + auto get_offset = [&](PrimExpr M, PrimExpr N) { + // offset = max_diagonal_length - diagonal_length + return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N); + }; + offset = if_then_else( + k >= 0, + super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1]) + : 0, + sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k) + : 0); + } + diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) + + offset); return diagonal(diagonal_indices); }; - return if_then_else((PrimExpr)iter_vars[ndim] == iter_vars[ndim - 1], get_diag(), + return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1, + if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2, + get_diag(), input(iter_vars)), input(iter_vars)); }, name, tag); diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8ccd14890581d..14ac454aec646 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1178,17 +1178,33 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0 return _make.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) -def matrix_set_diag(data, diagonal): +def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): """ - Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values. + Returns a tensor with the diagonals of input tensor replaced with the provided diagonal values. Parameters ---------- data : relay.Expr Input Tensor. + diagonal : relay.Expr Values to be filled in the diagonal. + k : int or tuple of int, optional + Diagonal Offset(s). The diagonal or range of diagonals to set. (0 by default) + Positive value means superdiagonal, 0 refers to the main diagonal, and + negative value means subdiagonals. k can be a single integer (for a single diagonal) + or a pair of integers specifying the low and high ends of a matrix band. + k[0] must not be larger than k[1]. + + align : string, optional + Some diagonals are shorter than max_diag_len and need to be padded. + align is a string specifying how superdiagonals and subdiagonals should be aligned, + respectively. There are four possible alignments: "RIGHT_LEFT" (default), "LEFT_RIGHT", + "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals to the right + (left-pads the row) and subdiagonals to the left (right-pads the row). It is the packing + format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment. + Returns ------- result : relay.Expr @@ -1216,7 +1232,22 @@ def matrix_set_diag(data, diagonal): [7, 5, 7, 7], [7, 7, 6, 7]]] """ - return _make.matrix_set_diag(data, diagonal) + if isinstance(k, (tuple, list)): + k_one = k[0] + if len(k) >= 2: + k_two = k[1] + else: + k_two = k[0] + else: + k_one = k + k_two = k + + super_diag_right_align = align[:5] == "RIGHT" + sub_diag_right_align = align[-5:] == "RIGHT" + + return _make.matrix_set_diag( + data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align + ) def adv_index(inputs): diff --git a/python/tvm/topi/testing/matrix_set_diag.py b/python/tvm/topi/testing/matrix_set_diag.py index 63edd0a6d6379..81a8f6cccafe1 100644 --- a/python/tvm/topi/testing/matrix_set_diag.py +++ b/python/tvm/topi/testing/matrix_set_diag.py @@ -19,20 +19,28 @@ import numpy as np -def matrix_set_diag(input_np, diagonal): +def matrix_set_diag(input_np, diagonal, k=0, align="RIGHT_LEFT"): """matrix_set_diag operator implemented in numpy. - Returns a numpy array with the diagonal of input array + Returns a numpy array with the diagonals of input array replaced with the provided diagonal values. Parameters ---------- - input : numpy.ndarray + input_np : numpy.ndarray Input Array. Shape = [D1, D2, D3, ... , Dn-1 , Dn] + diagonal : numpy.ndarray Values to be filled in the diagonal. - Shape = [D1, D2, D3, ... , Dn-1] + + k : int or tuple of int + Diagonal Offsets. + + align : string + Some diagonals are shorter than max_diag_len and need to be padded. + Possible Vales: + ["RIGHT_LEFT" (default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"] Returns ------- @@ -41,8 +49,36 @@ def matrix_set_diag(input_np, diagonal): Shape = [D1, D2, D3, ... , Dn-1 , Dn] """ out = np.array(input_np, copy=True) - n = min(input_np.shape[-1], input_np.shape[-2]) - for i in range(n): - out[..., i, i] = diagonal[..., i] + cols = input_np.shape[-1] + rows = input_np.shape[-2] + + onlyOneDiagonal = True + if isinstance(k, (tuple, list)): + if len(k) < 2 or k[0] == k[1]: + k = k[0] + else: + onlyOneDiagonal = False + + if onlyOneDiagonal: + for i in range(diagonal.shape[-1]): + if k >= 0: + out[..., i, i + k] = diagonal[..., i] + else: + out[..., i - k, i] = diagonal[..., i] + else: + for ki in range(k[0], k[1] + 1): + diag_len = min(cols - max(ki, 0), rows + min(ki, 0)) + offset = 0 + if ki >= 0: + if align[:5] == "RIGHT": + offset = diagonal.shape[-1] - diag_len + else: + if align[-5:] == "RIGHT": + offset = diagonal.shape[-1] - diag_len + for i in range(diag_len): + if ki >= 0: + out[..., i, i + ki] = diagonal[..., k[1] - ki, i + offset] + else: + out[..., i - ki, i] = diagonal[..., k[1] - ki, i + offset] return out diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 6af0828da448c..c4e51a8858d17 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -806,17 +806,33 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0 return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) -def matrix_set_diag(data, diagonal): +def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): """ - Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values. + Returns a tensor with the diagonals of input tensor replaced with the provided diagonal values. Parameters ---------- data : relay.Expr Input Tensor. + diagonal : relay.Expr Values to be filled in the diagonal. + k : int or tuple of int, optional + Diagonal Offset(s). The diagonal or range of diagonals to set. (0 by default) + Positive value means superdiagonal, 0 refers to the main diagonal, and + negative value means subdiagonals. k can be a single integer (for a single diagonal) + or a pair of integers specifying the low and high ends of a matrix band. + k[0] must not be larger than k[1]. + + align : string, optional + Some diagonals are shorter than max_diag_len and need to be padded. + align is a string specifying how superdiagonals and subdiagonals should be aligned, + respectively. There are four possible alignments: "RIGHT_LEFT" (default), "LEFT_RIGHT", + "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals to the right + (left-pads the row) and subdiagonals to the left (right-pads the row). It is the packing + format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment. + Returns ------- result : relay.Expr @@ -836,7 +852,7 @@ def matrix_set_diag(data, diagonal): diagonal = [[1, 2, 3], [4, 5, 6]] - relay.matrix_set_diag(input, diagonal) = + topi.matrix_set_diag(input, diagonal) = [[[1, 7, 7, 7], [7, 2, 7, 7], [7, 7, 3, 7]], @@ -844,7 +860,22 @@ def matrix_set_diag(data, diagonal): [7, 5, 7, 7], [7, 7, 6, 7]]] """ - return cpp.matrix_set_diag(data, diagonal) + if isinstance(k, (tuple, list)): + k_one = k[0] + if len(k) >= 2: + k_two = k[1] + else: + k_two = k[0] + else: + k_one = k + k_two = k + + super_diag_right_align = align[:5] == "RIGHT" + sub_diag_right_align = align[-5:] == "RIGHT" + + return cpp.matrix_set_diag( + data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align + ) def adv_index(data, indices): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bb1b8d788df15..4faface03ac64 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3120,6 +3120,8 @@ RELAY_REGISTER_OP("sparse_to_dense") .set_attr("FTVMCompute", SparseToDenseCompute); // relay.matrix_set_diag +TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs); + bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [input, diagonal, result] @@ -3131,13 +3133,28 @@ bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& att const auto* diagonal = types[1].as(); CHECK(diagonal); + const auto param = attrs.as(); + CHECK_GE(param->k2, param->k1); + int d_ndims = diagonal->shape.size(); - for (int i = 0; i < d_ndims - 1; i++) { + int i_ndims = input->shape.size(); + + reporter->Assert(input->shape[i_ndims - 2] > -param->k1); + reporter->Assert(input->shape[i_ndims - 1] > param->k2); + + for (int i = 0; i < d_ndims - 2; i++) { reporter->AssertEQ(input->shape[i], diagonal->shape[i]); } - auto min_dim = if_then_else(input->shape[d_ndims - 1] >= input->shape[d_ndims], - input->shape[d_ndims], input->shape[d_ndims - 1]); - reporter->Assert(diagonal->shape[d_ndims - 1] >= min_dim); + if (param->k1 != param->k2) { + reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1); + } else if (d_ndims >= 2) { + reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]); + } + auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <= + input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0), + input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0), + input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0)); + reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len); reporter->Assign(types[2], TensorType(input->shape, input->dtype)); return true; @@ -3145,22 +3162,37 @@ bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& att Array MatrixSetDiagCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return Array{topi::matrix_set_diag(inputs[0], inputs[1])}; -} - -Expr MakeMatrixSetDiag(Expr input, Expr diagonal) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2, + param->super_diag_right_align, + param->sub_diag_right_align)}; +} + +Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align, + bool sub_diag_right_align) { + auto attrs = make_object(); + attrs->k1 = k1; + attrs->k2 = k2; + attrs->super_diag_right_align = super_diag_right_align; + attrs->sub_diag_right_align = sub_diag_right_align; static const Op& op = Op::Get("matrix_set_diag"); - return Call(op, {input, diagonal}, Attrs(), {}); + return Call(op, {input, diagonal}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag); RELAY_REGISTER_OP("matrix_set_diag") .describe( - R"code(Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values. + R"code(Returns a tensor with the diagonals of input tensor replaced with the provided diagonal values. **input** Input tensor. **diagonal** Values to be filled in the diagonal. + **k1** Lower limit (included) of the range of diagonals. + **k2** Upper limit (included) of the range of diagonals. + **super_diag_right_align** Bool, true iff super-diagonal is right aligned (left-padded). + **sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded). )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(2) .add_argument("input", "Tensor", "Input Tensor.") .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.") diff --git a/src/topi/transform.cc b/src/topi/transform.cc index bf7e1e67c2471..d79952e2494f4 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -177,7 +177,11 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = matrix_set_diag(args[0], args[1]); + int k1 = args[2]; + int k2 = args[3]; + bool super_diag_right_align = args[4]; + bool sub_diag_right_align = args[5]; + *rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align); }); TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 55fba499514a9..edb7c460d5ba3 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -506,12 +506,10 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): @tvm.testing.uses_gpu def test_matrix_set_diag(): - def _verify(input_shape, dtype): - diagonal_shape = list(input_shape[:-2]) - diagonal_shape.append(min(input_shape[-2], input_shape[-1])) + def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): input = relay.var("input", relay.TensorType(input_shape, dtype)) diagonal = relay.var("diagonal", relay.TensorType(diagonal_shape, dtype)) - out = relay.matrix_set_diag(input, diagonal) + out = relay.matrix_set_diag(input, diagonal, k, align) in_type = run_infer_type(input) out_type = run_infer_type(out) @@ -520,7 +518,7 @@ def _verify(input_shape, dtype): func = relay.Function([input, diagonal], out) input_np = np.random.randint(-100, 100, size=input_shape).astype(dtype) diagonal_np = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) - out_np = tvm.topi.testing.matrix_set_diag(input_np, diagonal_np) + out_np = tvm.topi.testing.matrix_set_diag(input_np, diagonal_np, k, align) for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: @@ -528,9 +526,12 @@ def _verify(input_shape, dtype): out_relay = intrp.evaluate(func)(input_np, diagonal_np) tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) - _verify((2, 2), "float32") - _verify((4, 3, 3), "int32") - _verify((2, 3, 4), "float32") + _verify((2, 2), (2,), "float32") + _verify((4, 3, 3), (4, 3), "int32") + _verify((2, 3, 4), (2, 3), "float32", 1) + _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_RIGHT") + _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_LEFT") + _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT") if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index a32d41a27e178..f18b5397eefe8 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -715,12 +715,10 @@ def check_device(device, ctx): check_device(device, ctx) -def verify_matrix_set_diag(input_shape, dtype): - diagonal_shape = list(input_shape[:-2]) - diagonal_shape.append(min(input_shape[-2], input_shape[-1])) +def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): input = te.placeholder(shape=input_shape, name="input", dtype=dtype) diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) - matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal) + matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align) def check_device(device, ctx): ctx = tvm.context(device, 0) @@ -730,7 +728,7 @@ def check_device(device, ctx): fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], device, name="matrix_set_diag") input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) - out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy) + out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align) input_nd = tvm.nd.array(input_npy, ctx) diagonal_nd = tvm.nd.array(diagonal_npy, ctx) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), ctx) @@ -1165,9 +1163,12 @@ def test_sparse_to_dense(): @tvm.testing.uses_gpu def test_matrix_set_diag(): for dtype in ["float32", "int32"]: - verify_matrix_set_diag((2, 2), dtype) - verify_matrix_set_diag((4, 3, 3), dtype) - verify_matrix_set_diag((2, 3, 4), dtype) + verify_matrix_set_diag((2, 2), (2,), dtype) + verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) + verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1) + verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT") + verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT") + verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT") @tvm.testing.uses_gpu