diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index dcb770b9a563..1ed29cf29d6e 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -38,7 +38,7 @@ @register_parser def add_compile_parser(subparsers): - """ Include parser for 'compile' subcommand """ + """Include parser for 'compile' subcommand""" parser = subparsers.add_parser("compile", help="compile a model") parser.set_defaults(func=drive_compile) @@ -214,9 +214,11 @@ def compile_model( for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) partition_function = codegen["pass_pipeline"] - mod = partition_function(mod, params, **codegen_from_cli["opts"]) + mod, codegen_config = partition_function(mod, params, **codegen_from_cli["opts"]) if codegen["config_key"] is not None: - config[codegen["config_key"]] = codegen_from_cli["opts"] + config[codegen["config_key"]] = ( + codegen_config if codegen_config else codegen_from_cli["opts"] + ) if tuning_records and os.path.exists(tuning_records): logger.debug("tuning records file provided: %s", tuning_records) diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index ac1a41a0c4a9..b98fa1574275 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -26,7 +26,7 @@ from tvm.relay.op.contrib.ethosn import partition_for_ethosn from tvm.relay.op.contrib.bnns import partition_for_bnns from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai - +from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt from .common import TVMCException @@ -61,6 +61,10 @@ "config_key": "relay.ext.vitis_ai.options", "pass_pipeline": partition_for_vitis_ai, }, + "tensorrt": { + "config_key": "relay.ext.tensorrt.options", + "pass_pipeline": partition_for_tensorrt, + }, } diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 16c02335c8a0..9f558454b0ae 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -330,6 +330,20 @@ def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v assert type(tvmc_package.params) is bytearray assert os.path.exists(dumps_path) +def test_compile_tflite_module_with_external_codegen_tensorrt(tflite_mobilenet_v1_0_25_128): + pytest.importorskip("tflite") + + tvmc_model = tvmc.load(tflite_mobilenet_v1_0_25_128) + tvmc_package = tvmc.compiler.compile_model( + tvmc_model, target="tensorrt, llvm", dump_code="relay") + dumps_path = tvmc_package.package_path + ".relay" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) @mock.patch("tvm.relay.build") @mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target") diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py index 0a0b45eeb970..0fd48810f3e3 100644 --- a/tests/python/driver/tvmc/test_composite_target.py +++ b/tests/python/driver/tvmc/test_composite_target.py @@ -35,6 +35,7 @@ def test_get_codegen_names(): assert "ethos-n77" in names assert "vitis-ai" in names + assert "tensorrt" in names assert len(names) > 0