diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index ba82bcdff00e6..342919420f4b8 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -127,7 +127,7 @@ struct SliceLikeAttrs : public tvm::AttrsNode { Array axes; TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") { - TVM_ATTR_FIELD(axes).set_default(Array()) + TVM_ATTR_FIELD(axes) .describe("List of axes on which input data will be sliced according to the " "corresponding size of the second input. By default will slice " "on all axes. Negative axes mean counting in reverse."); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index def59a6ab0e9e..d09e605790609 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1020,7 +1020,7 @@ bool SliceLikeRel(const Array& types, const Array target_shape = target->shape; std::vector&& oshape = AsVector(dshape); - if (!param->axes.defined() || param->axes.size() == 0) { + if (!param->axes.defined()) { for (size_t i = 0; i < dshape.size(); ++i) { if (i < target_shape.size()) { oshape[i] = target_shape[i]; @@ -1030,6 +1030,7 @@ bool SliceLikeRel(const Array& types, } } } else { + CHECK(param->axes.size() != 0) << "Axes cannot be empty."; for (Integer i : param->axes) { if (reporter->Assert(i < make_const(Int(64), 0))) { i += make_const(Int(64), dshape.size()); diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 89dbefef1d857..1709a9a2f9825 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -30,11 +30,11 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"): def test_slice_like(): d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") - verify_slice_like(data=(d1, d2, d3), slice_like=(1, 2, 3), axes=[], output=(1, 2, 3)) - verify_slice_like(data=(1, 2, 3), slice_like=(d1, d2, d3), axes=[], output=(d1, d2, d3)) + verify_slice_like(data=(d1, d2, d3), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3)) + verify_slice_like(data=(1, 2, 3), slice_like=(d1, d2, d3), axes=None, output=(d1, d2, d3)) verify_slice_like(data=(d2, d3, d4), slice_like=(d1, d2, d3), axes=(1,2), output=(d2, d2, d3)) - verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=[], output=(1, 2, 3)) - verify_slice_like(data=(3, 4, 5), slice_like=(1, 2), axes=[], output=(1, 2, 5)) + verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3)) + verify_slice_like(data=(3, 4, 5), slice_like=(1, 2), axes=None, output=(1, 2, 5)) verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(1, 2), output=(3, 2, 3)) verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(-1, -3), output=(1, 4, 3)) verify_slice_like(data=(1, 3, 224, 224),