diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cc97a94a1406..d091342a5e4a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -303,6 +303,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode { Optional> end; Optional> strides; std::string slice_mode; + Optional> axes; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); @@ -317,6 +318,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode { "size - The input strides will be ignored, input end in this mode indicates the size" "of a slice starting at the location specified by begin. If end[i] is -1," "all remaining elements in that dimension are included in the slice"); + TVM_ATTR_FIELD(axes).describe("TODO"); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3f876f401b3c..c9d7b11f62dd 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1343,6 +1343,24 @@ def _impl_v10(cls, inputs, attr, params): data_rank = len(infer_shape(inputs[0])) + def has_static_axes(): + return (isinstance(axes, _expr.Constant) and + isinstance(starts, _expr.Constant) and + isinstance(ends, _expr.Constant) and + (steps is None or isinstance(steps, _expr.Constant))) + + # Update the starts and ends according to axes if required. + if axes is not None and has_static_axes(): + axes_np = axes.data.asnumpy().astype("int64") + begin_np = starts.data.asnumpy().astype("int64") + end_np = ends.data.asnumpy().astype("int64") + if steps is None: + strides_np = np.ones_like(begin_np).astype("int64") + else: + strides_np = steps.data.asnumpy().astype("int64") + + return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np)) + # Update the starts and ends according to axes if required. if axes is not None: data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 412acb4cea17..37b384dcfc31 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -244,16 +244,69 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice return out +@script +def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_mode, axes): + ndim = data_shape.shape[0] + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + out[i] = data_shape[i] + + for i in const_range(len(axes)): + axis = int64(axes[i]) + cbegin = int64(0) + cend = int64(data_shape[axis]) + cstride = int64(1) + if len(strides) > i: + cstride = int64(strides[i]) + if len(begin) > i: + cbegin = int64(begin[i]) + if cbegin < 0: + cbegin += int64(data_shape[axis]) + if len(end) <= i: + cend = int64(data_shape[axis]) + elif slice_mode != 0: + cstride = int64(1) + if end[i] < 0: + cend = int64(data_shape[axis]) + else: + cend = cbegin + int64(end[i]) + else: + if end[i] > data_shape[i]: + cend = int64(data_shape[axis]) + elif end[i] < -data_shape[i]: + cend = int64(-1) + else: + cend = int64(end[i]) + if cend < 0: + cend += int64(data_shape[axis]) + assert cstride != 0, "Strides can't be zero." + if cstride < 0: + slice_range = cbegin - cend + step = -cstride + else: + slice_range = cend - cbegin + step = cstride + + out[axis] = int64(ceil_div(slice_range, step)) + return out + + @_reg.register_shape_func("strided_slice", False) def strided_slice_shape_func(attrs, inputs, _): """ Shape func for strided_slice """ slice_mode = convert(0 if attrs.slice_mode == "end" else 1) + if len(attrs.axes) == 0: + return [ + _strided_slice_shape_func_input_shape( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode + ) + ] return [ - _strided_slice_shape_func_input_shape( - inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode - ) + _strided_slice_shape_func_with_axes( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode, attrs.axes + ) ] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c87f545c138a..af605b486bdf 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -867,7 +867,7 @@ def split(data, indices_or_sections, axis=0): return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) -def strided_slice(data, begin, end, strides=None, slice_mode="end"): +def strided_slice(data, begin, end, strides=None, slice_mode="end", axes=None): """Strided slice of an array. Parameters @@ -892,6 +892,9 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): the size of a slice starting at the location specified by begin. If end[i] is -1, all remaining elements in that dimension are included in the slice. + axes : List[int] + TODO + Returns ------- ret : relay.Expr @@ -917,7 +920,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin) begin = _make.where(begin >= ishape_slice, ishape_slice, begin) return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) - return _make.strided_slice(data, begin, end, strides, slice_mode) + return _make.strided_slice(data, begin, end, strides, slice_mode, axes) def strided_set(data, v, begin, end, strides=None): diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index bbfef5883e3d..089d7cebc9c0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -78,7 +78,7 @@ Expr MakeStack(Expr data, int axis); Expr MakeTranspose(Expr data, Array axes); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, - String slice_mode); + String slice_mode, Optional> axes=NullValue>()); Expr MakeTile(Expr data, Array reps); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bf45a412050f..1279e0acde9f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2451,6 +2451,7 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr // calculate output shape std::vector oshape(num_axis); if (param->begin && param->end && param->strides) { + const bool has_axes(param->axes); // stride will be set as 1 if slice mode is enabled std::vector stride_vec(num_axis, 1); if (param->slice_mode == "end") { @@ -2468,9 +2469,6 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr begin_vec.push_back(param->begin.value()[i]->value); } } - for (int64_t i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } std::vector end_vec; for (size_t i = 0; i < param->end.value().size(); ++i) { @@ -2489,49 +2487,99 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode; } } - for (int64_t i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } - for (int64_t i = 0; i < num_axis; ++i) { - int64_t stride_v = stride_vec[i]; - int64_t begin_v = begin_vec[i]; - int64_t end_v = end_vec[i]; + if (!has_axes) { + for (int64_t i = begin_vec.size(); i < num_axis; ++i) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } - if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || - (stride_v == -1 && begin_v == max_range && end_v == 0)) { - // Quick path, do not slice this dimension. - oshape[i] = dshape[i]; - continue; + for (int64_t i = end_vec.size(); i < num_axis; ++i) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } + + for (int64_t i = 0; i < num_axis; ++i) { + int64_t stride_v = stride_vec[i]; + int64_t begin_v = begin_vec[i]; + int64_t end_v = end_vec[i]; + + if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || + (stride_v == -1 && begin_v == max_range && end_v == 0)) { + // Quick path, do not slice this dimension. + oshape[i] = dshape[i]; + continue; + } + // Normal path, require the shape to be concrete integer. + // Require concrete integer as symbolic inference of min/max + // can get complicated and not very helpful. + const int64_t* p_dim_size = tir::as_const_int(dshape[i]); + if (!p_dim_size) { + oshape[i] = dshape[i]; + continue; + } + int64_t dim_size = p_dim_size[0]; + begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; + end_v = (end_v < 0) ? dim_size + end_v : end_v; + + int64_t slice_range, step; + if (stride_v < 0) { + if (end_v < -1) end_v = -1; + ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; + begin_v = std::min(dim_size - 1, begin_v); + slice_range = begin_v - end_v; + step = -stride_v; + } else { + if (begin_v < 0) begin_v = 0; + ICHECK_GE(stride_v, 0); + ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; + end_v = std::min(dim_size, end_v); + slice_range = end_v - begin_v; + step = stride_v; + } + oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } - // Normal path, require the shape to be concrete integer. - // Require concrete integer as symbolic inference of min/max - // can get complicated and not very helpful. - const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - if (!p_dim_size) { + } else { + auto axes = param->axes.value(); + for (int64_t i = 0; i < num_axis; ++i) { oshape[i] = dshape[i]; - continue; } - int64_t dim_size = p_dim_size[0]; - begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; - end_v = (end_v < 0) ? dim_size + end_v : end_v; - - int64_t slice_range, step; - if (stride_v < 0) { - if (end_v < -1) end_v = -1; - ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; - begin_v = std::min(dim_size - 1, begin_v); - slice_range = begin_v - end_v; - step = -stride_v; - } else { - if (begin_v < 0) begin_v = 0; - ICHECK_GE(stride_v, 0); - ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; - end_v = std::min(dim_size, end_v); - slice_range = end_v - begin_v; - step = stride_v; + for (int64_t i = 0; i < axes.size(); ++i) { + int64_t stride_v = stride_vec[i]; + int64_t begin_v = begin_vec[i]; + int64_t end_v = end_vec[i]; + + if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || + (stride_v == -1 && begin_v == max_range && end_v == 0)) { + // Quick path, do not slice this dimension. + continue; + } + // Normal path, require the shape to be concrete integer. + // Require concrete integer as symbolic inference of min/max + // can get complicated and not very helpful. + const int64_t* p_dim_size = tir::as_const_int(dshape[axes[i]]); + if (!p_dim_size) { + continue; + } + int64_t dim_size = p_dim_size[0]; + begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; + end_v = (end_v < 0) ? dim_size + end_v : end_v; + + int64_t slice_range, step; + if (stride_v < 0) { + if (end_v < -1) end_v = -1; + ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; + begin_v = std::min(dim_size - 1, begin_v); + slice_range = begin_v - end_v; + step = -stride_v; + } else { + if (begin_v < 0) begin_v = 0; + ICHECK_GE(stride_v, 0); + ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; + end_v = std::min(dim_size, end_v); + slice_range = end_v - begin_v; + step = stride_v; + } + oshape[axes[i]] = tir::make_const(dshape[axes[i]].dtype(), (slice_range + step - 1) / step); } - oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } } else { ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; @@ -2542,6 +2590,218 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr return true; } +// Array> StridedSliceInferCorrectLayout(const Attrs& attrs, +// const Array& new_in_layouts, +// const Array& old_in_layouts, +// const Array& old_in_types) +// { +// Array> old_in_shapes; +// for (auto old_in_t : old_in_types) { +// ICHECK(old_in_t.as()); +// old_in_shapes.push_back(old_in_t.as()->shape); +// } + +// ICHECK(old_in_layouts.defined()); +// ICHECK_GE(old_in_layouts.size(), 1); +// ICHECK(old_in_shapes.defined()); +// ICHECK_GE(old_in_shapes.size(), 1); + +// auto layout = old_in_layouts[0]; +// if (layout.defined() && new_in_layouts.defined()) { +// ICHECK_GE(new_in_layouts.size(), 1); +// auto new_layout = new_in_layouts[0]; +// auto shape = old_in_shapes[0]; + +// // NOTE: Discard "const" qualifier here. +// auto* params = const_cast(attrs.as()); +// ICHECK(params != nullptr); +// Array begin, end, strides; +// if (params->begin && params->end && params->strides) { +// for (Integer i : params->strides.value()) { +// ICHECK(i.defined()); +// strides.push_back(params->slice_mode == "size" ? 1 : i->value); +// } + +// for (Integer i : params->begin.value()) { +// ICHECK(i.defined()); +// begin.push_back(i->value); +// } +// for (Integer i : params->end.value()) { +// ICHECK(i.defined()); +// end.push_back(i->value); +// } +// } +// auto axes = params->axes; + +// Array new_begin, new_end, new_strides; + +// // Handles layout conversion like NHWC -> NCHW +// auto old_layout_name = layout.name(); +// auto new_layout_name = new_layout.name(); + +// if (old_layout_name.rfind(new_layout_name, 0) != 0 && +// new_layout_name.rfind(old_layout_name, 0) != 0) { +// if (old_layout_name.size() != new_layout_name.size()) { +// // Not support NHW4c -> NCHW +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } else { +// if (params->axes) { +// auto axes = params->axes.value(); +// Array new_axes(axes); +// std::vector axes_map(old_layout_name.size(), -1); +// for (size_t i = 0; i < axes.size(); ++i) { +// axes_map[axes[i]] = i; +// } +// LOG(INFO) << "old layout: " << old_layout_name; +// LOG(INFO) << "new layout: " << new_layout_name; +// for (size_t i = 0; i < new_layout_name.size(); ++i) { +// auto index = layout.IndexOf(new_layout[i]); +// if (index == -1) { +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } + +// size_t new_index = static_cast(index); +// if (axes_map[new_index] != -1) { +// ICHECK(strides[axes_map[new_index]].defined()); +// new_axes.Set(axes_map[new_index], new_index); +// new_begin.push_back(begin[axes_map[new_index]]->value); +// new_end.push_back(end[axes_map[new_index]]->value); +// new_strides.push_back(strides[axes_map[new_index]]->value); +// } +// } +// params->axes = new_axes; +// } else { +// for (size_t i = 0; i < new_layout_name.size(); ++i) { +// auto index = layout.IndexOf(new_layout[i]); +// if (index == -1) { +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } + +// size_t new_index = static_cast(index); +// int64_t bg, ed, st; +// if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) +// { +// st = strides[new_index]->value; +// } else { +// st = 1; +// } +// if (new_index < begin.size() && begin[new_index].defined()) { +// bg = begin[new_index]->value; +// } else { +// bg = 0; +// } +// if (new_index < end.size() && end[new_index].defined()) { +// ed = end[new_index]->value; +// } else { +// ed = shape[new_index].as()->value; +// } + +// new_begin.push_back(bg); +// new_end.push_back(ed); +// new_strides.push_back(st); +// } +// } +// params->begin = new_begin; +// params->end = new_end; +// params->strides = new_strides; +// layout = new_layout; +// } +// } else { +// if (params->axes) { +// auto axes = params->axes.value(); +// LOG(INFO) << "old layout: " << old_layout_name; +// LOG(INFO) << "new layout: " << new_layout_name; +// for (size_t i = 0; i < axes.size(); i++) { +// const LayoutAxis& axis = layout[axes[i]]; +// if (!axis.IsPrimal()) { +// // original layout that contains splitted axes is not supported +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } +// auto factor = new_layout.FactorOf(axis); +// if (factor == -1) { +// new_begin.push_back(begin[i]); +// new_end.push_back(end[i]); +// } else { +// if (strides.defined() && i < strides.size()) { +// auto stride = strides[i]; +// // arbitrary stride is not supported +// if (stride.defined() && stride->value != 1) { +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } +// } +// int64_t bg = begin[i].defined() ? begin[i]->value : 0; +// int64_t ed; +// if (!end[i].defined()) { +// ed = shape[axes[i]].as()->value; +// } else if (params->slice_mode == "size") { +// if (end[i]->value < 0) { +// ed = shape[axes[i]].as()->value; +// } else { +// ed = bg + end[i]->value; +// } +// } else { +// ed = end[i]->value; +// } + +// if (bg % factor || ed % factor) { +// // transform to original layout +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } +// new_begin.push_back(tvm::Integer(bg / factor)); +// new_end.push_back(tvm::Integer(ed / factor)); +// } +// } +// } else { +// for (size_t i = 0; i < begin.size(); i++) { +// const LayoutAxis& axis = layout[i]; +// if (!axis.IsPrimal()) { +// // original layout that contains splitted axes is not supported +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } +// auto factor = new_layout.FactorOf(axis); +// if (factor == -1) { +// new_begin.push_back(begin[i]); +// new_end.push_back(end[i]); +// } else { +// if (strides.defined() && i < strides.size()) { +// auto stride = strides[i]; +// // arbitrary stride is not supported +// if (stride.defined() && stride->value != 1) { +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } +// } +// int64_t bg = begin[i].defined() ? begin[i]->value : 0; +// int64_t ed; +// if (!end[i].defined()) { +// ed = shape[i].as()->value; +// } else if (params->slice_mode == "size") { +// if (end[i]->value < 0) { +// ed = shape[i].as()->value; +// } else { +// ed = bg + end[i]->value; +// } +// } else { +// ed = end[i]->value; +// } + +// if (bg % factor || ed % factor) { +// // transform to original layout +// return {{Layout::Undef()}, {Layout::Undef()}}; +// } +// new_begin.push_back(tvm::Integer(bg / factor)); +// new_end.push_back(tvm::Integer(ed / factor)); +// } +// } +// } + +// layout = new_layout; +// params->begin = new_begin; +// params->end = new_end; +// } +// } +// return {{layout}, {layout}}; +// } + Array> StridedSliceInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, @@ -2688,7 +2948,53 @@ Array StridedSliceCompute(const Attrs& attrs, const Arraybegin.value(); end = param->end.value(); strides = param->strides.value(); - if (IsDynamic(out_type)) { + if (param->axes) { + auto axes = param->axes.value(); + auto input = inputs[0]; + size_t src_tensor_dim = input->shape.size(); + + ICHECK(axes.size() <= src_tensor_dim); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && + axes.size() == strides.size()); + + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(input->shape[i]); + } + Array begin_expr; + for (size_t i = 0; i < axes.size(); ++i) { + auto idim = input->shape[axes[i]]; + auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]); + auto e = tvm::if_then_else(end[i] < 0, end[i] + idim, end[i]); + auto s = strides[i]->value; + PrimExpr range; + if (s < 0) { + b = tvm::min(b, idim - 1); + e = tvm::if_then_else(e < -1, -1, e); + range = b - e; + s = -s; + } else { + b = tvm::if_then_else(b < 0, 0, b); + e = tvm::min(e, idim); + range = e - b; + } + PrimExpr odim = indexdiv(range + tvm::PrimExpr(static_cast(s - 1)), s); + out_shape.Set(axes[i], cast(out_shape[i].dtype(), odim)); + begin_expr.push_back(b); + } + return Array{te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i]; + real_indices.Set(axes[i], ind); + } + return input(real_indices); + }, + std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective})}; + } else if (IsDynamic(out_type)) { auto input = inputs[0]; size_t src_tensor_dim = input->shape.size(); ICHECK(begin.size() == src_tensor_dim) @@ -2734,12 +3040,13 @@ Array StridedSliceCompute(const Attrs& attrs, const Array begin, Array end, Array strides, - String slice_mode) { + String slice_mode, Optional> axes) { auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); attrs->strides = std::move(strides); attrs->slice_mode = slice_mode; + attrs->axes = std::move(axes); static const Op& op = Op::Get("strided_slice"); return Call(op, {data}, Attrs(attrs), {}); }