From ec5037de9a6c1f3fd63e37122ce790146cc6b13a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 14 Oct 2024 11:13:50 -0700 Subject: [PATCH] cherry pick #3191 from main to release/2.5 --- core/runtime/TRTEngine.cpp | 4 +--- core/runtime/TRTEngine.h | 3 +-- core/runtime/execute_engine.cpp | 6 ++---- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 10 ++++------ 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index c2b9e6c35d..986bb11ecc 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -73,8 +73,6 @@ TRTEngine::TRTEngine( << get_current_platform() << ")"); this->target_platform = target_platform; - this->cudagraph_mempool_id = at::cuda::graph_pool_handle(); - this->hardware_compatible = hardware_compatible; auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); @@ -320,4 +318,4 @@ void TRTEngine::verify_serialization_fmt(const std::vector& seriali } // namespace runtime } // namespace core -} // namespace torch_tensorrt +} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index ebd5645d59..91ba46206b 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -81,7 +81,6 @@ struct TRTEngine : torch::CustomClassHolder { std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key; - at::cuda::MempoolId_t cudagraph_mempool_id; // TODO: Implement a call method // c10::List Run(c10::List inputs); @@ -104,4 +103,4 @@ struct TRTEngine : torch::CustomClassHolder { } // namespace runtime } // namespace core -} // namespace torch_tensorrt +} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 494431d3cb..7e6a20753f 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -305,8 +305,6 @@ std::vector execute_engine(std::vector inputs, c10::intr if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { // Create a new stream if the engine stream is the default stream compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); - } else { - compiled_engine->engine_stream = compiled_engine->caller_stream; } // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it. @@ -333,7 +331,7 @@ std::vector execute_engine(std::vector inputs, c10::intr if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; - compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id); + compiled_engine->cudagraph.capture_begin(); compiled_engine->exec_ctx->enqueueV3(recording_stream); compiled_engine->cudagraph.capture_end(); @@ -370,4 +368,4 @@ std::vector execute_engine(std::vector inputs, c10::intr } // namespace runtime } // namespace core -} // namespace torch_tensorrt +} // namespace torch_tensorrt \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f74c239550..eb22680cfb 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -5,6 +5,7 @@ from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module @@ -19,8 +20,6 @@ multi_gpu_device_check, ) -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -372,8 +371,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . or self._engine_stream is None ): self._engine_stream = torch.cuda.Stream() - else: - self._engine_stream = self._caller_stream self._engine_stream.wait_stream(self._caller_stream) @@ -464,7 +461,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: if new_shape_key != self.shape_key: logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}") self.shape_key = new_shape_key - self.cudagraph.reset() # type: ignore + if self.cudagraph: + self.cudagraph.reset() return False - return True + return True \ No newline at end of file