Skip to content

Commit b47e926

Browse files
authored
Merge pull request #658 from guoruoqian/fix_runtime_thread_safety
Make TRTorch runtime thread safe
2 parents 13293b3 + dcf14a7 commit b47e926

File tree

4 files changed

+101
-0
lines changed

4 files changed

+101
-0
lines changed

core/runtime/register_trt_op.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
112112
}
113113

114114
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
115+
116+
// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex for it.
117+
std::unique_lock<std::mutex> lock(compiled_engine->mu);
115118
compiled_engine->exec_ctx->enqueueV2(gpu_handles.data(), stream, nullptr);
116119

117120
return outputs;

core/runtime/runtime.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <map>
33
#include <memory>
4+
#include <mutex>
45
#include <utility>
56
#include "ATen/core/function_schema.h"
67
#include "NvInfer.h"
@@ -47,6 +48,7 @@ struct TRTEngine : torch::CustomClassHolder {
4748
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
4849
std::pair<uint64_t, uint64_t> num_io;
4950
std::string name;
51+
std::mutex mu;
5052
CudaDevice device_info;
5153

5254
std::unordered_map<uint64_t, uint64_t> in_binding_map;

tests/cpp/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ test_suite(
1414
":test_default_input_types",
1515
":test_compiled_modules",
1616
":test_modules_as_engines",
17+
":test_runtime_thread_safety",
1718
":test_multiple_registered_engines",
1819
":test_serialization",
1920
":test_module_fallback",
@@ -27,6 +28,7 @@ test_suite(
2728
":test_default_input_types",
2829
":test_compiled_modules",
2930
":test_modules_as_engines",
31+
":test_runtime_thread_safety",
3032
":test_multiple_registered_engines",
3133
":test_serialization",
3234
":test_module_fallback",
@@ -95,6 +97,17 @@ cc_test(
9597
timeout="long"
9698
)
9799

100+
cc_test(
101+
name = "test_runtime_thread_safety",
102+
srcs = ["test_runtime_thread_safety.cpp"],
103+
data = [
104+
"//tests/modules:jit_models",
105+
],
106+
deps = [
107+
":cpp_api_test",
108+
]
109+
)
110+
98111
cc_test(
99112
name = "test_module_fallback",
100113
srcs = ["test_module_fallback.cpp"],
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#include <string>
2+
#include <thread>
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/script.h"
6+
#include "trtorch/trtorch.h"
7+
8+
void run_infer(
9+
int thread_id,
10+
torch::jit::Module& mod,
11+
torch::jit::Module& trt_mod,
12+
const std::vector<torch::jit::IValue> inputs,
13+
const std::vector<torch::jit::IValue> inputs_trt,
14+
std::vector<torch::jit::IValue>& out_vec,
15+
std::vector<torch::jit::IValue>& trt_out_vec) {
16+
int count = 10;
17+
while (count-- > 0) {
18+
out_vec[thread_id] = mod.forward(inputs);
19+
trt_out_vec[thread_id] = trt_mod.forward(inputs_trt);
20+
}
21+
}
22+
23+
TEST(CppAPITests, RuntimeThreadSafety) {
24+
std::string path = "tests/modules/resnet50_traced.jit.pt";
25+
torch::jit::Module mod;
26+
try {
27+
// Deserialize the ScriptModule from a file using torch::jit::load().
28+
mod = torch::jit::load(path);
29+
} catch (const c10::Error& e) {
30+
std::cerr << "error loading the model\n";
31+
}
32+
mod.eval();
33+
mod.to(torch::kCUDA);
34+
35+
torch::Tensor in_jit = at::randint(5, {1, 3, 224, 224}, torch::kCUDA).to(torch::kFloat);
36+
torch::Tensor in_trt = in_jit.clone().to(torch::kFloat);
37+
38+
std::vector<torch::jit::IValue> inputs_jit;
39+
std::vector<torch::jit::IValue> inputs_trt;
40+
inputs_jit.push_back(in_jit.clone());
41+
inputs_trt.push_back(in_trt.clone());
42+
43+
std::vector<trtorch::CompileSpec::Input> input_ranges;
44+
for (auto in : inputs_trt) {
45+
input_ranges.push_back({std::vector<int64_t>{1, 3, 224, 224},
46+
std::vector<int64_t>{1, 3, 224, 224},
47+
std::vector<int64_t>{16, 3, 224, 224},
48+
torch::kFloat});
49+
}
50+
auto compile_settings = trtorch::CompileSpec(input_ranges);
51+
52+
// FP32 execution
53+
compile_settings.enabled_precisions = {torch::kFloat};
54+
compile_settings.strict_types = true;
55+
auto trt_mod = trtorch::CompileGraph(mod, compile_settings);
56+
std::cout << "trtorch::CompileGraph" << std::endl;
57+
58+
int num_threads = 10;
59+
std::vector<torch::jit::IValue> out_vec(num_threads), trt_out_vec(num_threads);
60+
std::vector<std::thread> threads;
61+
for (int i = 0; i < num_threads; i++) {
62+
threads.push_back(std::thread(
63+
run_infer,
64+
i,
65+
std::ref(mod),
66+
std::ref(trt_mod),
67+
inputs_jit,
68+
inputs_trt,
69+
std::ref(out_vec),
70+
std::ref(trt_out_vec)));
71+
}
72+
73+
for (int i = 0; i < num_threads; i++) {
74+
threads[i].join();
75+
}
76+
77+
bool flag = true;
78+
for (int i = 0; i < num_threads; i++) {
79+
bool f = trtorch::tests::util::almostEqual(out_vec[i].toTensor(), trt_out_vec[i].toTensor(), 1e-2);
80+
flag = flag && f;
81+
}
82+
ASSERT_TRUE(flag);
83+
}

0 commit comments

Comments
 (0)