Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][OP] Support MXNet-style attributes for reshape_like #6851

Merged
merged 5 commits into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}
}; // struct ReshapeAttrs

/*! \brief Attributes used in MXNet-style reshape_like operators */
struct ReshapeLikeAttrs : public tvm::AttrsNode<ReshapeLikeAttrs> {
int lhs_begin;
Integer lhs_end; // can be None
int rhs_begin;
Integer rhs_end; // can be None
TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs") {
TVM_ATTR_FIELD(lhs_begin).set_default(0).describe(
"The axis of the input where reshaping should begin.");
TVM_ATTR_FIELD(lhs_end)
.set_default(NullValue<Integer>())
.describe("The axis of the input where reshaping should end, exclusive.");
TVM_ATTR_FIELD(rhs_begin).set_default(0).describe(
"The axis of the shape_like tensor to begin taking dimensions from.");
TVM_ATTR_FIELD(rhs_end)
.set_default(NullValue<Integer>())
.describe("The axis of the shape_like tensor to end taking dimensions from, exclusive.");
}
}; // struct ReshapeLikeAttrs

struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
Integer axis;

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""


@tvm._ffi.register_object("relay.attrs.ReshapeLikeAttrs")
class ReshapeLikeAttrs(Attrs):
"""Attributes for transform.reshape_like"""


@tvm._ffi.register_object("relay.attrs.GatherAttrs")
class GatherAttrs(Attrs):
"""Attributes for transform.gather"""
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def reshape_like(data, shape_like):
def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
altanh marked this conversation as resolved.
Show resolved Hide resolved
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
Expand All @@ -329,7 +329,7 @@ def reshape_like(data, shape_like):
ret : relay.Expr
The computed result.
"""
return _make.reshape_like(data, shape_like)
return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end)


def take(data, indices, axis=None, mode="clip"):
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Expr MakeRepeat(Expr data, int repeats, int axis);

Expr MakeReshape(Expr data, Array<Integer> newshape);

Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why {lhs,rhs}_begin is int64_t and {lhs,rhs}_end is Integer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this because the beginning index (in both cases) must always be an integer, but the end index can be None which means I must use a nullable Integer wrapper. I could make everything Integer and check that beginning is always defined. I did feel a bit weird using int64_t directly since everything else seem to use int but the value wrapped by Integer is int64_t so that's why I chose it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the Attrs I defined uses int and not int64_t so I'll probably just use int since other code mostly uses it.

Integer rhs_end);

Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis);

Expr MakeSqueeze(Expr data, Array<Integer> axis);
Expand Down
55 changes: 50 additions & 5 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ RELAY_REGISTER_OP("transpose")

/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);

Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs) {
const auto* param = attrs.as<ReshapeAttrs>();
Expand Down Expand Up @@ -641,11 +642,45 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

Array<PrimExpr> infer_reshape_like(const Array<PrimExpr>& lhs_shape,
const Array<PrimExpr>& rhs_shape, const Attrs& attrs) {
const auto* like_attrs = attrs.as<ReshapeLikeAttrs>();
CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as<IntImmNode>())
altanh marked this conversation as resolved.
Show resolved Hide resolved
<< "lhs_end must be a concrete integer or None";
CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as<IntImmNode>())
<< "rhs_end must be a concrete integer or None";
int64_t lhs_shape_size = static_cast<int64_t>(lhs_shape.size());
int64_t rhs_shape_size = static_cast<int64_t>(rhs_shape.size());
int64_t lhs_begin = like_attrs->lhs_begin;
int64_t lhs_end =
like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as<IntImmNode>()->value : lhs_shape_size;
int64_t rhs_begin = like_attrs->rhs_begin;
int64_t rhs_end =
like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as<IntImmNode>()->value : rhs_shape_size;
lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin;
lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end;
rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin;
rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end;
Array<PrimExpr> shape_like;
for (auto i = 0; i < lhs_begin; i++) {
shape_like.push_back(lhs_shape[i]);
}
for (auto i = rhs_begin; i < rhs_end; i++) {
shape_like.push_back(rhs_shape[i]);
}
for (auto i = lhs_end; i < lhs_shape_size; i++) {
shape_like.push_back(lhs_shape[i]);
}
return shape_like;
}

Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
// Quick path for reshape_like
if (!attrs.as<ReshapeAttrs>()) {
return {topi::reshape(inputs[0], inputs[1]->shape)};
ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
auto shape_like = infer_reshape_like(inputs[0]->shape, inputs[1]->shape, attrs);
return {topi::reshape(inputs[0], shape_like)};
}

const auto* out_ttype = out_type.as<TensorTypeNode>();
Expand Down Expand Up @@ -746,6 +781,7 @@ Example::
*/
bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
Expand All @@ -755,6 +791,7 @@ bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
if (reshape_like == nullptr) {
return false;
}
auto shape_like = infer_reshape_like(data->shape, reshape_like->shape, attrs);
// Only check When input data has static shape.
bool is_static_shape = true;
for (size_t i = 0; i < data->shape.size(); ++i) {
Expand All @@ -763,17 +800,24 @@ bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
break;
}
}
auto output_type = TensorType(shape_like, data->dtype);
if (is_static_shape) {
ICHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
ICHECK(reporter->AssertEQ(data->Size(), output_type->Size()))
<< "Reshape inputs size should be compatible.";
}
reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype));
reporter->Assign(types[2], output_type);
return true;
}

Expr MakeReshapeLike(Expr data, Expr shape_like) {
Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin,
Integer rhs_end) {
auto attrs = make_object<ReshapeLikeAttrs>();
attrs->lhs_begin = std::move(lhs_begin);
attrs->lhs_end = std::move(lhs_end);
attrs->rhs_begin = std::move(rhs_begin);
attrs->rhs_end = std::move(rhs_end);
static const Op& op = Op::Get("reshape_like");
return Call(op, {data, shape_like}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike);
Expand All @@ -785,6 +829,7 @@ the input array into an output array with the same shape as the second input arr
.. note::
Sizes for both array should be compatible.
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReshapeLikeAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ inline Expr LeftShift(Expr x, Expr nbit) {
return Call(op, {x, nbit}, Attrs(), {});
}

inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return Call(op, {lhs, rhs}, Attrs(), {});
inline Expr ReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin,
Integer rhs_end) {
return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end);
}

inline Expr Copy(Expr data) {
Expand Down
32 changes: 27 additions & 5 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,36 @@ def test_reshape_like_infer_type():
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")

# partial reshaping
x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32"))
y = relay.var("y", relay.TensorType((1, 6, 5), "float32"))
z = relay.reshape_like(x, y, lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((1, 6, 4), "float32")

# symbolic partial reshaping
n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((5, 6), "float32"))
z = relay.var("z", relay.TensorType((4,), "float32"))
w = relay.reshape_like(x, y, lhs_end=3)
w = relay.reshape_like(w, z, lhs_begin=2)
altanh marked this conversation as resolved.
Show resolved Hide resolved
w = run_infer_type(w)
assert w.checked_type == relay.TensorType((5, 6, 4), "float32")


@tvm.testing.uses_gpu
def test_reshape_like():
def verify_reshape_like(shape, oshape):
def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}):
if shape_like is None:
shape_like = oshape
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32")
ref_res = np.reshape(x_data, y_data.shape)
y_data = np.random.uniform(low=-1, high=1, size=shape_like).astype("float32")
ref_res = np.reshape(x_data, oshape)

x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("x", relay.TensorType(oshape, "float32"))
z = relay.reshape_like(x, y)
y = relay.var("x", relay.TensorType(shape_like, "float32"))
z = relay.reshape_like(x, y, **reshape_like_kwargs)
zz = run_infer_type(z)
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")

Expand All @@ -340,6 +359,9 @@ def verify_reshape_like(shape, oshape):

verify_reshape_like((2, 3, 4), (1, 8, 3))
verify_reshape_like((4, 7), (2, 7, 2))
verify_reshape_like(
(1, 2, 3, 4), (1, 6, 4), (1, 6, 5), dict(lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)
)


def test_take_infer_type():
Expand Down