Skip to content

Commit

Permalink
Merge pull request #288 from NVIDIA/to_backend_device
Browse files Browse the repository at this point in the history
Adding the new device API, fixing the a nested dict issue in the existing compile phase, adding new lowering pass for bn
  • Loading branch information
narendasan authored Jan 25, 2021
2 parents b787c5e + 86bb5b7 commit 20022d4
Show file tree
Hide file tree
Showing 13 changed files with 167 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.7.0
4.0.0
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ http_archive(

http_archive(
name = "tensorrt",
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.1/tars/TensorRT-7.2.1.6.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",],
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.2/tars/TensorRT-7.2.2.3.Ubuntu-18.04.x86_64-gnu.cuda-11.0.cudnn8.0.tar.gz",],
build_file = "@//third_party/tensorrt/archive:BUILD",
sha256 = "8def6b03b0c8c3751f560df21b3e99668ae05aab5140b1d38b8e51e4a0ffbbb8",
strip_prefix = "TensorRT-7.2.1.6"
strip_prefix = "TensorRT-7.2.2.3",
sha256 = "b5c325e38e1d92ce1ce92ca8b54ede9c224bf128c9a53eb0b9022f1ee4313ee0"
)

####################################################################################
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::Conv2DToConvolution(g);
passes::Conv3DToConvolution(g);
passes::FuseAddMMBranches(g);
passes::RemoveBNDimCheck(g);
torch::jit::EliminateCommonSubexpression(g);
// torch::jit::UnrollLoops(g);
torch::jit::EliminateCommonSubexpression(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"exception_elimination.cpp",
"fuse_addmm_branches.cpp",
"fuse_flatten_linear.cpp",
"remove_bn_dim_check.cpp",
"remove_contiguous.cpp",
"remove_dropout.cpp",
"remove_to.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveTo(std::shared_ptr<torch::jit::Graph> graph);
Expand Down
89 changes: 89 additions & 0 deletions core/lowering/passes/remove_bn_dim_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/runtime/graph_executor.h"

#include "core/util/prelude.h"

#include <vector>

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {
namespace {
using namespace torch::jit;
struct BNDimCheckRemoval {
BNDimCheckRemoval(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}

void run() {
findBNDimCheckNodes(graph_->block());
torch::jit::EliminateDeadCode(graph_);
LOG_GRAPH("Post batch norm dim check removal: " << *graph_);
}

private:
bool isBNDimCheckNodes(Node* n) {
/// Check if this Node hosts a pattern like so:
/// %290 : bool = aten::ne(%289, %9)
/// = prim::If(%290)
/// block0():
/// %291 : str = aten::format(%10, %289)
/// = prim::RaiseException(%291)
/// -> ()
/// block1():
/// -> ()

if (n->blocks().size() != 2) {
return false;
}
auto arm1 = n->blocks()[0];
auto arm2 = n->blocks()[1];
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
// Make sure that the node doesn't actually produce any Value that are
// used by other nodes
return false;
}

auto arm1_start = arm1->nodes().begin();

if ((*arm1_start)->kind() != c10::Symbol::fromQualString("aten::format") &&
(*(++arm1_start))->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
// Make sure that block0 is solely just the exception and the return
return false;
}

if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
// Make sure that block1 is solely the return
return false;
}

return true;
}

void findBNDimCheckNodes(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isBNDimCheckNodes(n)) {
LOG_GRAPH("Found that node " << *n << " is an batch norm dim check node (EliminateChecks)" << std::endl);
it.destroyCurrent();
}
}
}

std::shared_ptr<Graph> graph_;
};
} // namespace

void RemoveBNDimCheck(std::shared_ptr<Graph> graph) {
BNDimCheckRemoval bndcr(std::move(graph));
bndcr.run();
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
37 changes: 21 additions & 16 deletions docsrc/tutorials/use_from_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,36 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API.
.. code-block:: python
spec = {
"forward": trtorch.TensorRTCompileSpec({
"input_shapes": [[1, 3, 300, 300]],
"op_precision": torch.half,
"refit": False,
"debug": False,
"strict_types": False,
"allow_gpu_fallback": True,
"device_type": "gpu",
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
"max_batch_size": 0,
})
}
"forward":
trtorch.TensorRTCompileSpec({
"input_shapes": [[1, 3, 300, 300]],
"op_precision": torch.half,
"refit": False,
"debug": False,
"strict_types": False,
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True
},
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
"max_batch_size": 0,
})
}
Now to compile with TRTorch, provide the target module objects and the spec dictionary to ``torch._C._jit_to_tensorrt``

.. code-block:: python
trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)
To run explicitly call the function of the method you want to run (vs. how you can just call on the module itself in standard PyTorch)

