diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index db9684d02ac9..afdea9712342 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -615,7 +615,6 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable @_register_external_dynamic_check_func("reshape") def reshape_annotate_fn(expr): # pylint: disable=unused-variable """Check if reshape is supported by TensorRT.""" - attrs, args = expr.attrs, expr.args if args[0].checked_type.dtype != "float32": logger.info("Only float32 inputs are supported for TensorRT.") @@ -629,23 +628,23 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable if len(new_shape) == 0 or len(shape) == 0: logger.info("reshape: Can't reshape to or from scalar.") return False - dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape]) if dynamic_reshape: # Make sure that the batch dim is unmodified. if int(new_shape[0]) < 0: - for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]): + for shape_val, new_shape_val in zip(shape[1:], new_shape[1:]): if not ( - isinstance(shape_val, int) - and isinstance(new_shape_val, int) + isinstance(shape_val, (int, tvm.tir.expr.IntImm)) + and isinstance(new_shape_val, (int, tvm.tir.expr.IntImm)) and int(shape_val) == int(new_shape_val) ): return False elif int(new_shape[0]) > 0: + # Currently we only allow dim[0] to be Any, so this branch will always be False if not ( - isinstance(shape[0], int) - and isinstance(new_shape[0], int) + isinstance(shape[0], (int, tvm.tir.expr.IntImm)) + and isinstance(new_shape[0], (int, tvm.tir.expr.IntImm)) and int(shape[0]) == int(new_shape[0]) ): return False diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index bd8d92eedb4c..7ddc4e762cfd 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -27,6 +27,7 @@ from tvm.contrib import graph_runtime, utils from tvm.runtime.vm import VirtualMachine from tvm.relay import Any, GlobalVar, transform +from tvm.relay.expr_functor import ExprVisitor from typing import Dict, Tuple, Union from tvm.contrib.download import download from tvm.relay.op.contrib import tensorrt @@ -631,6 +632,106 @@ def get_graph(x_shape, new_shape): run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6))) +class AreOpsOnGraph(ExprVisitor): + """ + Visits the Graph recursively and checks if it contains ops in the op_list + """ + + def __init__(self, op_list): + ExprVisitor.__init__(self) + self.op_list = op_list + self.on_graph = False + + def visit_call(self, call): + if isinstance(call.op, tvm.tir.op.Op): + if str(call.op) in self.op_list: + self.on_graph = True + + return super().visit_call(call) + + def are_ops_on_graph(self, subgraph) -> bool: + """ + This function recursively visits the graph and checks if op_list ops are ongraph" + """ + self.visit(subgraph) + return self.on_graph + + +def are_ops_on_trt(mod, op_list): + for subgraph in mod.get_global_vars(): + name = subgraph.name_hint + op_on_trt = False + op_on_tvm = True + if name == "main": + op_on_tvm = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) + elif mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt": + op_on_trt = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) + else: + op_on_tvm &= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) + + if not op_on_trt or op_on_tvm: + return False + + return True + + +def test_dynamic_reshape(): + if skip_codegen_test(): + return + + def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): + result_arr = [{} for _ in range(len(x_data_list))] + for use_trt in [True, False]: + x = relay.var("x", shape=x_shape, dtype="float32") + out = relay.reshape(x, new_shape) + f = relay.Function([x], out) + mod = tvm.IRModule() + mod["main"] = f + if use_trt: + mod, _ = tensorrt.partition_for_tensorrt( + mod, params={}, remove_no_mac_subgraphs=False + ) + assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt + if not skip_runtime_test(): + with relay.build_config(opt_level=3): + relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") + + for i, x_data in enumerate(x_data_list): + result_arr[i][use_trt] = relay_exec.evaluate()(x_data) + + if not skip_runtime_test(): + for i in range(len(x_data_list)): + assert_result_dict_holds(result_arr[i]) + + dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2] + x_shape = (relay.Any(), 3, 2, 3) + x_data_list = [ + np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values + ] + new_shape = (-1, 3, 2, 3) + should_offload_to_trt = True + test_run(x_data_list, x_shape, new_shape, should_offload_to_trt) + + dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2] + x_shape = (relay.Any(), 3, 2, 3) + x_data_list = [ + np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values + ] + new_shape = (-1, 1, 2, 3) + should_offload_to_trt = False + test_run(x_data_list, x_shape, new_shape, should_offload_to_trt) + + dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2] + x_shape = (1, relay.Any(), 2, 3) + x_data_list = [ + np.ones(list(x_shape[:1]) + [dim_value] + list(x_shape)[2:]).astype("float32") + for dim_value in dim_values + ] + new_shape = (1, -1, 2, 3) + should_offload_to_trt = False + test_run(x_data_list, x_shape, new_shape, should_offload_to_trt) + + def test_transpose(): def get_graph(x_shape, order): x = relay.var("x", shape=(x_shape), dtype="float32")