diff --git a/core/compiler.cpp b/core/compiler.cpp index 6cb24b0641..db20003640 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -46,7 +46,7 @@ c10::FunctionSchema GenerateGraphSchema( void AddEngineToGraph( torch::jit::script::Module mod, std::shared_ptr& g, - std::string& serialized_engine) { + const std::string& serialized_engine) { auto engine_ptr = c10::make_intrusive(mod._ivalue()->name(), serialized_engine); // Get required metadata about the engine out auto num_io = engine_ptr->num_io; @@ -173,9 +173,9 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C return new_mod; } -torch::jit::script::Module EmbedEngineInNewModule(std::string& engine) { +torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) { std::ostringstream engine_id; - engine_id << reinterpret_cast(&engine); + engine_id << reinterpret_cast(&engine); torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str()); auto new_g = std::make_shared(); AddEngineToGraph(new_mod, new_g, engine); diff --git a/core/compiler.h b/core/compiler.h index 586ecf9662..a7d16c6b8d 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -19,7 +19,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg); -torch::jit::script::Module EmbedEngineInNewModule(std::string& engine); +torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine); void set_device(const int gpu_id); diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 4a9659cfe8..b033b45629 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -485,14 +485,15 @@ TRTORCH_API std::string ConvertGraphToTRTEngine( * @brief Take a previously created TensorRT engine and embed it in * in a TorchScript module * - * @param engine: std::string - Precompiled serialized TensorRT engine + * @param engine: std::string - Pre-built serialized TensorRT engine * - * Takes a prebuilt serialized TensorRT engine and embeds it in a TorchScript - * graph. Registers the engine as the forward method of the module + * Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript + * module. Registers execution of the engine as the forward method of the module + * Forward is defined as: forward(Tensor[]) -> Tensor[] * * @return: A new module trageting a TensorRT engine */ -TRTORCH_API torch::jit::Module EmbedEngineInNewModule(std::string& engine); +TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine); /** * @brief Set gpu device id diff --git a/cpp/api/src/trtorch.cpp b/cpp/api/src/trtorch.cpp index a00631aa5f..1a5083fc90 100644 --- a/cpp/api/src/trtorch.cpp +++ b/cpp/api/src/trtorch.cpp @@ -31,7 +31,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module return core::CompileGraph(module, to_internal_compile_spec(info)); } -torch::jit::Module EmbedEngineInNewModule(std::string& engine) { +torch::jit::Module EmbedEngineInNewModule(const std::string& engine) { return core::EmbedEngineInNewModule(engine); } diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index 65c91732e6..183644a065 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -124,6 +124,26 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec)) +def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule: + """Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module + + Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module. + Registers the forward method to execute the TensorRT engine with the function signature: + + forward(Tensor[]) -> Tensor[] + + Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules + + Args: + serialized_engine (bytes): Serialized TensorRT engine from either TRTorch or TensorRT APIs + + Returns: + torch.jit.ScriptModule: New TorchScript module with engine embedded + """ + cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine) + return torch.jit._recursive.wrap_cpp_module(cpp_mod) + + def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool: """Checks to see if a method is fully supported by TRTorch diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index cb3d1d4e39..74c38f5d73 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -119,6 +119,10 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str return core::CheckMethodOperatorSupport(module, method_name); } +torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine) { + return core::EmbedEngineInNewModule(engine); +} + std::string get_build_info() { auto info = core::util::get_build_info(); return info; @@ -270,6 +274,10 @@ PYBIND11_MODULE(_C, m) { "check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators"); + m.def( + "embed_engine_in_new_module", + &trtorch::pyapi::EmbedEngineInNewModule, + "Takes a serialized TensorRT engine and wraps it in the forward method of a new TorchScript module"); m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string"); m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output"); diff --git a/tests/modules/test_modules_as_engines.cpp b/tests/modules/test_modules_as_engines.cpp index 2d1bcaba75..5fb1cf5862 100644 --- a/tests/modules/test_modules_as_engines.cpp +++ b/tests/modules/test_modules_as_engines.cpp @@ -16,7 +16,7 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5)); } -TEST_P(ModuleTests, ModuleToModuleIsClose) { +TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) { std::vector inputs; std::vector inputs_ivalues; for (auto in_shape : input_shapes) { diff --git a/tests/py/BUILD b/tests/py/BUILD index 510b3f681e..65a424466e 100644 --- a/tests/py/BUILD +++ b/tests/py/BUILD @@ -30,7 +30,7 @@ py_test( srcs = [ "test_ptq_dataloader_calibrator.py", "model_test_case.py" - ] + ], deps = [ requirement("torchvision") ] @@ -43,7 +43,7 @@ py_test( srcs = [ "test_ptq_trt_calibrator.py", "model_test_case.py" - ] + ], deps = [ requirement("torchvision") ] @@ -56,8 +56,6 @@ py_test( "test_multi_gpu.py", "model_test_case.py" ], - "//conditions:default" : [] - }), deps = [ requirement("torchvision") ] @@ -74,12 +72,23 @@ py_test( ] ) +py_test( + name = "test_trt_intercompatability", + srcs = [ + "test_trt_intercompatability.py", + "model_test_case.py" + ], + deps = [ + requirement("torchvision") + ] +) + py_test( name = "test_ptq_to_backend", srcs = [ "test_ptq_to_backend.py", "model_test_case.py" - ] + ], deps = [ requirement("torchvision") ] diff --git a/tests/py/test_api.py b/tests/py/test_api.py index a21385f6e1..f28ef9cf5b 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -45,6 +45,27 @@ def test_compile_script(self): same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() self.assertTrue(same < 2e-3) +class TestPTtoTRTtoPT(ModelTestCase): + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.ts_model = torch.jit.script(self.model) + + def test_pt_to_trt_to_pt(self): + compile_spec = { + "input_shapes": [self.input.shape], + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + "disable_tf32": False + } + } + + trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec) + trt_mod = trtorch.embed_engine_in_new_module(trt_engine) + same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max() + self.assertTrue(same < 2e-3) class TestCheckMethodOpSupport(unittest.TestCase): @@ -59,13 +80,13 @@ def test_check_support(self): class TestLoggingAPIs(unittest.TestCase): def test_logging_prefix(self): - new_prefix = "TEST" + new_prefix = "Python API Test: " trtorch.logging.set_logging_prefix(new_prefix) logging_prefix = trtorch.logging.get_logging_prefix() self.assertEqual(new_prefix, logging_prefix) def test_reportable_log_level(self): - new_level = trtorch.logging.Level.Warning + new_level = trtorch.logging.Level.Error trtorch.logging.set_reportable_log_level(new_level) level = trtorch.logging.get_reportable_log_level() self.assertEqual(new_level, level) @@ -78,10 +99,11 @@ def test_is_colored_output_on(self): def test_suite(): suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestLoggingAPIs)) suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True))) suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True))) + suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True))) suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport)) - suite.addTest(unittest.makeSuite(TestLoggingAPIs)) return suite diff --git a/tests/py/test_trt_intercompatability.py b/tests/py/test_trt_intercompatability.py new file mode 100644 index 0000000000..bc54ceafb2 --- /dev/null +++ b/tests/py/test_trt_intercompatability.py @@ -0,0 +1,51 @@ +import unittest +import trtorch +import torch +import torchvision.models as models +import tensorrt as trt + +from model_test_case import ModelTestCase + + +class TestPyTorchToTRTEngine(ModelTestCase): + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda:0") + self.ts_model = torch.jit.script(self.model) + + def test_pt_to_trt(self): + compile_spec = { + "input_shapes": [self.input.shape], + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": False, + "disable_tf32": False + } + } + + trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec) + + TRT_LOGGER = trt.Logger(trt.Logger.WARNING) + with trt.Runtime(TRT_LOGGER) as rt: + engine = rt.deserialize_cuda_engine(trt_engine) + with engine.create_execution_context() as ctx: + out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0") + bindings = [self.input.contiguous().data_ptr(), out.contiguous().data_ptr()] + ctx.execute_async(batch_size=1, bindings=bindings, stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream) + same = (out - self.ts_model(self.input)).abs().max() + self.assertTrue(same < 2e-3) + +def test_suite(): + suite = unittest.TestSuite() + suite.addTest(TestPyTorchToTRTEngine.parametrize(TestPyTorchToTRTEngine, model=models.resnet18(pretrained=True))) + + return suite + + +suite = test_suite() + +runner = unittest.TextTestRunner() +result = runner.run(suite) + +exit(int(not result.wasSuccessful()))