diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 274294ccb388e..262f41edad677 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -93,6 +93,26 @@ struct ReshapeAttrs : public tvm::AttrsNode { } }; // struct ReshapeAttrs +/*! \brief Attributes used in MXNet-style reshape_like operators */ +struct ReshapeLikeAttrs : public tvm::AttrsNode { + int lhs_begin; + Integer lhs_end; // can be None + int rhs_begin; + Integer rhs_end; // can be None + TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs") { + TVM_ATTR_FIELD(lhs_begin).set_default(0).describe( + "The axis of the input where reshaping should begin."); + TVM_ATTR_FIELD(lhs_end) + .set_default(NullValue()) + .describe("The axis of the input where reshaping should end, exclusive."); + TVM_ATTR_FIELD(rhs_begin).set_default(0).describe( + "The axis of the shape_like tensor to begin taking dimensions from."); + TVM_ATTR_FIELD(rhs_end) + .set_default(NullValue()) + .describe("The axis of the shape_like tensor to end taking dimensions from, exclusive."); + } +}; // struct ReshapeLikeAttrs + struct ScatterAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 5dc2c2402c08d..2c5f046bb7e81 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -194,6 +194,11 @@ class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" +@tvm._ffi.register_object("relay.attrs.ReshapeLikeAttrs") +class ReshapeLikeAttrs(Attrs): + """Attributes for transform.reshape_like""" + + @tvm._ffi.register_object("relay.attrs.GatherAttrs") class GatherAttrs(Attrs): """Attributes for transform.gather""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 17f4c02380b31..b7df6001e59e2 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -308,28 +308,55 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) -def reshape_like(data, shape_like): - """Reshapes the input array by the size of another array. - For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes - the input array into an output array with the same shape as the second input array. +def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None): + """Reshapes the input tensor by the size of another tensor. + For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes + the input tensor into an output tensor with the same shape as the second input tensor, + in particular reshaping the dimensions of `data` in `[lhs_begin, lhs_end)` using the dimensions + from `shape_like` in `[rhs_begin, rhs_end)`. .. note:: - Sizes for both array should be compatible. + Sizes for `data` and the output tensor should be compatible. Parameters ---------- data : relay.Expr The input data to the operator. - shape_like : tuple of int - The new shape. Should be compatible with the original shape. + shape_like : relay.Expr + The tensor to reshape data like. Should be compatible with the original shape on the + reshaped dimensions. + + lhs_begin : int, optional + The axis of data to begin reshaping. Default is 0. + + lhs_end : int or None, optional + The axis of data where reshaping should stop, exclusive. Default is None which reshapes to + the end. + + rhs_begin : int, optional + The axis of shape_like where the target shape begins. Default is 0. + + rhs_end : int or None, optional + The axis of shape_like where the target shape ends, exclusive. Default is None which extends + to the end. Returns ------- ret : relay.Expr The computed result. + + Examples + -------- + .. code-block:: python + + data.shape == (1, 2, 3, 4) + shape_like.shape == (6, 2, 2, 3) + + ret = relay.reshape_like(data, shape_like, lhs_begin=1, rhs_end=3) + ret.shape == (1, 6, 2, 2) """ - return _make.reshape_like(data, shape_like) + return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end) def take(data, indices, axis=None, mode="clip"): diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 631ec4c0d2f55..0e1f5c560081f 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -62,6 +62,9 @@ Expr MakeRepeat(Expr data, int repeats, int axis); Expr MakeReshape(Expr data, Array newshape); +Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, + Integer rhs_end); + Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 02fd8930d3326..3ca816a6caaea 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -453,6 +453,7 @@ RELAY_REGISTER_OP("transpose") /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); +TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs); Array infer_newshape(const Array& data_shape, const Attrs& attrs) { const auto* param = attrs.as(); @@ -641,11 +642,49 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +Array infer_reshape_like(const Array& lhs_shape, + const Array& rhs_shape, const Attrs& attrs) { + const auto* like_attrs = attrs.as(); + CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as()) + << "lhs_end must be a concrete integer or None"; + CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as()) + << "rhs_end must be a concrete integer or None"; + + int64_t lhs_shape_size = static_cast(lhs_shape.size()); + int64_t rhs_shape_size = static_cast(rhs_shape.size()); + int64_t lhs_begin = static_cast(like_attrs->lhs_begin); + int64_t lhs_end = + like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as()->value : lhs_shape_size; + int64_t rhs_begin = static_cast(like_attrs->rhs_begin); + int64_t rhs_end = + like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as()->value : rhs_shape_size; + + // handle negative axes + lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin; + lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end; + rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin; + rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end; + + Array shape_like; + for (auto i = 0; i < lhs_begin; i++) { + shape_like.push_back(lhs_shape[i]); + } + for (auto i = rhs_begin; i < rhs_end; i++) { + shape_like.push_back(rhs_shape[i]); + } + for (auto i = lhs_end; i < lhs_shape_size; i++) { + shape_like.push_back(lhs_shape[i]); + } + return shape_like; +} + Array ReshapeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { // Quick path for reshape_like if (!attrs.as()) { - return {topi::reshape(inputs[0], inputs[1]->shape)}; + ICHECK(attrs.as() != nullptr); + auto shape_like = infer_reshape_like(inputs[0]->shape, inputs[1]->shape, attrs); + return {topi::reshape(inputs[0], shape_like)}; } const auto* out_ttype = out_type.as(); @@ -746,6 +785,7 @@ Example:: */ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + ICHECK(attrs.as() != nullptr); ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) { @@ -755,6 +795,7 @@ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs if (reshape_like == nullptr) { return false; } + auto shape_like = infer_reshape_like(data->shape, reshape_like->shape, attrs); // Only check When input data has static shape. bool is_static_shape = true; for (size_t i = 0; i < data->shape.size(); ++i) { @@ -763,17 +804,24 @@ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs break; } } + auto output_type = TensorType(shape_like, data->dtype); if (is_static_shape) { - ICHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) + ICHECK(reporter->AssertEQ(data->Size(), output_type->Size())) << "Reshape inputs size should be compatible."; } - reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype)); + reporter->Assign(types[2], output_type); return true; } -Expr MakeReshapeLike(Expr data, Expr shape_like) { +Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, + Integer rhs_end) { + auto attrs = make_object(); + attrs->lhs_begin = std::move(lhs_begin); + attrs->lhs_end = std::move(lhs_end); + attrs->rhs_begin = std::move(rhs_begin); + attrs->rhs_end = std::move(rhs_end); static const Op& op = Op::Get("reshape_like"); - return Call(op, {data, shape_like}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike); @@ -784,7 +832,15 @@ For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation re the input array into an output array with the same shape as the second input array. .. note:: Sizes for both array should be compatible. +Example:: + + data.shape == (1, 2, 3, 4) + shape_like.shape == (6, 2, 2, 3) + + ret = reshape_like(data, shape_like, lhs_begin=1, rhs_end=3) + ret.shape == (1, 6, 2, 2) )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape_like", "Tensor", "Shape tensor.") diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 555391a27e4b9..8ef86e0881938 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -594,9 +594,9 @@ inline Expr LeftShift(Expr x, Expr nbit) { return Call(op, {x, nbit}, Attrs(), {}); } -inline Expr ReshapeLike(Expr lhs, Expr rhs) { - static const Op& op = Op::Get("reshape_like"); - return Call(op, {lhs, rhs}, Attrs(), {}); +inline Expr ReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, + Integer rhs_end) { + return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end); } inline Expr Copy(Expr data) { diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f091856f6b7e3..90e6e870f370d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -316,17 +316,45 @@ def test_reshape_like_infer_type(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") + # partial reshaping + x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32")) + y = relay.var("y", relay.TensorType((1, 6, 5), "float32")) + z = relay.reshape_like(x, y, lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((1, 6, 4), "float32") + + x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32")) + y = relay.var("y", relay.TensorType((2, 3, 4, 1, 6), "float32")) + z = relay.reshape_like(x, y, rhs_end=3) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((2, 3, 4), "float32") + z = relay.reshape_like(x, y, rhs_begin=2) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((4, 1, 6), "float32") + + # symbolic partial reshaping + n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = relay.var("y", relay.TensorType((5, 6), "float32")) + z = relay.var("z", relay.TensorType((4,), "float32")) + w = relay.reshape_like(x, y, lhs_end=3) + w = relay.reshape_like(w, z, lhs_begin=2) + w = run_infer_type(w) + assert w.checked_type == relay.TensorType((5, 6, 4), "float32") + @tvm.testing.uses_gpu def test_reshape_like(): - def verify_reshape_like(shape, oshape): + def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}): + if shape_like is None: + shape_like = oshape x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") - y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32") - ref_res = np.reshape(x_data, y_data.shape) + y_data = np.random.uniform(low=-1, high=1, size=shape_like).astype("float32") + ref_res = np.reshape(x_data, oshape) x = relay.var("x", relay.TensorType(shape, "float32")) - y = relay.var("x", relay.TensorType(oshape, "float32")) - z = relay.reshape_like(x, y) + y = relay.var("x", relay.TensorType(shape_like, "float32")) + z = relay.reshape_like(x, y, **reshape_like_kwargs) zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32") @@ -340,6 +368,9 @@ def verify_reshape_like(shape, oshape): verify_reshape_like((2, 3, 4), (1, 8, 3)) verify_reshape_like((4, 7), (2, 7, 2)) + verify_reshape_like( + (1, 2, 3, 4), (1, 6, 4), (1, 6, 5), dict(lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2) + ) def test_take_infer_type():