From deb1df4985e226e8077201d0e617e4e7f4e980c6 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Sat, 20 Oct 2018 08:39:53 +0530 Subject: [PATCH] set_attrs_type_key and test_format testcase added --- include/tvm/relay/attrs/transform.h | 4 +++- python/tvm/relay/op/transform.py | 2 +- src/relay/op/tensor/transform.cc | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 417e3b64dd710..6228b68252836 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -119,6 +119,9 @@ struct SplitAttrs : public tvm::AttrsNode { "the entries indicate where along axis the array is split."); TVM_ATTR_FIELD(axis).set_default(0) .describe("the axis to be splitted."); + } +}; + /*! \brief Attributes for StridedSlice operator */ struct StridedSliceAttrs : public tvm::AttrsNode { Array begin; @@ -134,7 +137,6 @@ struct StridedSliceAttrs : public tvm::AttrsNode { .describe("Stride values of the slice"); } }; - } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 37456265d8e37..fbb7aa37c4e11 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -314,6 +314,7 @@ def split(data, indices_or_sections, axis=0): ret_size = len(indices_or_sections) + 1 return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) + def strided_slice(data, begin, end, stride=None): """Strided slice of an array.. @@ -339,4 +340,3 @@ def strided_slice(data, begin, end, stride=None): """ stride = stride or [] return _make.strided_slice(data, list(begin), list(end), list(stride)) ->>>>>>> [RELAY][OP]Strided slice diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 358bcbb5d7a57..2ee53cd000f5e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -947,6 +947,7 @@ Examples:: .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(4) +.set_attrs_type_key("relay.attrs.StridedSliceAttrs") .add_type_rel("StridedSlice", StridedSliceRel); // Split