diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 027b3fe1df5f8..e02dd64780efd 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -310,6 +310,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode { Optional> end; Optional> strides; std::string slice_mode; + Optional> axes; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); @@ -324,6 +325,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode { "size - The input strides will be ignored, input end in this mode indicates the size" "of a slice starting at the location specified by begin. If end[i] is -1," "all remaining elements in that dimension are included in the slice"); + TVM_ATTR_FIELD(axes).describe("TODO"); } }; diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h new file mode 100644 index 0000000000000..e13082fa28094 --- /dev/null +++ b/include/tvm/topi/detail/strided_slice.h @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file strided_slice.h + * \brief Utility functions for strided_slice op + */ +#ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_ +#define TVM_TOPI_DETAIL_STRIDED_SLICE_H_ + +#include + +#include +#include + +#include "constant_utils.h" + +namespace tvm { +namespace topi { +namespace detail { + +using namespace tvm::te; + +inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) { + int64_t begin_range = stride < 0 ? -1 : 0; + int64_t end_range = stride < 0 ? extent - 1 : extent; + if (index < 0) { + index += extent; + } + return std::min(std::max(index, begin_range), end_range); +} + +inline std::tuple, std::vector, std::vector> ConvertToVec( + const Array& begin, const Array& end, const Array& strides, + std::string slice_mode) { + std::vector stride_vec(strides.size(), 1); + if (slice_mode == "end") { + for (size_t i = 0; i < strides.size(); ++i) { + ICHECK(strides[i].defined()); + stride_vec[i] = GetConstInt(strides[i]); + } + } + const int64_t max_range = std::numeric_limits::max(); + std::vector begin_vec; + for (size_t i = 0; i < begin.size(); ++i) { + if (!begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(GetConstInt(begin[i])); + } + } + std::vector end_vec; + for (size_t i = 0; i < end.size(); ++i) { + // allow end to be None + if (!end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (slice_mode == "size") { + int64_t end_val = GetConstInt(end[i]); + if (end_val < 0) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(begin_vec[i] + end_val); + } + } else { + end_vec.push_back(GetConstInt(end[i])); + } + } + return std::make_tuple(begin_vec, end_vec, stride_vec); +} + +inline Array StridedSliceCanonicalizeBegin(const Array& ishape, + const std::vector& begin, + const std::vector& strides, + const Array& axes, DataType dtype, + std::string slice_mode = "end") { + Array begin_expr; + for (size_t i = 0; i < axes.size(); ++i) { + if (ishape[axes[i]]->IsInstance()) { + int64_t dim_i = GetConstInt(ishape[axes[i]]); + int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]); + begin_expr.push_back(make_const(dtype, begin_i)); + } else { + auto idim = ishape[axes[i]]; + auto b_expr = make_const(dtype, begin[i]); + PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr; + auto s = strides[i]; + if (s < 0) { + b = tvm::min(b, idim - 1); + } else { + b = tvm::if_then_else(b < 0, 0, b); + } + begin_expr.push_back(b); + } + } + return begin_expr; +} + +inline Array StridedSliceOutputShape(const Array& ishape, + const std::vector& begin, + const std::vector& end, + const std::vector& strides, + const Array& axes, std::string slice_mode, + const Array& begin_canonicalized, + bool use_any = false) { + const size_t src_tensor_dim = ishape.size(); + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(ishape[i]); + } + + for (size_t i = 0; i < axes.size(); ++i) { + if (ishape[axes[i]]->IsInstance()) { + const int64_t dim_i = GetConstInt(ishape[axes[i]]); + ICHECK(begin_canonicalized[i]->IsInstance()); + int64_t begin_i = GetConstInt(begin_canonicalized[i]); + int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]); + int interval = std::abs(end_i - begin_i); + int slice_size = + static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); + ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) + << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; + out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); + } else if (use_any) { + out_shape.Set(axes[i], tvm::tir::Any()); + } else { + out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype)); + } + } + + return out_shape; +} + +} // namespace detail +} // namespace topi +} // namespace tvm +#endif // TVM_TOPI_DETAIL_STRIDED_SLICE_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 29c3156ab5d61..d3328c59afb47 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -619,7 +619,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { strides.push_back(Integer(1)); if (i > 0 && i <= num_block_dims) { diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 36acc7376c7c2..f82a7329bf9c6 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -27,8 +27,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -39,7 +41,6 @@ #include #include -#include "detail/broadcast.h" namespace tvm { namespace topi { @@ -550,6 +551,50 @@ inline Array split(const Tensor& x, Array split_indices, int a return result; } +inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + std::string name = "T_dynamic_strided_slice", + std::string tag = kInjective) { + const size_t src_tensor_dim = static_cast(x->shape.size()); + ICHECK_LE(begin.size(), src_tensor_dim); + ICHECK_LE(end.size(), src_tensor_dim); + ICHECK_LE(strides.size(), src_tensor_dim); + ICHECK_EQ(begin.size(), end.size()); + ICHECK_EQ(begin.size(), strides.size()); + + const size_t num_slice_axes = begin.size(); + Array out_shape; + + for (size_t i = 0; i < num_slice_axes; ++i) { + auto d = indexdiv(end[i] - begin[i], strides[i]); + if (d->IsInstance()) { + // Preserve static dimension if possible + out_shape.push_back(d); + } else { + out_shape.push_back(tvm::tir::Var("dim")); + } + } + + for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { + out_shape.push_back(x->shape[i]); + } + + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (int32_t i = 0; i < num_slice_axes; ++i) { + real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); + } + // keep input dim + for (int32_t i = num_slice_axes; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i]); + } + return x(real_indices); + }, + name, tag); +} + /*! * \brief strided_slice of a tensor with dynamic begin/end/stride * @@ -567,26 +612,72 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b const te::Tensor& end, const te::Tensor& strides, std::string name = "T_strided_slice_dynamic", std::string tag = topi::kInjective) { - int64_t src_tensor_dim = x->shape.size(); - Array out_shape; const int64_t num_dynamic_axes = begin->shape[0].as()->value; + ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); + ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); + + Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { - out_shape.push_back(x->shape[i]); + auto i64_ind = IntImm(DataType::Int(64), i); + begin_expr.push_back(begin(i64_ind)); + end_expr.push_back(end(i64_ind)); + strides_expr.push_back(strides(i64_ind)); } + return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); +} + +inline Array StridedSliceOutputShape( + const Array& ishape, const Array& begin, const Array& end, + const Array& strides, const Array& axes, const std::string& slice_mode) { + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + std::vector begin_vec, end_vec, strides_vec; + std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); + auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, + begin[0]->dtype, slice_mode); + return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, + begin_canonicalized, true); +} + +/*! + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + const Array& axes, std::string slice_mode = "end", + std::string name = "T_strided_slice_with_axes", + std::string tag = kInjective) { + const size_t src_tensor_dim = x->shape.size(); + ICHECK(axes.size() <= src_tensor_dim); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + + std::vector begin_vec, end_vec, strides_vec; + std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); + + auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes, + begin[0]->dtype, slice_mode); + auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes, + slice_mode, begin_expr); + return te::compute( out_shape, - [&](const Array& indices) { + [&](const Array& indices) { Array real_indices; - // dynamic slicing - for (int32_t i = 0; i < num_dynamic_axes; ++i) { - real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); - } - // keep input dim - for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i]); + for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + auto stride = make_const(strides[i].dtype(), strides_vec[i]); + PrimExpr ind = indices[axes[i]] * stride + begin_expr[i]; + real_indices.Set(axes[i], ind); } return x(real_indices); }, @@ -607,122 +698,32 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b * * \return A Tensor whose op member is the split operation */ -inline Tensor strided_slice(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - std::string slice_mode = "end", std::string name = "T_strided_slice", - std::string tag = kInjective) { +inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, + const Array& strides, std::string slice_mode = "end", + std::string name = "T_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); - // Quick path for dynamic shape strided slice. - // This is for ease of use to dynamice strided slice in topi. - bool is_static = IsConstIntArray(x->shape); - is_static &= IsConstIntArray(begin); - is_static &= IsConstIntArray(end); - is_static &= IsConstIntArray(strides); - - Array out_shape; - if (!is_static) { - ICHECK_EQ(strides.size(), src_tensor_dim); - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); - } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides[i] + begin[i]); - } - return x(real_indices); - }, - name, tag); - } + Array axes; + for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); + Array begin_full(begin); + Array end_full(end); + Array strides_full(strides); - // Setup the ranges. - // NOTE: this code duplicates the shape inference logic relay.op - // Consider to refactor in the future. - std::vector stride_vec(src_tensor_dim, 1); - for (size_t i = 0; i < strides.size(); ++i) { - ICHECK(strides[i].defined()); - stride_vec[i] = GetConstInt(strides[i]); - } - - const int64_t max_range = std::numeric_limits::max(); + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 0); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); - std::vector begin_vec; - for (size_t i = 0; i < begin.size(); ++i) { - if (!begin[i].defined()) { - // value=None - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(GetConstInt(begin[i])); - } - } - for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + for (size_t i = strides.size(); i < src_tensor_dim; ++i) { + strides_full.push_back(one); } - - std::vector end_vec; - for (size_t i = 0; i < end.size(); ++i) { - // allow end to be None - - if (!end[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else if (slice_mode == "size") { - int64_t end_val = GetConstInt(end[i]); - if (end_val < 0) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else { - end_vec.push_back(begin_vec[i] + end_val); - } - } else { - end_vec.push_back(GetConstInt(end[i])); - } + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range); } - for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } - // Compute - Array begin_expr; - Array strides_expr; - - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; - int64_t dim_i = GetConstInt(x->shape[i]); - int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i; - // transform negative indices to positive value, clips on the correct range - auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) { - if (index < 0) { - index += dim_i; - } - return std::min(std::max(index, begin_range), end_range); - }; - - int64_t begin_i = index_canonicalization(begin_vec[i]); - int64_t end_i = index_canonicalization(end_vec[i]); - - int interval = std::abs(end_i - begin_i); - int slice_size = - static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); - ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; - - begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); - out_shape.push_back(slice_size); + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range); } - return compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return x(real_indices); - }, - name, tag); + return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name, + tag); } /*! diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 896e8af999219..12e6daaa73671 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1341,7 +1341,31 @@ def _impl_v10(cls, inputs, attr, params): axes = inputs[3] steps = inputs[4] - data_rank = len(infer_shape(inputs[0])) + ishape = infer_shape(inputs[0]) + data_rank = len(ishape) + + def has_static_axes(): + return ( + isinstance(axes, _expr.Constant) + and isinstance(starts, _expr.Constant) + and isinstance(ends, _expr.Constant) + and (steps is None or isinstance(steps, _expr.Constant)) + ) + + # Update the starts and ends according to axes if required. + if axes is not None and has_static_axes(): + axes_np = axes.data.asnumpy().astype("int64") + begin_np = starts.data.asnumpy().astype("int64") + end_np = ends.data.asnumpy().astype("int64") + if steps is None: + strides_np = np.ones_like(begin_np).astype("int64") + else: + strides_np = steps.data.asnumpy().astype("int64") + + if all([isinstance(ishape[i], int) for i in axes_np]): + return _op.strided_slice( + inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np) + ) # Update the starts and ends according to axes if required. if axes is not None: @@ -3118,6 +3142,7 @@ def _get_convert_map(opset): "NonZero": NonZero.get_converter(opset), "Range": Range.get_converter(opset), "CumSum": CumSum.get_converter(opset), + "EyeLike": EyeLike.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 94c413b6df6cb..13c035047dfee 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -244,16 +244,69 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice return out +@script +def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_mode, axes): + ndim = data_shape.shape[0] + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + out[i] = data_shape[i] + + for i in const_range(len(axes)): + axis = int64(axes[i]) + cbegin = int64(0) + cend = int64(data_shape[axis]) + cstride = int64(1) + if len(strides) > i: + cstride = int64(strides[i]) + if len(begin) > i: + cbegin = int64(begin[i]) + if cbegin < 0: + cbegin += int64(data_shape[axis]) + if len(end) <= i: + cend = int64(data_shape[axis]) + elif slice_mode != 0: + cstride = int64(1) + if end[i] < 0: + cend = int64(data_shape[axis]) + else: + cend = cbegin + int64(end[i]) + else: + if end[i] > data_shape[i]: + cend = int64(data_shape[axis]) + elif end[i] < -data_shape[i]: + cend = int64(-1) + else: + cend = int64(end[i]) + if cend < 0: + cend += int64(data_shape[axis]) + assert cstride != 0, "Strides can't be zero." + if cstride < 0: + slice_range = cbegin - cend + step = -cstride + else: + slice_range = cend - cbegin + step = cstride + + out[axis] = int64(ceil_div(slice_range, step)) + return out + + @_reg.register_shape_func("strided_slice", False) def strided_slice_shape_func(attrs, inputs, _): """ Shape func for strided_slice """ slice_mode = convert(0 if attrs.slice_mode == "end" else 1) + if attrs.axes is None: + return [ + _strided_slice_shape_func_input_shape( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode + ) + ] return [ - _strided_slice_shape_func_input_shape( - inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode - ) + _strided_slice_shape_func_with_axes( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode, attrs.axes + ) ] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 74fb44fc2232a..828f84831aa38 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -867,7 +867,7 @@ def split(data, indices_or_sections, axis=0): return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) -def strided_slice(data, begin, end, strides=None, slice_mode="end"): +def strided_slice(data, begin, end, strides=None, slice_mode="end", axes=None): """Strided slice of an array. Parameters @@ -892,6 +892,9 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): the size of a slice starting at the location specified by begin. If end[i] is -1, all remaining elements in that dimension are included in the slice. + axes : List[int] + TODO + Returns ------- ret : relay.Expr @@ -917,7 +920,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin) begin = _make.where(begin >= ishape_slice, ishape_slice, begin) return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) - return _make.strided_slice(data, begin, end, strides, slice_mode) + return _make.strided_slice(data, begin, end, strides, slice_mode, axes) def strided_set(data, v, begin, end, strides=None): diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b43972d686cc4..4372f88b9500d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -490,6 +490,9 @@ class VMFunctionCompiler : ExprFunctor { argument_registers.push_back(reg->second); } + // Extract functions attrs + op_attrs[op_index] = func->attrs->dict; + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), outputs.size(), argument_registers)); } diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index bbfef5883e3df..089d7cebc9c03 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -78,7 +78,7 @@ Expr MakeStack(Expr data, int axis); Expr MakeTranspose(Expr data, Array axes); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, - String slice_mode); + String slice_mode, Optional> axes=NullValue>()); Expr MakeTile(Expr data, Array reps); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 10fe5e543dfc7..56240e6da1b62 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2445,99 +2445,40 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr return false; } - auto dshape = data->shape; - int64_t num_axis = dshape.size(); - - // calculate output shape - std::vector oshape(num_axis); - if (param->begin && param->end && param->strides) { - // stride will be set as 1 if slice mode is enabled - std::vector stride_vec(num_axis, 1); - if (param->slice_mode == "end") { - for (size_t i = 0; i < param->strides.value().size(); ++i) { - ICHECK(param->strides.value()[i].defined()); - stride_vec[i] = param->strides.value()[i]->value; - } - } - const int64_t max_range = std::numeric_limits::max(); - std::vector begin_vec; - for (size_t i = 0; i < param->begin.value().size(); ++i) { - if (!param->begin.value()[i].defined()) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(param->begin.value()[i]->value); - } - } - for (int64_t i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } + ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; + ICHECK(param->end) << "strided_slice recieved invalid end " << param->end; + ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; + + auto begin = param->begin.value(); + auto end = param->end.value(); + auto strides = param->strides.value(); + + const size_t src_tensor_dim = static_cast(data->shape.size()); + Array axes; + if (param->axes) { + axes = param->axes.value(); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && + axes.size() == strides.size()) + << "axes, begin, end, and strides must have the same length"; + } else { + for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); - std::vector end_vec; - for (size_t i = 0; i < param->end.value().size(); ++i) { - // allow end to be None - if (!param->end.value()[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else if (param->slice_mode == "size") { - if (param->end.value()[i]->value < 0) { - end_vec.push_back(max_range); - } else { - end_vec.push_back(begin_vec[i] + param->end.value()[i]->value); - } - } else if (param->slice_mode == "end") { - end_vec.push_back(param->end.value()[i]->value); - } else { - LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode; - } + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 0); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); + + for (size_t i = strides.size(); i < src_tensor_dim; ++i) { + strides.push_back(one); } - for (int64_t i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range); } - - for (int64_t i = 0; i < num_axis; ++i) { - int64_t stride_v = stride_vec[i]; - int64_t begin_v = begin_vec[i]; - int64_t end_v = end_vec[i]; - - if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || - (stride_v == -1 && begin_v == max_range && end_v == 0)) { - // Quick path, do not slice this dimension. - oshape[i] = dshape[i]; - continue; - } - // Normal path, require the shape to be concrete integer. - // Require concrete integer as symbolic inference of min/max - // can get complicated and not very helpful. - const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - if (!p_dim_size) { - oshape[i] = dshape[i]; - continue; - } - int64_t dim_size = p_dim_size[0]; - begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; - end_v = (end_v < 0) ? dim_size + end_v : end_v; - - int64_t slice_range, step; - if (stride_v < 0) { - if (end_v < -1) end_v = -1; - ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; - begin_v = std::min(dim_size - 1, begin_v); - slice_range = begin_v - end_v; - step = -stride_v; - } else { - if (begin_v < 0) begin_v = 0; - ICHECK_GE(stride_v, 0); - ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; - end_v = std::min(dim_size, end_v); - slice_range = end_v - begin_v; - step = stride_v; - } - oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range); } - } else { - ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; - ICHECK(param->end) << "strided_slice recieved invalid end " << param->end; - ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; } + auto oshape = + topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode); reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -2596,78 +2537,130 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, // Not support NHW4c -> NCHW return {{Layout::Undef()}, {Layout::Undef()}}; } else { - for (size_t i = 0; i < new_layout_name.size(); ++i) { - auto index = layout.IndexOf(new_layout[i]); - if (index == -1) { - return {{Layout::Undef()}, {Layout::Undef()}}; + if (params->axes) { + auto axes = params->axes.value(); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + new_strides.push_back(strides[i]); + new_axes.push_back(new_idx); } + params->axes = new_axes; - size_t new_index = static_cast(index); - int64_t bg, ed, st; - if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { - st = strides[new_index]->value; - } else { - st = 1; - } - if (new_index < begin.size() && begin[new_index].defined()) { - bg = begin[new_index]->value; - } else { - bg = 0; - } - if (new_index < end.size() && end[new_index].defined()) { - ed = end[new_index]->value; - } else { - ed = shape[new_index].as()->value; - } + } else { + for (size_t i = 0; i < new_layout_name.size(); ++i) { + auto index = layout.IndexOf(new_layout[i]); + if (index == -1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } - new_begin.push_back(IntImm(begin[0]->dtype, bg)); - new_end.push_back(IntImm(end[0]->dtype, ed)); - new_strides.push_back(IntImm(strides[0]->dtype, st)); + size_t new_index = static_cast(index); + int64_t bg, ed, st; + if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { + st = strides[new_index]->value; + } else { + st = 1; + } + if (new_index < begin.size() && begin[new_index].defined()) { + bg = begin[new_index]->value; + } else { + bg = 0; + } + if (new_index < end.size() && end[new_index].defined()) { + ed = end[new_index]->value; + } else { + ed = shape[new_index].as()->value; + } + + new_begin.push_back(IntImm(begin[0]->dtype, bg)); + new_end.push_back(IntImm(end[0]->dtype, ed)); + new_strides.push_back(IntImm(strides[0]->dtype, st)); + } } + params->begin = new_begin; params->end = new_end; params->strides = new_strides; layout = new_layout; } } else { - for (size_t i = 0; i < begin.size(); i++) { - const LayoutAxis& axis = layout[i]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return {{Layout::Undef()}, {Layout::Undef()}}; - } - auto factor = new_layout.FactorOf(axis); - if (factor == -1) { - new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); - new_end.push_back(IntImm(end[i]->dtype, end[i])); - } else { - if (strides.defined() && i < strides.size()) { - auto stride = strides[i]; - // arbitrary stride is not supported - if (stride.defined() && stride->value != 1) { + if (params->axes) { + auto axes = params->axes.value(); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_axes.push_back(new_idx); + + const LayoutAxis& axis = layout[old_idx]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + + auto factor = new_layout.FactorOf(axis); + + if (factor == -1) { + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + } else { + int64_t bg = begin[i]; + int64_t ed = end[i]; + if (bg % factor || ed % factor) { + // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - int64_t bg = begin[i].defined() ? begin[i]->value : 0; - int64_t ed; - if (!end[i].defined()) { - ed = shape[i].as()->value; - } else if (params->slice_mode == "size") { - if (end[i]->value < 0) { + } + params->axes = new_axes; + + } else { + for (size_t i = 0; i < begin.size(); i++) { + const LayoutAxis& axis = layout[i]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + auto factor = new_layout.FactorOf(axis); + if (factor == -1) { + new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); + new_end.push_back(IntImm(end[i]->dtype, end[i])); + } else { + if (strides.defined() && i < strides.size()) { + auto stride = strides[i]; + // arbitrary stride is not supported + if (stride.defined() && stride->value != 1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } + } + int64_t bg = begin[i].defined() ? begin[i]->value : 0; + int64_t ed; + if (!end[i].defined()) { ed = shape[i].as()->value; + } else if (params->slice_mode == "size") { + if (end[i]->value < 0) { + ed = shape[i].as()->value; + } else { + ed = bg + end[i]->value; + } } else { - ed = bg + end[i]->value; + ed = end[i]->value; } - } else { - ed = end[i]->value; - } - if (bg % factor || ed % factor) { - // transform to original layout - return {{Layout::Undef()}, {Layout::Undef()}}; + if (bg % factor || ed % factor) { + // transform to original layout + return {{Layout::Undef()}, {Layout::Undef()}}; + } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); - new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } } @@ -2683,63 +2676,27 @@ Array StridedSliceCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); - Array begin, end, strides; - Array begin_expr, end_expr, strides_expr; - begin = param->begin.value(); - end = param->end.value(); - strides = param->strides.value(); - if (IsDynamic(out_type)) { - auto input = inputs[0]; - size_t src_tensor_dim = input->shape.size(); - ICHECK(begin.size() == src_tensor_dim) - << "for dynamic inputs, len(begin) must equal the input dimension"; - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_i = begin[i]->value; - if (begin_i < 0) { - begin_i += topi::detail::GetConstInt(input->shape[i]); - } - begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - (i < strides.size() ? strides[i]->value : 1))); - } - return Array{te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return input(real_indices); - }, - std::string{"T_strided_slice_dynamic"}, std::string{topi::kInjective})}; - } else { - for (size_t i = 0; i < begin.size(); ++i) { - begin_expr.push_back(begin[i]); - } - for (size_t i = 0; i < end.size(); ++i) { - end_expr.push_back(end[i]); - } - for (size_t i = 0; i < strides.size(); ++i) { - strides_expr.push_back(strides[i]); - } + ICHECK(param->begin && param->end && param->strides); + Array begin = param->begin.value(); + Array end = param->end.value(); + Array strides = param->strides.value(); + if (param->axes) { + auto axes = param->axes.value(); + return Array{ + topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)}; } - return Array{ - topi::strided_slice(inputs[0], begin_expr, end_expr, strides_expr, param->slice_mode)}; + return Array{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; } // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, - String slice_mode) { + String slice_mode, Optional> axes) { auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); attrs->strides = std::move(strides); attrs->slice_mode = slice_mode; + attrs->axes = std::move(axes); static const Op& op = Op::Get("strided_slice"); return Call(op, {data}, Attrs(attrs), {}); } @@ -3057,16 +3014,21 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& ICHECK(param != nullptr); Array src_shape = inputs[0]->shape; Array target_shape = inputs[1]->shape; - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < src_shape.size(); ++i) { begin_idx.push_back(0); strides.push_back(1); } - end_idx = Array(src_shape); + for (auto s : src_shape) { + ICHECK(s->IsInstance()) << "slice_like does not support dynamic input shape"; + end_idx.push_back(topi::GetConstInt(s)); + } if (!param->axes.defined()) { for (size_t i = 0; i < src_shape.size(); ++i) { if (i < target_shape.size()) { - end_idx.Set(i, target_shape[i]); + ICHECK(target_shape[i]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(i, topi::GetConstInt(target_shape[i])); ICHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) << "End index of axis " << i << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " @@ -3078,7 +3040,9 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& if (axis < 0) { axis = static_cast(src_shape.size()) + axis; } - end_idx.Set(axis, target_shape[axis]); + ICHECK(target_shape[axis]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(axis, topi::GetConstInt(target_shape[axis])); ICHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis])) << "End index of axis " << axis << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs " diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 99126b1591435..0a7795d600fe1 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -105,6 +105,10 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun } std::unordered_map metrics; + + ICHECK(exec_->op_attrs.find(packed_index) != exec_->op_attrs.end()) + << packed_index_map_[packed_index] << " not found in op attrs"; + auto& op_attrs = exec_->op_attrs.at(packed_index); for (auto p : op_attrs) { if (std::string(p.first).find("layout") != std::string::npos) { diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 0bce3bbc7f538..7c6e491dcc268 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -174,11 +174,26 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = strided_slice(args[0], args[1], args[2], args[3], args[4]); + Tensor x = args[0]; + Array begin = args[1]; + Array end = args[2]; + Array strides = args[3]; + std::string slice_mode = args[4]; + if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) { + Array begin_static = args[1]; + Array end_static = args[2]; + Array strides_static = args[3]; + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } else { + *rv = dynamic_strided_slice(x, begin, end, strides); + } }); TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = dynamic_strided_slice(args[0], args[1], args[2], args[3]); + te::Tensor begin = args[1]; + te::Tensor end = args[2]; + te::Tensor strides = args[3]; + *rv = dynamic_strided_slice(args[0], begin, end, strides); }); TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8016e435618ad..74b8ec51e1fa2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1032,7 +1032,7 @@ def verify_any_strided_slice( mod = tvm.IRModule() data = relay.var("data", shape=data_shape, dtype="float32") if const_attrs: - data = relay.var("data", shape=data_np_shape, dtype="float32") + data = relay.var("data", shape=data_shape, dtype="float32") begin = relay.const(np_begin) end = relay.const(np_end) strides = relay.const(np_strides) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c49e3de626622..b7d102d2b4d45 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -483,9 +483,8 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - # TODO(mbrookhart): fix static strided_slice with dynamic input and negative begin - # verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - # verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False ) @@ -534,6 +533,7 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): if __name__ == "__main__": test_strided_slice() + test_dyn_strided_slice() test_strided_set() test_binary_op() test_cmp_type() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3031c55379ae8..5c2793c607a9f 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -770,6 +770,61 @@ def expected(): ) +@tvm.testing.uses_gpu +def test_alter_layout_strided_slice_axes_nhwc(): + """Test rewriting strided_slice with axes during alter_iop_layout""" + + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 32], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NHWC4c" + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + x = relay.layout_transform(x, "NHWC", "NHWC4c") + y = relay.op.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC4c", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 4], end=[1, 8], strides=[1, 1], axes=[0, 3]) + y = relay.layout_transform(y, "NHWC4c", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + + mod_before = tvm.IRModule() + mod_new = tvm.IRModule() + mod_before["main"] = a + mod_new["main"] = b + assert tvm.ir.structural_equal(mod_before, mod_new) + + def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -1298,3 +1353,4 @@ def expected(): test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() test_alter_op_dense() + test_alter_layout_strided_slice_axes_nhwc() diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index dd2dc979a7316..5f3d754284b59 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1235,6 +1235,49 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_strided_slice_axes_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 1]) + + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = run_opt_pass(before(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_roi_pool_convert_layout(): def before(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -1784,3 +1827,4 @@ def expected(): test_convert_with_config() test_conv_squeeze_convert_layout() test_conv_reduce_convert_layout() + test_conv_strided_slice_axes_convert_layout()