From 8a23866b00d768eb2d8737f709a53629a4f209da Mon Sep 17 00:00:00 2001 From: Rohan Mukherjee Date: Tue, 27 Oct 2020 23:27:22 +0000 Subject: [PATCH] Changes in CheckReshapeOnly to support TupleTypes as input This arises insed ManifestAllocPass inside relay.vm.compile --- python/tvm/relay/transform/memory_alloc.py | 5 +++++ tests/python/relay/test_vm.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index f611c1cc14c1..66528c861788 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -84,6 +84,11 @@ def visit_call(self, call): for arg in call.args: self.visit(arg) + def visit_var(self, var): + var_type = var.checked_type + if not isinstance(var_type, ty.TensorType): + self.reshape_only = False + def is_reshape_only(func): """Check if the primitive function contains only reshape ops.""" diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 038b5c5ed9e1..55bafe0ec8ea 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -754,5 +754,21 @@ def test_vm_reshape_tensor(): check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod) +def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)): + tup = relay.var( + "tup", + type_annotation=relay.TupleType([relay.TensorType(x_shape), relay.TensorType(y_shape)]), + ) + out = relay.reshape(relay.TupleGetItem(tup, 0), (1, -1)) + f = relay.Function([tup], out) + + x_data = np.random.uniform(size=x_shape).astype("float32") + y_data = np.random.uniform(size=y_shape).astype("float32") + + for tgt, ctx in tvm.testing.enabled_targets(): + res = veval(f, (x_data, y_data)) + tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1))) + + if __name__ == "__main__": pytest.main([__file__])