Skip to content

Commit

Permalink
feat(//py): [to_backend] adding device specification support for
Browse files Browse the repository at this point in the history
to_backend

Also fixes nested dictionary bug reported in #286

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jan 22, 2021
1 parent 72bf74b commit 6eeba1c
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 37 deletions.
19 changes: 7 additions & 12 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
assert isinstance(compile_spec["strict_types"], bool)
info.strict_types = compile_spec["strict_types"]

if "allow_gpu_fallback" in compile_spec:
assert isinstance(compile_spec["allow_gpu_fallback"], bool)
info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"]

if "device" in compile_spec:
info.device = _parse_device(compile_spec["device"])

Expand All @@ -177,7 +173,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
return info


def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec:
"""
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
Expand Down Expand Up @@ -235,14 +231,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
ir.set_max(i.max)
backend_spec.append_input_range(ir)

for i in parsed_spec.device:
ir = torch.classes.tensorrt.Device()
ir.set_device_type(i.device_type)
ir.set_gpu_id(i.gpu_id)
ir.set_dla_core(i.dla_core)
ir.set_allow_gpu_fallback(i.allow_gpu_fallback)
backend_spec.set_device(ir)
d = torch.classes.tensorrt.Device()
d.set_device_type(int(parsed_spec.device.device_type))
d.set_gpu_id(parsed_spec.device.gpu_id)
d.set_dla_core(parsed_spec.device.dla_core)
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)

backend_spec.set_device(d)
backend_spec.set_op_precision(int(parsed_spec.op_precision))
backend_spec.set_refit(parsed_spec.refit)
backend_spec.set_debug(parsed_spec.debug)
Expand Down
24 changes: 18 additions & 6 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,34 @@
namespace trtorch {
namespace backend {
namespace {
void RegisterTRTCompileSpec() {

#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
(registry).def("set_" #field_name, &class_name::set_##field_name); \
(registry).def("get_" #field_name, &class_name::get_##field_name);

void RegisterTRTCompileSpec() {
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());

ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);

static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
torch::class_<trtorch::pyapi::Device>("tensorrt", "Device").def(torch::init<>());

ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);


static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
Expand All @@ -30,6 +41,7 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);

}

struct TRTTSRegistrations {
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/csrc/tensorrt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::
auto method = mod.get_method(method_name);
auto g = method.graph();

auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>();
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
LOG_DEBUG(raw_spec->stringify());
auto cfg = raw_spec->toInternalCompileSpec();
auto convert_cfg = std::move(cfg.convert_info);
Expand Down
30 changes: 17 additions & 13 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ namespace pyapi {
return field_name; \
}

// TODO: Make this error message more informative
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
void set_##field_name(int64_t val) { \
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
field_name = static_cast<type>(val); \
} \
int64_t get_##field_name() { \
return static_cast<int64_t>(field_name); \
}

struct InputRange : torch::CustomClassHolder {
std::vector<int64_t> min;
std::vector<int64_t> opt;
Expand Down Expand Up @@ -59,7 +69,7 @@ struct Device : torch::CustomClassHolder {
allow_gpu_fallback(false) // allow_gpu_fallback
{}

ADD_FIELD_GET_SET(device_type, DeviceType);
ADD_ENUM_GET_SET(device_type, DeviceType, 1);
ADD_FIELD_GET_SET(gpu_id, int64_t);
ADD_FIELD_GET_SET(dla_core, int64_t);
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
Expand All @@ -77,28 +87,22 @@ enum class EngineCapability : int8_t {
std::string to_str(EngineCapability value);
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);

// TODO: Make this error message more informative
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
void set_##field_name(int64_t val) { \
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
field_name = static_cast<type>(val); \
} \
int64_t get_##field_name() { \
return static_cast<int64_t>(field_name); \
}

struct CompileSpec : torch::CustomClassHolder {
core::CompileSpec toInternalCompileSpec();
std::string stringify();
void appendInputRange(const c10::intrusive_ptr<InputRange>& ir) {
input_ranges.push_back(*ir);
}

ADD_ENUM_GET_SET(op_precision, DataType, 3);
void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
device = *d;
}

ADD_ENUM_GET_SET(op_precision, DataType, 2);
ADD_FIELD_GET_SET(refit, bool);
ADD_FIELD_GET_SET(debug, bool);
ADD_FIELD_GET_SET(strict_types, bool);
ADD_ENUM_GET_SET(capability, EngineCapability, 3);
ADD_ENUM_GET_SET(capability, EngineCapability, 2);
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
ADD_FIELD_GET_SET(workspace_size, int64_t);
Expand Down
3 changes: 2 additions & 1 deletion tests/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ py_test(
] + select({
":aarch64_linux": [
"test_api_dla.py"
]
],
"//conditions:default" : []
}),
deps = [
requirement("torchvision")
Expand Down
11 changes: 7 additions & 4 deletions tests/py/test_to_backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ def setUp(self):
"refit": False,
"debug": False,
"strict_types": False,
"allow_gpu_fallback": True,
"device_type": "gpu",
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"allow_gpu_fallback": True
},
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
Expand All @@ -29,14 +32,14 @@ def setUp(self):
}

def test_to_backend_lowering(self):
trt_mod = torch._C._jit_to_tensorrt(self.scripted_model._c, {"forward": self.spec})
trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec)
same = (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)


def test_suite():
suite = unittest.TestSuite()
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.resnet18(pretrained=True)))

return suite

Expand Down

0 comments on commit 6eeba1c

Please sign in to comment.