Skip to content

Commit

Permalink
fixed output shape, refactored version working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent d2538ae commit 80442f8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
18 changes: 11 additions & 7 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
return result;
}

// inline te::Tensor strided_slice_compute_common() {}

/*!
* \brief strided_slice of a tensor with dynamic begin/end/stride
*
Expand Down Expand Up @@ -657,7 +659,7 @@ inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array<Integ
inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
const Array<Integer>& axes, std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string name = "T_strided_slice_with_axes",
std::string tag = kInjective) {
size_t src_tensor_dim = x->shape.size();

Expand Down Expand Up @@ -703,10 +705,12 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
}

// Compute
Array<PrimExpr> begin_expr;
Array<PrimExpr> strides_expr;

Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(x->shape[i]);
}

Array<PrimExpr> begin_expr, strides_expr;
for (size_t i = 0; i < axes.size(); ++i) {
int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
ICHECK(x->shape[axes[i]]->IsInstance<tvm::IntImmNode>())
Expand Down Expand Up @@ -734,7 +738,7 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
begin_expr.push_back(make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i]));
out_shape.push_back(slice_size);
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
}

return te::compute(
Expand All @@ -743,12 +747,12 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i];
PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return x(real_indices);
},
std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective});
name, tag);
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ def has_static_axes():
else:
strides_np = steps.data.asnumpy().astype("int64")

# return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np))
return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np))

# Update the starts and ends according to axes if required.
if axes is not None:
Expand Down

0 comments on commit 80442f8

Please sign in to comment.