From 36aa777eacd8426a850d08b528e9addcd36a4894 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 May 2021 06:41:50 +0900 Subject: [PATCH] refactoring slice with axes --- include/tvm/topi/nn.h | 2 +- include/tvm/topi/transform.h | 234 ++++++++++++------------------- src/relay/op/tensor/transform.cc | 15 +- src/target/target_kind.cc | 1 + src/topi/transform.cc | 6 +- 5 files changed, 104 insertions(+), 154 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 29c3156ab5d6..d3328c59afb4 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -619,7 +619,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { strides.push_back(Integer(1)); if (i > 0 && i <= num_block_dims) { diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index a81ac691dadd..28daff9da139 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -593,54 +593,54 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b name, tag); } -/*! - * \brief strided_slice of a tensor - * - * \param x The input tensor - * \param begin The indices to begin with in the slicing - * \param end Indicies indicating end of the slice - * \param strides Specifies the stride values, it can be negative - * in that case, the input tensor will be reversed in that particular axis - * \param slice_mode Specifies the slice mode - * \param name The name of the operation - * \param tag The tag to mark the operation - * - * \return A Tensor whose op member is the split operation - */ -inline Tensor strided_slice(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - std::string slice_mode = "end", std::string name = "T_strided_slice", - std::string tag = kInjective) { - size_t src_tensor_dim = static_cast(x->shape.size()); - // Quick path for dynamic shape strided slice. - // This is for ease of use to dynamice strided slice in topi. - bool is_static = IsConstIntArray(x->shape); - is_static &= IsConstIntArray(begin); - is_static &= IsConstIntArray(end); - is_static &= IsConstIntArray(strides); - +inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array& begin, + const Array& end, const Array& strides, + std::string slice_mode = "end", + std::string name = "T_strided_slice_dynamic_input", + std::string tag = kInjective) { + size_t src_tensor_dim = input->shape.size(); + ICHECK(begin.size() == src_tensor_dim) + << "for dynamic inputs, len(begin) must equal the input dimension"; Array out_shape; - if (!is_static) { - ICHECK_EQ(strides.size(), src_tensor_dim); - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(tvm::tir::Var("dim")); + } + Array begin_expr, end_expr, strides_expr; + for (size_t i = 0; i < src_tensor_dim; ++i) { + int64_t begin_i = begin[i]->value; + if (begin_i < 0) { + begin_i += topi::detail::GetConstInt(input->shape[i]); } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides[i] + begin[i]); - } - return x(real_indices); - }, - name, tag); + begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); + strides_expr.push_back( + tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), + (i < strides.size() ? strides[i]->value : 1))); } + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + } + return input(real_indices); + }, + std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective}); +} + +inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + const Array& axes, std::string slice_mode = "end", + std::string name = "T_strided_slice_dynamic_input", + std::string tag = kInjective) { + size_t src_tensor_dim = x->shape.size(); + + ICHECK(axes.size() <= src_tensor_dim); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); - // Setup the ranges. // NOTE: this code duplicates the shape inference logic relay.op // Consider to refactor in the future. - std::vector stride_vec(src_tensor_dim, 1); + std::vector stride_vec(strides.size(), 1); for (size_t i = 0; i < strides.size(); ++i) { ICHECK(strides[i].defined()); stride_vec[i] = GetConstInt(strides[i]); @@ -657,9 +657,6 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, begin_vec.push_back(GetConstInt(begin[i])); } } - for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } std::vector end_vec; for (size_t i = 0; i < end.size(); ++i) { @@ -678,16 +675,17 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, end_vec.push_back(GetConstInt(end[i])); } } - for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } + // Compute Array begin_expr; Array strides_expr; - for (size_t i = 0; i < src_tensor_dim; ++i) { + Array out_shape; + for (size_t i = 0; i < axes.size(); ++i) { int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; - int64_t dim_i = GetConstInt(x->shape[i]); + ICHECK(x->shape[axes[i]]->IsInstance()) + << "Input shape at axis " << axes[i] << " is not static"; + int64_t dim_i = GetConstInt(x->shape[axes[i]]); int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i; // transform negative indices to positive value, clips on the correct range auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) { @@ -713,116 +711,60 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, out_shape.push_back(slice_size); } - return compute( + return te::compute( out_shape, - [&](const Array& indices) { + [&](const Array& indices) { Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i]; + real_indices.Set(axes[i], ind); } return x(real_indices); }, - name, tag); + std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective}); } +/*! + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, const Array& strides, std::string slice_mode = "end", std::string name = "T_strided_slice", std::string tag = kInjective) { - Array begin_expr, end_expr, strides_expr; - for (size_t i = 0; i < begin.size(); ++i) { - begin_expr.push_back(begin[i]); - } - for (size_t i = 0; i < end.size(); ++i) { - end_expr.push_back(end[i]); - } - for (size_t i = 0; i < strides.size(); ++i) { - strides_expr.push_back(strides[i]); - } - return strided_slice(x, begin_expr, end_expr, strides_expr, slice_mode, name, tag); -} - -inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array& begin, - const Array& end, const Array& strides, - std::string slice_mode = "end", - std::string name = "T_strided_slice_dynamic_input", - std::string tag = kInjective) { - size_t src_tensor_dim = input->shape.size(); - ICHECK(begin.size() == src_tensor_dim) - << "for dynamic inputs, len(begin) must equal the input dimension"; - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - Array begin_expr, end_expr, strides_expr; - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_i = begin[i]->value; - if (begin_i < 0) { - begin_i += topi::detail::GetConstInt(input->shape[i]); - } - begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - (i < strides.size() ? strides[i]->value : 1))); - } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return input(real_indices); - }, - std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective}); -} - -inline Tensor strided_slice_with_axes(const Tensor& input, const Array& begin, - const Array& end, const Array& strides, - const Array& axes, std::string slice_mode = "end", - std::string name = "T_strided_slice_dynamic_input", - std::string tag = kInjective) { - size_t src_tensor_dim = input->shape.size(); + size_t src_tensor_dim = static_cast(x->shape.size()); + Array axes; + for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); + Array begin_full(begin); + Array end_full(end); + Array strides_full(strides); - ICHECK(axes.size() <= src_tensor_dim); - ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 1); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(input->shape[i]); + for (size_t i = strides_full.size(); i < src_tensor_dim; ++i) { + strides_full.push_back(one); } - Array begin_expr; - for (size_t i = 0; i < axes.size(); ++i) { - auto idim = input->shape[axes[i]]; - auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]); - auto e = tvm::if_then_else(end[i] < 0, end[i] + idim, end[i]); - auto s = strides[i]->value; - PrimExpr range; - if (s < 0) { - b = tvm::min(b, idim - 1); - e = tvm::if_then_else(e < -1, -1, e); - range = b - e; - s = -s; - } else { - b = tvm::if_then_else(b < 0, 0, b); - e = tvm::min(e, idim); - range = e - b; - } - PrimExpr odim = indexdiv(range + tvm::PrimExpr(static_cast(s - 1)), s); - out_shape.Set(axes[i], cast(out_shape[i].dtype(), odim)); - begin_expr.push_back(b); + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range); } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]); - for (size_t i = 0; i < axes.size(); ++i) { - PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i]; - real_indices.Set(axes[i], ind); - } - return input(real_indices); - }, - std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective}); + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range); + } + + return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name, + tag); } /*! diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5dc6147a239b..a1943e8cbfe2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3072,16 +3072,21 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& ICHECK(param != nullptr); Array src_shape = inputs[0]->shape; Array target_shape = inputs[1]->shape; - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < src_shape.size(); ++i) { begin_idx.push_back(0); strides.push_back(1); } - end_idx = Array(src_shape); + for (auto s : src_shape) { + ICHECK(s->IsInstance()) << "slice_like does not support dynamic input shape"; + end_idx.push_back(topi::GetConstInt(s)); + } if (!param->axes.defined()) { for (size_t i = 0; i < src_shape.size(); ++i) { if (i < target_shape.size()) { - end_idx.Set(i, target_shape[i]); + ICHECK(target_shape[i]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(i, topi::GetConstInt(target_shape[i])); ICHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) << "End index of axis " << i << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " @@ -3093,7 +3098,9 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& if (axis < 0) { axis = static_cast(src_shape.size()) + axis; } - end_idx.Set(axis, target_shape[axis]); + ICHECK(target_shape[axis]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(axis, topi::GetConstInt(target_shape[axis])); ICHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis])) << "End index of axis " << axis << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs " diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 08e998e0f035..cb0ca439a398 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -347,6 +347,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .set_default_keys({"vulkan", "gpu"}) .set_attrs_preprocessor(UpdateVulkanAttrs); + TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) .add_attr_option("system-lib") .add_attr_option("max_num_threads", Integer(256)) diff --git a/src/topi/transform.cc b/src/topi/transform.cc index dfea643217d6..dd7962bdb1cf 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -174,9 +174,9 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - Array begin = args[1]; - Array end = args[2]; - Array strides = args[3]; + Array begin = args[1]; + Array end = args[2]; + Array strides = args[3]; *rv = strided_slice(args[0], begin, end, strides, args[4]); });