Skip to content

Commit

Permalink
[RELAY][OP] Support MXNet-style attributes for reshape_like (apache#6851
Browse files Browse the repository at this point in the history
)

* add MXNet-style reshape_like attrs support

* lint

* document, switch to int, add more tests, style

* add example usage in documentation

* fix doc formatting
  • Loading branch information
altanh authored and Trevor Morris committed Dec 4, 2020
1 parent 4a0b368 commit 26e3e2a
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 21 deletions.
20 changes: 20 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}
}; // struct ReshapeAttrs

/*! \brief Attributes used in MXNet-style reshape_like operators */
struct ReshapeLikeAttrs : public tvm::AttrsNode<ReshapeLikeAttrs> {
int lhs_begin;
Integer lhs_end; // can be None
int rhs_begin;
Integer rhs_end; // can be None
TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs") {
TVM_ATTR_FIELD(lhs_begin).set_default(0).describe(
"The axis of the input where reshaping should begin.");
TVM_ATTR_FIELD(lhs_end)
.set_default(NullValue<Integer>())
.describe("The axis of the input where reshaping should end, exclusive.");
TVM_ATTR_FIELD(rhs_begin).set_default(0).describe(
"The axis of the shape_like tensor to begin taking dimensions from.");
TVM_ATTR_FIELD(rhs_end)
.set_default(NullValue<Integer>())
.describe("The axis of the shape_like tensor to end taking dimensions from, exclusive.");
}
}; // struct ReshapeLikeAttrs

struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
Integer axis;

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""


@tvm._ffi.register_object("relay.attrs.ReshapeLikeAttrs")
class ReshapeLikeAttrs(Attrs):
"""Attributes for transform.reshape_like"""


@tvm._ffi.register_object("relay.attrs.GatherAttrs")
class GatherAttrs(Attrs):
"""Attributes for transform.gather"""
Expand Down
43 changes: 35 additions & 8 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,55 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
"""Reshapes the input tensor by the size of another tensor.
For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes
the input tensor into an output tensor with the same shape as the second input tensor,
in particular reshaping the dimensions of `data` in `[lhs_begin, lhs_end)` using the dimensions
from `shape_like` in `[rhs_begin, rhs_end)`.
.. note::
Sizes for both array should be compatible.
Sizes for `data` and the output tensor should be compatible.
Parameters
----------
data : relay.Expr
The input data to the operator.
shape_like : tuple of int
The new shape. Should be compatible with the original shape.
shape_like : relay.Expr
The tensor to reshape data like. Should be compatible with the original shape on the
reshaped dimensions.
lhs_begin : int, optional
The axis of data to begin reshaping. Default is 0.
lhs_end : int or None, optional
The axis of data where reshaping should stop, exclusive. Default is None which reshapes to
the end.
rhs_begin : int, optional
The axis of shape_like where the target shape begins. Default is 0.
rhs_end : int or None, optional
The axis of shape_like where the target shape ends, exclusive. Default is None which extends
to the end.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data.shape == (1, 2, 3, 4)
shape_like.shape == (6, 2, 2, 3)
ret = relay.reshape_like(data, shape_like, lhs_begin=1, rhs_end=3)
ret.shape == (1, 6, 2, 2)
"""
return _make.reshape_like(data, shape_like)
return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end)


def take(data, indices, axis=None, mode="clip"):
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Expr MakeRepeat(Expr data, int repeats, int axis);

Expr MakeReshape(Expr data, Array<Integer> newshape);

Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
Integer rhs_end);

Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis);

Expr MakeSqueeze(Expr data, Array<Integer> axis);
Expand Down
66 changes: 61 additions & 5 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ RELAY_REGISTER_OP("transpose")

/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);

Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs) {
const auto* param = attrs.as<ReshapeAttrs>();
Expand Down Expand Up @@ -641,11 +642,49 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

Array<PrimExpr> infer_reshape_like(const Array<PrimExpr>& lhs_shape,
const Array<PrimExpr>& rhs_shape, const Attrs& attrs) {
const auto* like_attrs = attrs.as<ReshapeLikeAttrs>();
CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as<IntImmNode>())
<< "lhs_end must be a concrete integer or None";
CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as<IntImmNode>())
<< "rhs_end must be a concrete integer or None";

int64_t lhs_shape_size = static_cast<int64_t>(lhs_shape.size());
int64_t rhs_shape_size = static_cast<int64_t>(rhs_shape.size());
int64_t lhs_begin = static_cast<int64_t>(like_attrs->lhs_begin);
int64_t lhs_end =
like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as<IntImmNode>()->value : lhs_shape_size;
int64_t rhs_begin = static_cast<int64_t>(like_attrs->rhs_begin);
int64_t rhs_end =
like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as<IntImmNode>()->value : rhs_shape_size;

// handle negative axes
lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin;
lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end;
rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin;
rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end;