.. code-block:: python
input = torch.randn((1, 3, 300, 300).to("cuda").to(torch.half)
input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
print(trt_model.forward(input))
27 changes: 11 additions & 16 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
assert isinstance(compile_spec["strict_types"], bool)
info.strict_types = compile_spec["strict_types"]

if "allow_gpu_fallback" in compile_spec:
assert isinstance(compile_spec["allow_gpu_fallback"], bool)
info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"]

if "device" in compile_spec:
info.device = _parse_device(compile_spec["device"])

Expand All @@ -177,7 +173,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
return info


def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec:
"""
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
Expand All @@ -199,10 +195,10 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
} # Dynamic input shape for input #2
],
"device": {
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
"dla_core": 0, # (DLA only) Target dla core id to run engine
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
"dla_core": 0, # (DLA only) Target dla core id to run engine
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
"refit": False, # enable refit
Expand Down Expand Up @@ -235,14 +231,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]):
ir.set_max(i.max)
backend_spec.append_input_range(ir)

for i in parsed_spec.device:
ir = torch.classes.tensorrt.Device()
ir.set_device_type(i.device_type)
ir.set_gpu_id(i.gpu_id)
ir.set_dla_core(i.dla_core)
ir.set_allow_gpu_fallback(i.allow_gpu_fallback)
backend_spec.set_device(ir)
d = torch.classes.tensorrt.Device()
d.set_device_type(int(parsed_spec.device.device_type))
d.set_gpu_id(parsed_spec.device.gpu_id)
d.set_dla_core(parsed_spec.device.dla_core)
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)

backend_spec.set_device(d)
backend_spec.set_op_precision(int(parsed_spec.op_precision))
backend_spec.set_refit(parsed_spec.refit)
backend_spec.set_debug(parsed_spec.debug)
Expand Down
12 changes: 11 additions & 1 deletion py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
namespace trtorch {
namespace backend {
namespace {
void RegisterTRTCompileSpec() {

#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
(registry).def("set_" #field_name, &class_name::set_##field_name); \
(registry).def("get_" #field_name, &class_name::get_##field_name);

void RegisterTRTCompileSpec() {
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "InputRange").def(torch::init<>());

ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);

static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
torch::class_<trtorch::pyapi::Device>("tensorrt", "Device").def(torch::init<>());

ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);

static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/csrc/tensorrt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue processed_mod, c10::
auto method = mod.get_method(method_name);
auto g = method.graph();

auto raw_spec = it->value().toGenericDict().at(it->key()).toCustomClass<trtorch::pyapi::CompileSpec>();
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
LOG_DEBUG(raw_spec->stringify());
auto cfg = raw_spec->toInternalCompileSpec();
auto convert_cfg = std::move(cfg.convert_info);
Expand Down
30 changes: 17 additions & 13 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ namespace pyapi {
return field_name; \
}

// TODO: Make this error message more informative
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
void set_##field_name(int64_t val) { \
TRTORCH_CHECK(val >= 0 && val <= max_val, "Invalid enum value for field"); \
field_name = static_cast<type>(val); \
} \
int64_t get_##field_name() { \
return static_cast<int64_t>(field_name); \
}

struct InputRange : torch::CustomClassHolder {
std::vector<int64_t> min;
std::vector<int64_t> opt;
Expand Down Expand Up @@ -59,7 +69,7 @@ struct Device : torch::CustomClassHolder {
allow_gpu_fallback(false) // allow_gpu_fallback
{}

ADD_FIELD_GET_SET(device_type, DeviceType);
ADD_ENUM_GET_SET(device_type, DeviceType, static_cast<int64_t>(DeviceType::kDLA));
ADD_FIELD_GET_SET(gpu_id, int64_t);
ADD_FIELD_GET_SET(dla_core, int64_t);
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
Expand All @@ -77,28 +87,22 @@ enum class EngineCapability : int8_t {
std::string to_str(EngineCapability value);
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);

// TODO: Make this error message more informative
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
void set_##field_name(int64_t val) { \
TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \
field_name = static_cast<type>(val); \
} \
int64_t get_##field_name() { \
return static_cast<int64_t>(field_name); \
}

struct CompileSpec : torch::CustomClassHolder {
core::CompileSpec toInternalCompileSpec();
std::string stringify();
void appendInputRange(const c10::intrusive_ptr<InputRange>& ir) {
input_ranges.push_back(*ir);
}

ADD_ENUM_GET_SET(op_precision, DataType, 3);
void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
device = *d;
}

ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
ADD_FIELD_GET_SET(refit, bool);
ADD_FIELD_GET_SET(debug, bool);
ADD_FIELD_GET_SET(strict_types, bool);
ADD_ENUM_GET_SET(capability, EngineCapability, 3);
ADD_ENUM_GET_SET(capability, EngineCapability, static_cast<int64_t>(EngineCapability::kSAFE_DLA));
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
ADD_FIELD_GET_SET(workspace_size, int64_t);
Expand Down
3 changes: 2 additions & 1 deletion tests/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ py_test(
] + select({
":aarch64_linux": [
"test_api_dla.py"
]
],
"//conditions:default" : []
}),
deps = [
requirement("torchvision")
Expand Down
12 changes: 8 additions & 4 deletions tests/py/test_to_backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def setUp(self):
"refit": False,
"debug": False,
"strict_types": False,
"allow_gpu_fallback": True,
"device_type": "gpu",
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True
},
"capability": trtorch.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
Expand All @@ -29,14 +33,14 @@ def setUp(self):
}

def test_to_backend_lowering(self):
trt_mod = torch._C._jit_to_tensorrt(self.scripted_model._c, {"forward": self.spec})
trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec)
same = (trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)


def test_suite():
suite = unittest.TestSuite()
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(TestToBackendLowering.parametrize(TestToBackendLowering, model=models.resnet18(pretrained=True)))

return suite

Expand Down

0 comments on commit 20022d4

Please sign in to comment.