Skip to content

Commit

Permalink
cherry pick #3191 from main to release/2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia committed Oct 14, 2024
1 parent 348d21a commit ec5037d
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 15 deletions.
4 changes: 1 addition & 3 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ TRTEngine::TRTEngine(
<< get_current_platform() << ")");
this->target_platform = target_platform;

this->cudagraph_mempool_id = at::cuda::graph_pool_handle();

this->hardware_compatible = hardware_compatible;
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
Expand Down Expand Up @@ -320,4 +318,4 @@ void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& seriali

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
3 changes: 1 addition & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
at::cuda::MempoolId_t cudagraph_mempool_id;

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
Expand All @@ -104,4 +103,4 @@ struct TRTEngine : torch::CustomClassHolder {

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
6 changes: 2 additions & 4 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
// Create a new stream if the engine stream is the default stream
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
} else {
compiled_engine->engine_stream = compiled_engine->caller_stream;
}

// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
Expand All @@ -333,7 +331,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
if (need_cudagraphs_record) {
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id);
compiled_engine->cudagraph.capture_begin();
compiled_engine->exec_ctx->enqueueV3(recording_stream);
compiled_engine->cudagraph.capture_end();

Expand Down Expand Up @@ -370,4 +368,4 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tempfile import tempdir
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
Expand All @@ -19,8 +20,6 @@
multi_gpu_device_check,
)

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -372,8 +371,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()
else:
self._engine_stream = self._caller_stream

self._engine_stream.wait_stream(self._caller_stream)

Expand Down Expand Up @@ -464,7 +461,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
if new_shape_key != self.shape_key:
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
self.shape_key = new_shape_key
self.cudagraph.reset() # type: ignore
if self.cudagraph:
self.cudagraph.reset()
return False

return True
return True

0 comments on commit ec5037d

Please sign in to comment.