-
Notifications
You must be signed in to change notification settings - Fork 355
/
runtime.h
76 lines (60 loc) · 2 KB
/
runtime.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#pragma once
#include <map>
#include <memory>
#include <mutex>
#include <utility>
#include "ATen/core/function_schema.h"
#include "NvInfer.h"
#include "core/runtime/RTDevice.h"
#include "core/runtime/TRTEngine.h"
#include "core/util/prelude.h"
#include "torch/custom_class.h"
namespace torch_tensorrt {
namespace core {
namespace runtime {
using EngineID = int64_t;
const std::string ABI_VERSION = "5";
extern bool MULTI_DEVICE_SAFE_MODE;
extern bool CUDAGRAPHS_MODE;
typedef enum {
ABI_TARGET_IDX = 0,
NAME_IDX,
DEVICE_IDX,
ENGINE_IDX,
INPUT_BINDING_NAMES_IDX,
OUTPUT_BINDING_NAMES_IDX,
HW_COMPATIBLE_IDX,
SERIALIZED_METADATA_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;
c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device = RTDevice(),
bool hardware_compatible = false);
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device, bool hardware_compatible);
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
void multi_gpu_device_check();
bool get_multi_device_safe_mode();
void set_multi_device_safe_mode(bool multi_device_safe_mode);
bool get_cudagraphs_mode();
void set_cudagraphs_mode(bool multi_device_safe_mode);
class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
DeviceMap device_list;
public:
// Scans and updates the list of available CUDA devices
DeviceList();
public:
void insert(int device_id, RTDevice cuda_device);
RTDevice find(int device_id);
DeviceMap get_devices();
std::string dump_list();
};
DeviceList get_available_device_list();
const std::unordered_map<std::string, std::string>& get_dla_supported_SMs();
void set_rt_device(RTDevice& cuda_device);
// Gets the current active GPU (DLA will not show up through this)
RTDevice get_current_device();
} // namespace runtime
} // namespace core
} // namespace torch_tensorrt