Array<PrimExpr> shape_like;
for (auto i = 0; i < lhs_begin; i++) {
shape_like.push_back(lhs_shape[i]);
}
for (auto i = rhs_begin; i < rhs_end; i++) {
shape_like.push_back(rhs_shape[i]);
}
for (auto i = lhs_end; i < lhs_shape_size; i++) {
shape_like.push_back(lhs_shape[i]);
}
return shape_like;
}

Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
// Quick path for reshape_like
if (!attrs.as<ReshapeAttrs>()) {
return {topi::reshape(inputs[0], inputs[1]->shape)};
ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
auto shape_like = infer_reshape_like(inputs[0]->shape, inputs[1]->shape, attrs);
return {topi::reshape(inputs[0], shape_like)};
}

const auto* out_ttype = out_type.as<TensorTypeNode>();
Expand Down Expand Up @@ -746,6 +785,7 @@ Example::
*/
bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
Expand All @@ -755,6 +795,7 @@ bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
if (reshape_like == nullptr) {
return false;
}
auto shape_like = infer_reshape_like(data->shape, reshape_like->shape, attrs);
// Only check When input data has static shape.
bool is_static_shape = true;
for (size_t i = 0; i < data->shape.size(); ++i) {
Expand All @@ -763,17 +804,24 @@ bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
break;
}
}
auto output_type = TensorType(shape_like, data->dtype);
if (is_static_shape) {
ICHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
ICHECK(reporter->AssertEQ(data->Size(), output_type->Size()))
<< "Reshape inputs size should be compatible.";
}
reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype));
reporter->Assign(types[2], output_type);
return true;
}

Expr MakeReshapeLike(Expr data, Expr shape_like) {
Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
Integer rhs_end) {
auto attrs = make_object<ReshapeLikeAttrs>();
attrs->lhs_begin = std::move(lhs_begin);
attrs->lhs_end = std::move(lhs_end);
attrs->rhs_begin = std::move(rhs_begin);
attrs->rhs_end = std::move(rhs_end);
static const Op& op = Op::Get("reshape_like");
return Call(op, {data, shape_like}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike);
Expand All @@ -784,7 +832,15 @@ For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation re
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
Example::
data.shape == (1, 2, 3, 4)
shape_like.shape == (6, 2, 2, 3)
ret = reshape_like(data, shape_like, lhs_begin=1, rhs_end=3)
ret.shape == (1, 6, 2, 2)
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReshapeLikeAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ inline Expr LeftShift(Expr x, Expr nbit) {
return Call(op, {x, nbit}, Attrs(), {});
}

inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return Call(op, {lhs, rhs}, Attrs(), {});
inline Expr ReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
Integer rhs_end) {
return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end);
}

inline Expr Copy(Expr data) {
Expand Down
41 changes: 36 additions & 5 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,45 @@ def test_reshape_like_infer_type():
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")

# partial reshaping
x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32"))
y = relay.var("y", relay.TensorType((1, 6, 5), "float32"))
z = relay.reshape_like(x, y, lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((1, 6, 4), "float32")

x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32"))
y = relay.var("y", relay.TensorType((2, 3, 4, 1, 6), "float32"))
z = relay.reshape_like(x, y, rhs_end=3)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((2, 3, 4), "float32")
z = relay.reshape_like(x, y, rhs_begin=2)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((4, 1, 6), "float32")

# symbolic partial reshaping
n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((5, 6), "float32"))
z = relay.var("z", relay.TensorType((4,), "float32"))
w = relay.reshape_like(x, y, lhs_end=3)
w = relay.reshape_like(w, z, lhs_begin=2)
w = run_infer_type(w)
assert w.checked_type == relay.TensorType((5, 6, 4), "float32")


@tvm.testing.uses_gpu
def test_reshape_like():
def verify_reshape_like(shape, oshape):
def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}):
if shape_like is None:
shape_like = oshape
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32")
ref_res = np.reshape(x_data, y_data.shape)
y_data = np.random.uniform(low=-1, high=1, size=shape_like).astype("float32")
ref_res = np.reshape(x_data, oshape)

x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("x", relay.TensorType(oshape, "float32"))
z = relay.reshape_like(x, y)
y = relay.var("x", relay.TensorType(shape_like, "float32"))
z = relay.reshape_like(x, y, **reshape_like_kwargs)
zz = run_infer_type(z)
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")

Expand All @@ -340,6 +368,9 @@ def verify_reshape_like(shape, oshape):

verify_reshape_like((2, 3, 4), (1, 8, 3))
verify_reshape_like((4, 7), (2, 7, 2))
verify_reshape_like(
(1, 2, 3, 4), (1, 6, 4), (1, 6, 5), dict(lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)
)


def test_take_infer_type():
Expand Down

0 comments on commit 26e3e2a

Please sign in to comment.