diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index eb1ad8bd622b..c2a973b8bbf0 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -550,7 +550,25 @@ inline Array split(const Tensor& x, Array split_indices, int a return result; } -// inline te::Tensor strided_slice_compute_common() {} +inline te::Tensor strided_slice_compute_common(const te::Tensor& x, + const Array& out_shape, + const Array& begin, + const Array& strides, + const Array& axes, const std::string& name, + const std::string& tag) { + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + PrimExpr ind = indices[axes[i]] * strides[i] + begin[i]; + real_indices.Set(axes[i], ind); + } + return x(real_indices); + }, + name, tag); +} /*! * \brief strided_slice of a tensor with dynamic begin/end/stride @@ -597,8 +615,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, const Array& end, const Array& strides, - std::string slice_mode = "end", - std::string name = "T_strided_slice", + std::string name = "T_dynamic_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); ICHECK_EQ(begin.size(), src_tensor_dim); @@ -606,27 +623,20 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi ICHECK_EQ(strides.size(), src_tensor_dim); Array out_shape; + Array axes; for (size_t i = 0; i < src_tensor_dim; ++i) { out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); + axes.push_back(i); } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides[i] + begin[i]); - } - return x(real_indices); - }, - name, tag); + return strided_slice_compute_common(x, out_shape, begin, strides, axes, name, tag); } -inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array& begin, +inline Tensor strided_slice_dynamic_input(const Tensor& x, const Array& begin, const Array& end, const Array& strides, std::string slice_mode = "end", std::string name = "T_strided_slice_dynamic_input", std::string tag = kInjective) { - size_t src_tensor_dim = input->shape.size(); + size_t src_tensor_dim = x->shape.size(); ICHECK(begin.size() == src_tensor_dim) << "for dynamic inputs, len(begin) must equal the input dimension"; Array out_shape; @@ -634,26 +644,19 @@ inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array begin_expr, end_expr, strides_expr; + Array axes; for (size_t i = 0; i < src_tensor_dim; ++i) { int64_t begin_i = begin[i]->value; if (begin_i < 0) { - begin_i += topi::detail::GetConstInt(input->shape[i]); + begin_i += topi::detail::GetConstInt(x->shape[i]); } begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); strides_expr.push_back( tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), (i < strides.size() ? strides[i]->value : 1))); + axes.push_back(i); } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return input(real_indices); - }, - std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective}); + return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag); } inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, @@ -689,7 +692,6 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg std::vector end_vec; for (size_t i = 0; i < end.size(); ++i) { // allow end to be None - if (!end[i].defined()) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } else if (slice_mode == "size") { @@ -740,19 +742,7 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); } - - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]); - for (size_t i = 0; i < axes.size(); ++i) { - PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i]; - real_indices.Set(axes[i], ind); - } - return x(real_indices); - }, - name, tag); + return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag); } /*! diff --git a/src/topi/transform.cc b/src/topi/transform.cc index e30daf3f3503..e19b6da11064 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -189,7 +189,7 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* *rv = strided_slice_dynamic_input(x, begin_static, end_static, strides_static, slice_mode); } } else { - *rv = dynamic_strided_slice(x, begin, end, strides, slice_mode); + *rv = dynamic_strided_slice(x, begin, end, strides); } });