1
+ #include < string>
2
+ #include < thread>
3
+ #include " gtest/gtest.h"
4
+ #include " tests/util/util.h"
5
+ #include " torch/script.h"
6
+ #include " trtorch/trtorch.h"
7
+
8
+ void run_infer (
9
+ int thread_id,
10
+ torch::jit::Module& mod,
11
+ torch::jit::Module& trt_mod,
12
+ const std::vector<torch::jit::IValue> inputs,
13
+ const std::vector<torch::jit::IValue> inputs_trt,
14
+ std::vector<torch::jit::IValue>& out_vec,
15
+ std::vector<torch::jit::IValue>& trt_out_vec) {
16
+ int count = 10 ;
17
+ while (count-- > 0 ) {
18
+ out_vec[thread_id] = mod.forward (inputs);
19
+ trt_out_vec[thread_id] = trt_mod.forward (inputs_trt);
20
+ }
21
+ }
22
+
23
+ TEST (CppAPITests, RuntimeThreadSafety) {
24
+ std::string path = " tests/modules/resnet50_traced.jit.pt" ;
25
+ torch::jit::Module mod;
26
+ try {
27
+ // Deserialize the ScriptModule from a file using torch::jit::load().
28
+ mod = torch::jit::load (path);
29
+ } catch (const c10::Error& e) {
30
+ std::cerr << " error loading the model\n " ;
31
+ }
32
+ mod.eval ();
33
+ mod.to (torch::kCUDA );
34
+
35
+ torch::Tensor in_jit = at::randint (5 , {1 , 3 , 224 , 224 }, torch::kCUDA ).to (torch::kFloat );
36
+ torch::Tensor in_trt = in_jit.clone ().to (torch::kFloat );
37
+
38
+ std::vector<torch::jit::IValue> inputs_jit;
39
+ std::vector<torch::jit::IValue> inputs_trt;
40
+ inputs_jit.push_back (in_jit.clone ());
41
+ inputs_trt.push_back (in_trt.clone ());
42
+
43
+ std::vector<trtorch::CompileSpec::Input> input_ranges;
44
+ for (auto in : inputs_trt) {
45
+ input_ranges.push_back ({std::vector<int64_t >{1 , 3 , 224 , 224 },
46
+ std::vector<int64_t >{1 , 3 , 224 , 224 },
47
+ std::vector<int64_t >{16 , 3 , 224 , 224 },
48
+ torch::kFloat });
49
+ }
50
+ auto compile_settings = trtorch::CompileSpec (input_ranges);
51
+
52
+ // FP32 execution
53
+ compile_settings.enabled_precisions = {torch::kFloat };
54
+ compile_settings.strict_types = true ;
55
+ auto trt_mod = trtorch::CompileGraph (mod, compile_settings);
56
+ std::cout << " trtorch::CompileGraph" << std::endl;
57
+
58
+ int num_threads = 10 ;
59
+ std::vector<torch::jit::IValue> out_vec (num_threads), trt_out_vec (num_threads);
60
+ std::vector<std::thread> threads;
61
+ for (int i = 0 ; i < num_threads; i++) {
62
+ threads.push_back (std::thread (
63
+ run_infer,
64
+ i,
65
+ std::ref (mod),
66
+ std::ref (trt_mod),
67
+ inputs_jit,
68
+ inputs_trt,
69
+ std::ref (out_vec),
70
+ std::ref (trt_out_vec)));
71
+ }
72
+
73
+ for (int i = 0 ; i < num_threads; i++) {
74
+ threads[i].join ();
75
+ }
76
+
77
+ bool flag = true ;
78
+ for (int i = 0 ; i < num_threads; i++) {
79
+ bool f = trtorch::tests::util::almostEqual (out_vec[i].toTensor (), trt_out_vec[i].toTensor (), 1e-2 );
80
+ flag = flag && f;
81
+ }
82
+ ASSERT_TRUE (flag);
83
+ }
0 commit comments