From 6eeba1cd3bb5ee8e5c23410bd28bb61b221fe96f Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 21 Jan 2021 19:33:12 -0800 Subject: [PATCH 1/8] feat(//py): [to_backend] adding device specification support for to_backend Also fixes nested dictionary bug reported in #286 Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- py/trtorch/_compile_spec.py | 19 +++++------- py/trtorch/csrc/register_tensorrt_classes.cpp | 24 +++++++++++---- py/trtorch/csrc/tensorrt_backend.cpp | 2 +- py/trtorch/csrc/tensorrt_classes.h | 30 +++++++++++-------- tests/py/BUILD | 3 +- tests/py/test_to_backend_api.py | 11 ++++--- 6 files changed, 52 insertions(+), 37 deletions(-) diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 9ee90e82cf..329d8e1036 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -147,10 +147,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: assert isinstance(compile_spec["strict_types"], bool) info.strict_types = compile_spec["strict_types"] - if "allow_gpu_fallback" in compile_spec: - assert isinstance(compile_spec["allow_gpu_fallback"], bool) - info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"] - if "device" in compile_spec: info.device = _parse_device(compile_spec["device"]) @@ -177,7 +173,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: return info -def TensorRTCompileSpec(compile_spec: Dict[str, Any]): +def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec: """ Utility to create a formated spec dictionary for using the PyTorch TensorRT backend @@ -235,14 +231,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]): ir.set_max(i.max) backend_spec.append_input_range(ir) - for i in parsed_spec.device: - ir = torch.classes.tensorrt.Device() - ir.set_device_type(i.device_type) - ir.set_gpu_id(i.gpu_id) - ir.set_dla_core(i.dla_core) - ir.set_allow_gpu_fallback(i.allow_gpu_fallback) - backend_spec.set_device(ir) + d = torch.classes.tensorrt.Device() + d.set_device_type(int(parsed_spec.device.device_type)) + d.set_gpu_id(parsed_spec.device.gpu_id) + d.set_dla_core(parsed_spec.device.dla_core) + d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback) + backend_spec.set_device(d) backend_spec.set_op_precision(int(parsed_spec.op_precision)) backend_spec.set_refit(parsed_spec.refit) backend_spec.set_debug(parsed_spec.debug) diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 4f1da24d7a..d7b422c360 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -3,23 +3,34 @@ namespace trtorch { namespace backend { namespace { -void RegisterTRTCompileSpec() { + #define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \ (registry).def("set_" #field_name, &class_name::set_##field_name); \ (registry).def("get_" #field_name, &class_name::get_##field_name); +void RegisterTRTCompileSpec() { static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = - torch::class_("tensorrt", "InputRange").def(torch::init<>()); + torch::class_("tensorrt", "InputRange").def(torch::init<>()); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max); + static auto TRTORCH_UNUSED TRTDeviceTSRegistration = + torch::class_("tensorrt", "Device").def(torch::init<>()); + + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type); + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id); + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core); + ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback); + + static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration = - torch::class_("tensorrt", "CompileSpec") - .def(torch::init<>()) - .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) - .def("__str__", &trtorch::pyapi::CompileSpec::stringify); + torch::class_("tensorrt", "CompileSpec") + .def(torch::init<>()) + .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) + .def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) + .def("__str__", &trtorch::pyapi::CompileSpec::stringify); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit); @@ -30,6 +41,7 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size); + } struct TRTTSRegistrations { diff --git a/py/trtorch/csrc/tensorrt_backend.cpp b/py/trtorch/csrc/tensorrt_backend.cpp index 0dca942b42..3734594d7a 100644 --- a/py/trtorch/csrc/tensorrt_backend.cpp +++ b/py/trtorch/csrc/tensorrt_backend.cpp @@ -46,7 +46,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10:: auto method = mod.get_method(method_name); auto g = method.graph(); - auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass(); + auto raw_spec = it->value().toCustomClass(); LOG_DEBUG(raw_spec->stringify()); auto cfg = raw_spec->toInternalCompileSpec(); auto convert_cfg = std::move(cfg.convert_info); diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index a3d787f638..c2852532c3 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -17,6 +17,16 @@ namespace pyapi { return field_name; \ } +// TODO: Make this error message more informative +#define ADD_ENUM_GET_SET(field_name, type, max_val) \ + void set_##field_name(int64_t val) { \ + TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \ + field_name = static_cast(val); \ + } \ + int64_t get_##field_name() { \ + return static_cast(field_name); \ + } + struct InputRange : torch::CustomClassHolder { std::vector min; std::vector opt; @@ -59,7 +69,7 @@ struct Device : torch::CustomClassHolder { allow_gpu_fallback(false) // allow_gpu_fallback {} - ADD_FIELD_GET_SET(device_type, DeviceType); + ADD_ENUM_GET_SET(device_type, DeviceType, 1); ADD_FIELD_GET_SET(gpu_id, int64_t); ADD_FIELD_GET_SET(dla_core, int64_t); ADD_FIELD_GET_SET(allow_gpu_fallback, bool); @@ -77,16 +87,6 @@ enum class EngineCapability : int8_t { std::string to_str(EngineCapability value); nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value); -// TODO: Make this error message more informative -#define ADD_ENUM_GET_SET(field_name, type, max_val) \ - void set_##field_name(int64_t val) { \ - TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \ - field_name = static_cast(val); \ - } \ - int64_t get_##field_name() { \ - return static_cast(field_name); \ - } - struct CompileSpec : torch::CustomClassHolder { core::CompileSpec toInternalCompileSpec(); std::string stringify(); @@ -94,11 +94,15 @@ struct CompileSpec : torch::CustomClassHolder { input_ranges.push_back(*ir); } - ADD_ENUM_GET_SET(op_precision, DataType, 3); + void setDeviceIntrusive(const c10::intrusive_ptr& d) { + device = *d; + } + + ADD_ENUM_GET_SET(op_precision, DataType, 2); ADD_FIELD_GET_SET(refit, bool); ADD_FIELD_GET_SET(debug, bool); ADD_FIELD_GET_SET(strict_types, bool); - ADD_ENUM_GET_SET(capability, EngineCapability, 3); + ADD_ENUM_GET_SET(capability, EngineCapability, 2); ADD_FIELD_GET_SET(num_min_timing_iters, int64_t); ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t); ADD_FIELD_GET_SET(workspace_size, int64_t); diff --git a/tests/py/BUILD b/tests/py/BUILD index 2dcbd494ea..d8798f0175 100644 --- a/tests/py/BUILD +++ b/tests/py/BUILD @@ -17,7 +17,8 @@ py_test( ] + select({ ":aarch64_linux": [ "test_api_dla.py" - ] + ], + "//conditions:default" : [] }), deps = [ requirement("torchvision") diff --git a/tests/py/test_to_backend_api.py b/tests/py/test_to_backend_api.py index 72c22582bd..9c2e4c4a40 100644 --- a/tests/py/test_to_backend_api.py +++ b/tests/py/test_to_backend_api.py @@ -19,8 +19,11 @@ def setUp(self): "refit": False, "debug": False, "strict_types": False, - "allow_gpu_fallback": True, - "device_type": "gpu", + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "allow_gpu_fallback": True + }, "capability": trtorch.EngineCapability.default, "num_min_timing_iters": 2, "num_avg_timing_iters": 1, @@ -29,14 +32,14 @@ def setUp(self): } def test_to_backend_lowering(self): - trt_mod = torch._C._jit_to_tensorrt(self.scripted_model._c, {"forward": self.spec}) + trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec) same = (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max() self.assertTrue(same < 2e-3) def test_suite(): suite = unittest.TestSuite() - suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.mobilenet_v2(pretrained=True))) + suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.resnet18(pretrained=True))) return suite From 3d14cdac6d43afe99fd8721f81bfbc406c0915c9 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 21 Jan 2021 19:37:23 -0800 Subject: [PATCH 2/8] feat(//core/lowering): Adding a new pass to handle new dim checks for batchnorm Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/lowering/lowering.cpp | 1 + core/lowering/passes/BUILD | 1 + core/lowering/passes/passes.h | 1 + core/lowering/passes/remove_bn_dim_check.cpp | 88 ++++++++++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 core/lowering/passes/remove_bn_dim_check.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index ab1f5561ef..7ee86f0d4c 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -40,6 +40,7 @@ void LowerGraph(std::shared_ptr& g) { passes::Conv2DToConvolution(g); passes::Conv3DToConvolution(g); passes::FuseAddMMBranches(g); + passes::RemoveBNDimCheck(g); torch::jit::EliminateCommonSubexpression(g); // torch::jit::UnrollLoops(g); torch::jit::EliminateCommonSubexpression(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index be6f3fcf42..9d3f328e20 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -18,6 +18,7 @@ cc_library( "exception_elimination.cpp", "fuse_addmm_branches.cpp", "fuse_flatten_linear.cpp", + "remove_bn_dim_check.cpp", "remove_contiguous.cpp", "remove_dropout.cpp", "remove_to.cpp", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index f1c72c2aca..977533a3d6 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -12,6 +12,7 @@ void Conv3DToConvolution(std::shared_ptr& graph); void FuseAddMMBranches(std::shared_ptr graph); void FuseFlattenLinear(std::shared_ptr& graph); void EliminateExceptionOrPassPattern(std::shared_ptr graph); +void RemoveBNDimCheck(std::shared_ptr graph); void RemoveContiguous(std::shared_ptr& graph); void RemoveDropout(std::shared_ptr& graph); void RemoveTo(std::shared_ptr graph); diff --git a/core/lowering/passes/remove_bn_dim_check.cpp b/core/lowering/passes/remove_bn_dim_check.cpp new file mode 100644 index 0000000000..92e48137e4 --- /dev/null +++ b/core/lowering/passes/remove_bn_dim_check.cpp @@ -0,0 +1,88 @@ +#include "torch/csrc/jit/ir/alias_analysis.h" +#include "torch/csrc/jit/jit_log.h" +#include "torch/csrc/jit/passes/constant_propagation.h" +#include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/guard_elimination.h" +#include "torch/csrc/jit/passes/peephole.h" +#include "torch/csrc/jit/runtime/graph_executor.h" + +#include "core/util/prelude.h" + +#include + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { +namespace { +using namespace torch::jit; +struct BNDimCheckRemoval { + BNDimCheckRemoval(std::shared_ptr graph) : graph_(std::move(graph)) {} + + void run() { + findBNDimCheckNodes(graph_->block()); + torch::jit::EliminateDeadCode(graph_); + LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_); + } + + private: + bool isBNDimCheckNodes(Node* n) { + /// Check if this Node hosts a pattern like so: + /// %290 : bool = aten::ne(%289, %9) + /// = prim::If(%290) + /// block0(): + /// %291 : str = aten::format(%10, %289) + /// = prim::RaiseException(%291) + /// -> () + /// block1(): + /// -> () + + if (n->blocks().size() != 2) { + return false; + } + auto arm1 = n->blocks()[0]; + auto arm2 = n->blocks()[1]; + if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) { + // Make sure that the node doesn't actually produce any Value that are + // used by other nodes + return false; + } + + auto arm1_start = arm1->nodes().begin(); + + if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") && (*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) { + // Make sure that block0 is solely just the exception and the return + return false; + } + + if ((*(arm2->nodes().begin()))->kind() != prim::Return) { + // Make sure that block1 is solely the return + return false; + } + + return true; + } + + void findBNDimCheckNodes(Block* b) { + for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { + auto n = *it; + if (n->kind() == prim::If && isBNDimCheckNodes(n)) { + LOG_GRAPH("Found that node " << *n << " is an batch norm dim check node (EliminateChecks)" << std::endl); + it.destroyCurrent(); + } + } + } + + std::shared_ptr graph_; +}; +} // namespace + +void RemoveBNDimCheck(std::shared_ptr graph) { + BNDimCheckRemoval bndcr(std::move(graph)); + bndcr.run(); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch From 0618b6bd8eea654f5b40b942467cefd6872d211e Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 21 Jan 2021 19:40:45 -0800 Subject: [PATCH 3/8] refactor!: Update bazel and trt versions BREAKING CHANGE: Version of bazel has been bumped to 4.0.0 Version of TensorRT has been bumped to 7.2.2.3 Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .bazelversion | 2 +- WORKSPACE | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.bazelversion b/.bazelversion index 7c69a55dbb..fcdb2e109f 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -3.7.0 +4.0.0 diff --git a/WORKSPACE b/WORKSPACE index 5bc2d2ccc2..1cb4500b90 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -79,10 +79,10 @@ http_archive( http_archive( name = "tensorrt", - urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.1/tars/TensorRT-7.2.1.6.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",], + urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.2/tars/TensorRT-7.2.2.3.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",], build_file = "@//third_party/tensorrt/archive:BUILD", - sha256 = "8def6b03b0c8c3751f560df21b3e99668ae05aab5140b1d38b8e51e4a0ffbbb8", - strip_prefix = "TensorRT-7.2.1.6" + strip_prefix = "TensorRT-7.2.2.3", + sha256 = "b5c325e38e1d92ce1ce92ca8b54ede9c224bf128c9a53eb0b9022f1ee4313ee0" ) #################################################################################### From 57c6d46f66a2aed91135badb90485f8ca1e58578 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 21 Jan 2021 19:53:22 -0800 Subject: [PATCH 4/8] docs: Update the docs to include new device API for to_backend Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan docs: Update docs for to_backend API for new device API and new PyTorch API Changes the docs to show the new device dictionary API and how to use the new to backend api (changed from PyTorch 1.6.0) Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- docsrc/tutorials/use_from_pytorch.rst | 36 +++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/docsrc/tutorials/use_from_pytorch.rst b/docsrc/tutorials/use_from_pytorch.rst index 322efd29a9..985cd01617 100644 --- a/docsrc/tutorials/use_from_pytorch.rst +++ b/docsrc/tutorials/use_from_pytorch.rst @@ -32,31 +32,35 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API. .. code-block:: python spec = { - "forward": trtorch.TensorRTCompileSpec({ - "input_shapes": [[1, 3, 300, 300]], - "op_precision": torch.half, - "refit": False, - "debug": False, - "strict_types": False, - "allow_gpu_fallback": True, - "device_type": "gpu", - "capability": trtorch.EngineCapability.default, - "num_min_timing_iters": 2, - "num_avg_timing_iters": 1, - "max_batch_size": 0, - }) - } + "forward": + trtorch.TensorRTCompileSpec({ + "input_shapes": [[1, 3, 300, 300]], + "op_precision": torch.half, + "refit": False, + "debug": False, + "strict_types": False, + "device": { + "device_type": trtorch.DeviceType.GPU, + "gpu_id": 0, + "allow_gpu_fallback": True + }, + "capability": trtorch.EngineCapability.default, + "num_min_timing_iters": 2, + "num_avg_timing_iters": 1, + "max_batch_size": 0, + }) + } Now to compile with TRTorch, provide the target module objects and the spec dictionary to ``torch._C._jit_to_tensorrt`` .. code-block:: python - trt_model = torch._C._jit_to_tensorrt(script_model._c, spec) + trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec) To run explicitly call the function of the method you want to run (vs. how you can just call on the module itself in standard PyTorch) .. code-block:: python - input = torch.randn((1, 3, 300, 300).to("cuda").to(torch.half) + input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half) print(trt_model.forward(input)) From 08156803655f1e1b9551106c7b8809d52b952d17 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 21 Jan 2021 20:33:28 -0800 Subject: [PATCH 5/8] refactor: Lint code Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/lowering/passes/remove_bn_dim_check.cpp | 3 ++- py/trtorch/csrc/register_tensorrt_classes.cpp | 16 +++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/core/lowering/passes/remove_bn_dim_check.cpp b/core/lowering/passes/remove_bn_dim_check.cpp index 92e48137e4..6116522484 100644 --- a/core/lowering/passes/remove_bn_dim_check.cpp +++ b/core/lowering/passes/remove_bn_dim_check.cpp @@ -50,7 +50,8 @@ struct BNDimCheckRemoval { auto arm1_start = arm1->nodes().begin(); - if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") && (*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) { + if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") && + (*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) { // Make sure that block0 is solely just the exception and the return return false; } diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index d7b422c360..4446ece752 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -10,27 +10,26 @@ namespace { void RegisterTRTCompileSpec() { static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = - torch::class_("tensorrt", "InputRange").def(torch::init<>()); + torch::class_("tensorrt", "InputRange").def(torch::init<>()); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max); static auto TRTORCH_UNUSED TRTDeviceTSRegistration = - torch::class_("tensorrt", "Device").def(torch::init<>()); + torch::class_("tensorrt", "Device").def(torch::init<>()); ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type); ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id); ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core); ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback); - static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration = - torch::class_("tensorrt", "CompileSpec") - .def(torch::init<>()) - .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) - .def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) - .def("__str__", &trtorch::pyapi::CompileSpec::stringify); + torch::class_("tensorrt", "CompileSpec") + .def(torch::init<>()) + .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) + .def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) + .def("__str__", &trtorch::pyapi::CompileSpec::stringify); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit); @@ -41,7 +40,6 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size); - } struct TRTTSRegistrations { From 5031324399ac720743e0f0874a37e16d4bb594e9 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 22 Jan 2021 15:39:52 -0800 Subject: [PATCH 6/8] refactor: Addressing PR comments Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- docsrc/tutorials/use_from_pytorch.rst | 1 + py/trtorch/_compile_spec.py | 8 ++++---- py/trtorch/csrc/tensorrt_classes.h | 6 +++--- tests/py/test_to_backend_api.py | 1 + 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docsrc/tutorials/use_from_pytorch.rst b/docsrc/tutorials/use_from_pytorch.rst index 985cd01617..ac84a14b49 100644 --- a/docsrc/tutorials/use_from_pytorch.rst +++ b/docsrc/tutorials/use_from_pytorch.rst @@ -42,6 +42,7 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API. "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, + "dla_core": 0, "allow_gpu_fallback": True }, "capability": trtorch.EngineCapability.default, diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 329d8e1036..311bee34c1 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -195,10 +195,10 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. } # Dynamic input shape for input #2 ], "device": { - "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA) - "gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA) - "dla_core": 0, # (DLA only) Target dla core id to run engine - "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU + "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA) + "gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA) + "dla_core": 0, # (DLA only) Target dla core id to run engine + "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU }, "op_precision": torch.half, # Operating precision set to FP16 "refit": False, # enable refit diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index c2852532c3..3b8b3cdcdb 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -69,7 +69,7 @@ struct Device : torch::CustomClassHolder { allow_gpu_fallback(false) // allow_gpu_fallback {} - ADD_ENUM_GET_SET(device_type, DeviceType, 1); + ADD_ENUM_GET_SET(device_type, DeviceType, static_cast(DeviceType::kDLA)); ADD_FIELD_GET_SET(gpu_id, int64_t); ADD_FIELD_GET_SET(dla_core, int64_t); ADD_FIELD_GET_SET(allow_gpu_fallback, bool); @@ -98,11 +98,11 @@ struct CompileSpec : torch::CustomClassHolder { device = *d; } - ADD_ENUM_GET_SET(op_precision, DataType, 2); + ADD_ENUM_GET_SET(op_precision, DataType, static_cast(DataType::kChar)); ADD_FIELD_GET_SET(refit, bool); ADD_FIELD_GET_SET(debug, bool); ADD_FIELD_GET_SET(strict_types, bool); - ADD_ENUM_GET_SET(capability, EngineCapability, 2); + ADD_ENUM_GET_SET(capability, EngineCapability, static_cast(EngineCapability::kSAFE_DLA)); ADD_FIELD_GET_SET(num_min_timing_iters, int64_t); ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t); ADD_FIELD_GET_SET(workspace_size, int64_t); diff --git a/tests/py/test_to_backend_api.py b/tests/py/test_to_backend_api.py index 9c2e4c4a40..f111694bfc 100644 --- a/tests/py/test_to_backend_api.py +++ b/tests/py/test_to_backend_api.py @@ -22,6 +22,7 @@ def setUp(self): "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, + "dla_core": 0, "allow_gpu_fallback": True }, "capability": trtorch.EngineCapability.default, From 6b942e5092eb83564ec5d691680dcda0859cb240 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 22 Jan 2021 15:46:22 -0800 Subject: [PATCH 7/8] fix(//py): Fix bounds for enum macros Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- py/trtorch/csrc/tensorrt_classes.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index 3b8b3cdcdb..1ad32b3167 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -18,13 +18,13 @@ namespace pyapi { } // TODO: Make this error message more informative -#define ADD_ENUM_GET_SET(field_name, type, max_val) \ - void set_##field_name(int64_t val) { \ - TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \ - field_name = static_cast(val); \ - } \ - int64_t get_##field_name() { \ - return static_cast(field_name); \ +#define ADD_ENUM_GET_SET(field_name, type, max_val) \ + void set_##field_name(int64_t val) { \ + TRTORCH_CHECK(val >= 0 && val <= max_val, "Invalid enum value for field"); \ + field_name = static_cast(val); \ + } \ + int64_t get_##field_name() { \ + return static_cast(field_name); \ } struct InputRange : torch::CustomClassHolder { From 86bb5b7e351a25f043ffc0831ac44787fe89ea30 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 22 Jan 2021 16:49:30 -0800 Subject: [PATCH 8/8] fix(//core/lowering): fix debug message for bn dim check removal pass Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/lowering/passes/remove_bn_dim_check.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/lowering/passes/remove_bn_dim_check.cpp b/core/lowering/passes/remove_bn_dim_check.cpp index 6116522484..9b2a63b0a7 100644 --- a/core/lowering/passes/remove_bn_dim_check.cpp +++ b/core/lowering/passes/remove_bn_dim_check.cpp @@ -22,7 +22,7 @@ struct BNDimCheckRemoval { void run() { findBNDimCheckNodes(graph_->block()); torch::jit::EliminateDeadCode(graph_); - LOG_GRAPH("Post aten::addmm branch fusion: " << *graph_); + LOG_GRAPH("Post batch norm dim check removal: " << *graph_); } private: