Skip to content

Commit

Permalink
[RELAY]reshape_like (apache#1950)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and Wei Chen committed Feb 19, 2019
1 parent 19d364d commit a1aeffa
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ This level enables additional math and transform operators.
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.reshape
tvm.relay.reshape_like
tvm.relay.copy
tvm.relay.transpose
tvm.relay.floor
Expand Down Expand Up @@ -189,6 +190,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.abs
.. autofunction:: tvm.relay.negative
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}

/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL IndexExpr Size() const;

TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);

/*! \brief Construct an scalar containing elements of dtype. */
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ def reshape(data, newshape):
return _make.reshape(data, list(newshape))


def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
Parameters
----------
data : relay.Expr
The input data to the operator.
shape_like : tuple of int
The new shape. Should be compatible with the original shape.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.reshape_like(data, shape_like)


def take(data, indices, axis=None):
"""Take elements from an array along an axis.
Expand Down
12 changes: 12 additions & 0 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype);
}

IndexExpr TensorTypeNode::Size() const {
if (shape.size() == 0) {
return make_const(Int(64), 1);
}

IndexExpr size = shape[0];
for (size_t i = 1; i < shape.size(); ++i) {
size *= shape[i];
}
return size;
}

TVM_REGISTER_NODE_TYPE(TensorTypeNode);

TVM_REGISTER_API("relay._make.TensorType")
Expand Down
56 changes: 56 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,62 @@ Example::
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel);


/*!
* \brief ReshapeLikeRel User defined type constraint function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return False if the relation has not been resolved, it might be resolved later.
* True if this relation has been resolved.
*/
bool ReshapeLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* reshape_like = types[1].as<TensorTypeNode>();
if (reshape_like == nullptr) {
return false;
}
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
<< "Reshape inputs size should be compatible.";
reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
return true;
}


Expr MakeReshapeLike(Expr data,
Expr shape_like) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {data, shape_like}, Attrs(), {});
}


TVM_REGISTER_API("relay.op._make.reshape_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
});


RELAY_REGISTER_OP("reshape_like")
.describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(3)
.add_type_rel("ReshapeLike", ReshapeLikeRel);


// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);

Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ def test_reshape_infer_type():
(n, t, 2000), "float32")


def test_reshape_like():
# concrete shape
x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
y = relay.var("y", relay.TensorType((1,6), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 6), "float32")

# symbolic shape
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((1, 8, 8), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")


def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None):
Expand Down Expand Up @@ -187,6 +203,7 @@ def test_infer_type_leaky_relu():
test_clip_type()
test_transpose_infer_type()
test_reshape_infer_type()
test_reshape_like()
test_take_infer_type()
test_full()
test_full_like()
Expand Down

0 comments on commit a1aeffa

Please sign in to comment.