From 6eeba1cd3bb5ee8e5c23410bd28bb61b221fe96f Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 21 Jan 2021 19:33:12 -0800 Subject: [PATCH] feat(//py): [to_backend] adding device specification support for to_backend Also fixes nested dictionary bug reported in #286 Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- py/trtorch/_compile_spec.py | 19 +++++------- py/trtorch/csrc/register_tensorrt_classes.cpp | 24 +++++++++++---- py/trtorch/csrc/tensorrt_backend.cpp | 2 +- py/trtorch/csrc/tensorrt_classes.h | 30 +++++++++++-------- tests/py/BUILD | 3 +- tests/py/test_to_backend_api.py | 11 ++++--- 6 files changed, 52 insertions(+), 37 deletions(-) diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 9ee90e82cf..329d8e1036 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -147,10 +147,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: assert isinstance(compile_spec["strict_types"], bool) info.strict_types = compile_spec["strict_types"] - if "allow_gpu_fallback" in compile_spec: - assert isinstance(compile_spec["allow_gpu_fallback"], bool) - info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"] - if "device" in compile_spec: info.device = _parse_device(compile_spec["device"]) @@ -177,7 +173,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: return info -def TensorRTCompileSpec(compile_spec: Dict[str, Any]): +def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec: """ Utility to create a formated spec dictionary for using the PyTorch TensorRT backend @@ -235,14 +231,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]): ir.set_max(i.max) backend_spec.append_input_range(ir) - for i in parsed_spec.device: - ir = torch.classes.tensorrt.Device() - ir.set_device_type(i.device_type) - ir.set_gpu_id(i.gpu_id) - ir.set_dla_core(i.dla_core) - ir.set_allow_gpu_fallback(i.allow_gpu_fallback) - backend_spec.set_device(ir) + d = torch.classes.tensorrt.Device() + d.set_device_type(int(parsed_spec.device.device_type)) + d.set_gpu_id(parsed_spec.device.gpu_id) + d.set_dla_core(parsed_spec.device.dla_core) + d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback) + backend_spec.set_device(d) backend_spec.set_op_precision(int(parsed_spec.op_precision)) backend_spec.set_refit(parsed_spec.refit) backend_spec.set_debug(parsed_spec.debug) diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 4f1da24d7a..d7b422c360 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -3,23 +3,34 @@ namespace trtorch { namespace backend { namespace { -void RegisterTRTCompileSpec() { + #define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \ (registry).def("set_" #field_name, &class_name::set_##field_name); \ (registry).def("get_" #field_name, &class_name::get_##field_name); +void RegisterTRTCompileSpec() { static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = - torch::class_("tensorrt", "InputRange").def(torch::init<>()); + torch::class_("tensorrt", "InputRange").def(torch::init<>()); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max); + static auto TRTORCH_UNUSED TRTDeviceTSRegistration = + torch::class_("tensorrt", "Device").def(torch::init<>()); + + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type); + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id); + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core); + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback); + + static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration = - torch::class_("tensorrt", "CompileSpec") - .def(torch::init<>()) - .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) - .def("__str__", &trtorch::pyapi::CompileSpec::stringify); + torch::class_("tensorrt", "CompileSpec") + .def(torch::init<>()) + .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) + .def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) + .def("__str__", &trtorch::pyapi::CompileSpec::stringify); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit); @@ -30,6 +41,7 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size); + } struct TRTTSRegistrations { diff --git a/py/trtorch/csrc/tensorrt_backend.cpp b/py/trtorch/csrc/tensorrt_backend.cpp index 0dca942b42..3734594d7a 100644 --- a/py/trtorch/csrc/tensorrt_backend.cpp +++ b/py/trtorch/csrc/tensorrt_backend.cpp @@ -46,7 +46,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10:: auto method = mod.get_method(method_name); auto g = method.graph(); - auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass(); + auto raw_spec = it->value().toCustomClass(); LOG_DEBUG(raw_spec->stringify()); auto cfg = raw_spec->toInternalCompileSpec(); auto convert_cfg = std::move(cfg.convert_info); diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index a3d787f638..c2852532c3 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -17,6 +17,16 @@ namespace pyapi { return field_name; \ } +// TODO: Make this error message more informative +#define ADD_ENUM_GET_SET(field_name, type, max_val) \ + void set_##field_name(int64_t val) { \ + TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \ + field_name = static_cast(val); \ + } \ + int64_t get_##field_name() { \ + return static_cast(field_name); \ + } + struct InputRange : torch::CustomClassHolder { std::vector min; std::vector opt; @@ -59,7 +69,7 @@ struct Device : torch::CustomClassHolder { allow_gpu_fallback(false) // allow_gpu_fallback {} - ADD_FIELD_GET_SET(device_type, DeviceType); + ADD_ENUM_GET_SET(device_type, DeviceType, 1); ADD_FIELD_GET_SET(gpu_id, int64_t); ADD_FIELD_GET_SET(dla_core, int64_t); ADD_FIELD_GET_SET(allow_gpu_fallback, bool); @@ -77,16 +87,6 @@ enum class EngineCapability : int8_t { std::string to_str(EngineCapability value); nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value); -// TODO: Make this error message more informative -#define ADD_ENUM_GET_SET(field_name, type, max_val) \ - void set_##field_name(int64_t val) { \ - TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \ - field_name = static_cast(val); \ - } \ - int64_t get_##field_name() { \ - return static_cast(field_name); \ - } - struct CompileSpec : torch::CustomClassHolder { core::CompileSpec toInternalCompileSpec(); std::string stringify(); @@ -94,11 +94,15 @@ struct CompileSpec : torch::CustomClassHolder { input_ranges.push_back(*ir); } - ADD_ENUM_GET_SET(op_precision, DataType, 3); + void setDeviceIntrusive(const c10::intrusive_ptr& d) { + device = *d; + } + + ADD_ENUM_GET_SET(op_precision, DataType, 2); ADD_FIELD_GET_SET(refit, bool); ADD_FIELD_GET_SET(debug, bool); ADD_FIELD_GET_SET(strict_types, bool); - ADD_ENUM_GET_SET(capability, EngineCapability, 3); + ADD_ENUM_GET_SET(capability, EngineCapability, 2); ADD_FIELD_GET_SET(num_min_timing_iters, int64_t); ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t); ADD_FIELD_GET_SET(workspace_size, int64_t); diff --git a/tests/py/BUILD b/tests/py/BUILD index 2dcbd494ea..d8798f0175 100644 --- a/tests/py/BUILD +++ b/tests/py/BUILD @@ -17,7 +17,8 @@ py_test( ] + select({ ":aarch64_linux": [ "test_api_dla.py" - ] + ], + "//conditions:default" : [] }), deps = [ requirement("torchvision") diff --git a/tests/py/test_to_backend_api.py b/tests/py/test_to_backend_api.py index 72c22582bd..9c2e4c4a40 100644 --- a/tests/py/test_to_backend_api.py +++ b/tests/py/test_to_backend_api.py @@ -19,8 +19,11 @@ def setUp(self): "refit": False, "debug": False, "strict_types": False, - "allow_gpu_fallback": True, - "device_type": "gpu", + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "allow_gpu_fallback": True + }, "capability": trtorch.EngineCapability.default, "num_min_timing_iters": 2, "num_avg_timing_iters": 1, @@ -29,14 +32,14 @@ def setUp(self): } def test_to_backend_lowering(self): - trt_mod = torch._C._jit_to_tensorrt(self.scripted_model._c, {"forward": self.spec}) + trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec) same = (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max() self.assertTrue(same < 2e-3) def test_suite(): suite = unittest.TestSuite() - suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.mobilenet_v2(pretrained=True))) + suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.resnet18(pretrained=True))) return suite