diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 36b93ee5d39f..a0ea8f2e60a3 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -363,7 +363,7 @@ RELAY_REGISTER_OP("stack") .set_attrs_type_key("relay.attrs.StackAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input list of tensors.") -.set_support_level(1) +.set_support_level(3) .add_type_rel("Stack", StackRel) .set_attr("FTVMCompute", StackCompute) .set_attr("TOpPattern", kInjective); @@ -1109,7 +1109,7 @@ RELAY_REGISTER_OP("repeat") .set_num_inputs(1) .set_attrs_type_key("relay.attrs.Repeat") .add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) +.set_support_level(3) .add_type_rel("Repeat", RepeatRel) .set_attr("FTVMCompute", RepeatCompute) .set_attr("TOpPattern", kBroadcast); @@ -1134,9 +1134,15 @@ bool TileRel(const Array& types, const size_t ndim = data->shape.size(); const Array& reps = param->reps; // check dimension match - CHECK(!reps.defined()) + CHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; const size_t rndim = reps.size(); + for (size_t i = 0; i < rndim; ++i) { + if (const tvm::ir::IntImm* val = reps[i].as()) { + CHECK_GT(val->value, 0) + << "Tile reps value should always be larger than 0, but get: " << val->value; + } + } size_t tndim = (ndim > rndim) ? ndim : rndim; // re-construct data shape or reps shape std::vector data_shape; @@ -1158,6 +1164,10 @@ bool TileRel(const Array& types, } else { for (size_t i = 0; i < rndim; ++i) reps_shape.emplace_back(reps[i]); + for (size_t i = 0; i < (rndim - ndim); ++i) + data_shape.emplace_back(1); + for (size_t i = 0; i < ndim; ++i) + data_shape.emplace_back(data->shape[i]); } std::vector oshape; oshape.reserve(tndim); @@ -1199,7 +1209,7 @@ RELAY_REGISTER_OP("tile") .set_num_inputs(1) .set_attrs_type_key("relay.attrs.Tile") .add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) +.set_support_level(3) .add_type_rel("Tile", TileRel) .set_attr("FTVMCompute", TileCompute) .set_attr("TOpPattern", kBroadcast); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index ed6a79e82b3f..10ace54e8b12 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -491,6 +491,62 @@ def verify_arange(start, stop, step): verify_arange(20, 1, -1) verify_arange(20, 1, -1.5) +def test_tile(): + def verify_tile(dshape, reps): + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.tile(x, reps=reps) + + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") + ref_res = np.tile(x_data, reps=reps) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_tile((2, 3, 4), (3, 2, 1)) + verify_tile((2, 3, 4), (1, 2)) + verify_tile((2, 3), (3, 2, 1)) + +def test_repeat(): + def verify_repeat(dshape, repeats, axis): + x = relay.Var("x", relay.TensorType(dshape, "float32")) + func = relay.Function([x], relay.repeat(x, repeats, axis)) + data = np.random.uniform(size=dshape).astype("float32") + ref_res = np.repeat(data, repeats, axis) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_repeat((3,), 2, 0) + verify_repeat((3, 10), 2, -1) + verify_repeat((3, 2, 4), 3, 1) + +def test_stack(): + def verify_stack(dshapes, axis): + y = [] + for shape in dshapes: + y.append(relay.var("input", relay.TensorType(shape, "float32"))) + x = relay.Tuple(y) + z = relay.stack(x, axis=axis) + + func = relay.Function(y, z) + x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes] + ref_res = np.stack(x_data, axis=axis) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(*x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_stack([(2,), (2,), (2,)], -1) + verify_stack([(2,), (2,), (2,)], 0) + verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1) + verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1) + + def test_reverse(): def verify_reverse(dshape, axis): @@ -536,3 +592,6 @@ def verify_reverse(dshape, axis): test_split_infer_type() test_arange() test_reverse() + test_stack() + test_tile() + test_repeat() diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index ad557f0fcbfe..59c1090480c2 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -29,7 +29,7 @@ def check_device(device): check_device(device) -def verify_tranpose(in_shape, axes): +def verify_transpose(in_shape, axes): A = tvm.placeholder(shape=in_shape, name="A") B = topi.transpose(A, axes) def check_device(device): @@ -40,7 +40,7 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_injective(B) - foo = tvm.build(s, [A, B], device, name="tranpose") + foo = tvm.build(s, [A, B], device, name="transpose") data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) out_npy = data_npy.transpose(axes) data_nd = tvm.nd.array(data_npy, ctx) @@ -416,10 +416,10 @@ def test_expand_dims(): verify_expand_dims((3, 10), (1, 3, 10), -3, 1) -def test_tranpose(): - verify_tranpose((3, 10, 2), (1, 0, 2)) - verify_tranpose((3, 10, 5), (2, 0, 1)) - verify_tranpose((3, 10), None) +def test_transpose(): + verify_transpose((3, 10, 2), (1, 0, 2)) + verify_transpose((3, 10, 5), (2, 0, 1)) + verify_transpose((3, 10), None) def test_reshape(): @@ -595,7 +595,7 @@ def check_device(device): test_strided_slice() test_concatenate() test_stack() - test_tranpose() + test_transpose() test_expand_dims() test_reshape() test_squeeze()