Skip to content

Commit

Permalink
refactor type rel
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent ecfe3cd commit fbb099c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 153 deletions.
28 changes: 14 additions & 14 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,19 +673,19 @@ inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_
return std::make_tuple(begin_vec, end_vec, stride_vec);
}

inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Tensor& x,
inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Array<PrimExpr>& ishape,
const std::vector<int64_t>& begin,
const std::vector<int64_t>& strides,
const Array<Integer>& axes, DataType dtype,
std::string slice_mode = "end") {
Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
if (x->shape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(x->shape[axes[i]]);
if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[axes[i]]);
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
begin_expr.push_back(make_const(dtype, begin_i));
} else {
auto idim = x->shape[axes[i]];
auto idim = ishape[axes[i]];
auto b_expr = make_const(dtype, begin[i]);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
Expand All @@ -700,20 +700,20 @@ inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Tensor& x,
return begin_expr;
}

inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const std::vector<int64_t>& begin,
inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, const std::vector<int64_t>& begin,
const std::vector<int64_t>& end,
const std::vector<int64_t>& strides,
const Array<Integer>& axes, std::string slice_mode,
const Array<PrimExpr>& begin_canonicalized) {
size_t src_tensor_dim = x->shape.size();
size_t src_tensor_dim = ishape.size();
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(x->shape[i]);
out_shape.push_back(ishape[i]);
}

for (size_t i = 0; i < axes.size(); ++i) {
if (x->shape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(x->shape[axes[i]]);
if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(ishape[axes[i]]);
ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
int64_t begin_i = GetConstInt(begin_canonicalized[i]);
int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
Expand All @@ -731,16 +731,16 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const std::vecto
return out_shape;
}

inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const Array<Integer>& begin,
inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, const Array<Integer>& begin,
const Array<Integer>& end,
const Array<Integer>& strides,
const Array<Integer>& axes,
const std::string& slice_mode) {
std::vector<int64_t> begin_vec, end_vec, strides_vec;
std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode);
auto begin_canonicalized =
StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
return StridedSliceOutputShape(x, begin_vec, end_vec, strides_vec, axes, slice_mode,
StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
begin_canonicalized);
}

Expand All @@ -757,9 +757,9 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode);

auto begin_expr =
StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
auto out_shape =
StridedSliceOutputShape(x, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_expr);
StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_expr);

return te::compute(
out_shape,
Expand Down
155 changes: 16 additions & 139 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2445,147 +2445,24 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
return false;
}

auto dshape = data->shape;
int64_t num_axis = dshape.size();

// calculate output shape
std::vector<IndexExpr> oshape(num_axis);
if (param->begin && param->end && param->strides) {
const bool has_axes(param->axes);
// stride will be set as 1 if slice mode is enabled
std::vector<int64_t> stride_vec(num_axis, 1);
if (param->slice_mode == "end") {
for (size_t i = 0; i < param->strides.value().size(); ++i) {
ICHECK(param->strides.value()[i].defined());
stride_vec[i] = param->strides.value()[i]->value;
}
}
const int64_t max_range = std::numeric_limits<int64_t>::max();
std::vector<int64_t> begin_vec;
for (size_t i = 0; i < param->begin.value().size(); ++i) {
if (!param->begin.value()[i].defined()) {
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
} else {
begin_vec.push_back(param->begin.value()[i]->value);
}
}

std::vector<int64_t> end_vec;
for (size_t i = 0; i < param->end.value().size(); ++i) {
// allow end to be None
if (!param->end.value()[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else if (param->slice_mode == "size") {
if (param->end.value()[i]->value < 0) {
end_vec.push_back(max_range);
} else {
end_vec.push_back(begin_vec[i] + param->end.value()[i]->value);
}
} else if (param->slice_mode == "end") {
end_vec.push_back(param->end.value()[i]->value);
} else {
LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode;
}
}
ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin;
ICHECK(param->end) << "strided_slice recieved invalid end " << param->end;
ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides;

if (!has_axes) {
for (int64_t i = begin_vec.size(); i < num_axis; ++i) {
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
}

for (int64_t i = end_vec.size(); i < num_axis; ++i) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
}

for (int64_t i = 0; i < num_axis; ++i) {
int64_t stride_v = stride_vec[i];
int64_t begin_v = begin_vec[i];
int64_t end_v = end_vec[i];

if ((stride_v == 1 && begin_v == 0 && end_v == max_range) ||
(stride_v == -1 && begin_v == max_range && end_v == 0)) {
// Quick path, do not slice this dimension.
oshape[i] = dshape[i];
continue;
}
// Normal path, require the shape to be concrete integer.
// Require concrete integer as symbolic inference of min/max
// can get complicated and not very helpful.
const int64_t* p_dim_size = tir::as_const_int(dshape[i]);
if (!p_dim_size) {
oshape[i] = dshape[i];
continue;
}
int64_t dim_size = p_dim_size[0];
begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
end_v = (end_v < 0) ? dim_size + end_v : end_v;

int64_t slice_range, step;
if (stride_v < 0) {
if (end_v < -1) end_v = -1;
ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i;
begin_v = std::min(dim_size - 1, begin_v);
slice_range = begin_v - end_v;
step = -stride_v;
} else {
if (begin_v < 0) begin_v = 0;
ICHECK_GE(stride_v, 0);
ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i;
end_v = std::min(dim_size, end_v);
slice_range = end_v - begin_v;
step = stride_v;
}
oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
}
} else {
auto axes = param->axes.value();
for (int64_t i = 0; i < num_axis; ++i) {
oshape[i] = dshape[i];
}
for (int64_t i = 0; i < axes.size(); ++i) {
int64_t stride_v = stride_vec[i];
int64_t begin_v = begin_vec[i];
int64_t end_v = end_vec[i];

if ((stride_v == 1 && begin_v == 0 && end_v == max_range) ||
(stride_v == -1 && begin_v == max_range && end_v == 0)) {
// Quick path, do not slice this dimension.
continue;
}
// Normal path, require the shape to be concrete integer.
// Require concrete integer as symbolic inference of min/max
// can get complicated and not very helpful.
const int64_t* p_dim_size = tir::as_const_int(dshape[axes[i]]);
if (!p_dim_size) {
continue;
}
int64_t dim_size = p_dim_size[0];
begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
end_v = (end_v < 0) ? dim_size + end_v : end_v;

int64_t slice_range, step;
if (stride_v < 0) {
if (end_v < -1) end_v = -1;
ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i;
begin_v = std::min(dim_size - 1, begin_v);
slice_range = begin_v - end_v;
step = -stride_v;
} else {
if (begin_v < 0) begin_v = 0;
ICHECK_GE(stride_v, 0);
ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i;
end_v = std::min(dim_size, end_v);
slice_range = end_v - begin_v;
step = stride_v;
}
oshape[axes[i]] = tir::make_const(dshape[axes[i]].dtype(), (slice_range + step - 1) / step);
}
}
const size_t src_tensor_dim = static_cast<size_t>(data->shape.size());
Array<Integer> axes;
if (param->axes) {
axes = param->axes.value();
} else {
ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin;
ICHECK(param->end) << "strided_slice recieved invalid end " << param->end;
ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides;
}
for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
}
auto begin = param->begin.value();
auto end = param->end.value();
auto strides = param->strides.value();
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size())
<< "Axes, begin, end, and strides must have the same length";
auto oshape =
topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode);
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
Expand Down

0 comments on commit fbb099c

Please sign in to comment.