Skip to content

Commit

Permalink
Review comment fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Oct 28, 2018
1 parent 41b6e19 commit 3c3f641
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 22 deletions.
8 changes: 4 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,16 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {

/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Array<IndexExpr> begin;
Array<IndexExpr> end;
Array<IndexExpr> stride;
Array<Integer> begin;
Array<Integer> end;
Array<Integer> strides;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin)
.describe("Indices for begin of slice, begin index is also inclusive");
TVM_ATTR_FIELD(end)
.describe("Indices for end of slice, end index is also inclusive");
TVM_ATTR_FIELD(stride).set_default(Array<IndexExpr>({}))
TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
.describe("Stride values of the slice");
}
};
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def split(data, indices_or_sections, axis=0):
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)


def strided_slice(data, begin, end, stride=None):
def strided_slice(data, begin, end, strides=None):
"""Strided slice of an array..
Parameters
Expand All @@ -352,7 +352,7 @@ def strided_slice(data, begin, end, stride=None):
end: list of int
Indicies indicating end of the slice.
stride: list of int, optional
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
Expand All @@ -361,5 +361,5 @@ def strided_slice(data, begin, end, stride=None):
ret : relay.Expr
The computed result.
"""
stride = stride or []
return _make.strided_slice(data, list(begin), list(end), list(stride))
strides = strides or []
return _make.strided_slice(data, list(begin), list(end), list(strides))
36 changes: 25 additions & 11 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,8 @@ RELAY_REGISTER_OP("broadcast_to_like")
.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
.set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel);


// strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
bool StridedSliceRel(const Array<Type>& types,
Expand All @@ -908,7 +910,7 @@ bool StridedSliceRel(const Array<Type>& types,
auto dshape = data->shape;
auto num_axis = dshape.size();

std::vector<IndexExpr> begin_vec;
std::vector<Integer> begin_vec;
for (auto i : param->begin) {
begin_vec.push_back(i);
}
Expand All @@ -924,8 +926,8 @@ bool StridedSliceRel(const Array<Type>& types,
end_vec.push_back(dshape[i]);
}

std::vector<IndexExpr> stride_vec;
for (auto i : param->stride) {
std::vector<Integer> stride_vec;
for (auto i : param->strides) {
stride_vec.push_back(i);
}
for (auto i = stride_vec.size(); i < num_axis; ++i) {
Expand All @@ -934,10 +936,22 @@ bool StridedSliceRel(const Array<Type>& types,
std::vector<IndexExpr> oshape(dshape.size());

for (size_t i = 0; i < num_axis; ++i) {
auto begin_range = reporter->Assert(stride_vec[i] < 0) ? -1 : 0;
auto end_range = reporter->Assert(stride_vec[i] < 0) ? dshape[i] - 1 : dshape[i];
auto begin = reporter->Assert(begin_vec[i] < 0) ? dshape[i] + begin_vec[i] : begin_vec[i];
auto end = reporter->Assert(end_vec[i] < 0) ? dshape[i] + end_vec[i] : end_vec[i];
const int64_t* stride_t = as_const_int(stride_vec[i]);
CHECK(stride_t != nullptr) << "Stride cannot be symbolic.";
int64_t stride_v = stride_t[0];

const int64_t* begin_t = as_const_int(begin_vec[i]);
CHECK(begin_t != nullptr) << "Begin index cannot be symbolic.";
int64_t begin_v = begin_t[0];

const int64_t* end_t = as_const_int(end_vec[i]);
CHECK(end_t != nullptr) << "End index cannot be symbolic.";
int64_t end_v = end_t[0];

auto begin_range = make_const(Int(64), (stride_v < 0) ? -1 : 0);
auto end_range = (stride_v < 0) ? dshape[i] - 1 : dshape[i];
auto begin = (begin_v < 0) ? dshape[i] + begin_vec[i] : begin_vec[i];
auto end = (end_v < 0) ? dshape[i] + end_vec[i] : end_vec[i];

begin = min(max(begin, begin_range), end_range);
end = min(max(end, begin_range), end_range);
Expand All @@ -958,13 +972,13 @@ bool StridedSliceRel(const Array<Type>& types,

// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data,
Array<IndexExpr> begin,
Array<IndexExpr> end,
Array<IndexExpr> stride) {
Array<Integer> begin,
Array<Integer> end,
Array<Integer> strides) {
auto attrs = make_node<StridedSliceAttrs>();
attrs->begin = std::move(begin);
attrs->end = std::move(end);
attrs->stride = std::move(stride);
attrs->strides = std::move(strides);
static const Op& op = Op::Get("strided_slice");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Expand Down
12 changes: 9 additions & 3 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,22 @@ def test_reduce_functions():
verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))
verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128))
verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))


def verify_strided_slice(data, begin, end, stride, output):
x = relay.var("x", relay.TensorType(data, "float32"))
z = relay.strided_slice(x, begin=begin, end=end, stride=stride)
z = relay.strided_slice(x, begin=begin, end=end, strides=stride)
zz = relay.ir_pass.infer_type(z)
assert "begin=" in z.astext()
assert "end=" in z.astext()
if stride:
assert "stride=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(output, "float32")
assert "strides=" in z.astext()
if output:
assert zz.checked_type == relay.ty.TensorType(output, "float32")

def test_strided_slice():
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
verify_strided_slice((d1, d2, d3), [0, 0, 0], [4, -5, 4], [1, -1, 2], None)
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3))
verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
Expand All @@ -111,6 +116,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3))
verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))


if __name__ == "__main__":
test_binary_op()
test_cmp_type()
Expand Down

0 comments on commit 3c3f641

Please sign in to comment.