Skip to content

Commit

Permalink
refactor: Upgrading to LibTorch 1.5.0 (CUDA 10.2, cuDNN 7.6.5, TensorRT
Browse files Browse the repository at this point in the history
7.0.0)

- Closes #42
- Issue #1 is back, unknown root cause, will follow up with the PyTorch
Team
- Closes #14: The default build now requires users to grab the tarballs
from the NVIDIA website to support hermetic builds, may look at some
methods to smooth this out later. The old method is still available
- New operators need to be implemented to support MobileNet in 1.5.0
(blocks merge into master)

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 29, 2020
1 parent 36d27da commit a51c7b6
Show file tree
Hide file tree
Showing 48 changed files with 459 additions and 233 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ cpp/ptq/datasets/data/
tests/accuracy/datasets/data/*
._.DS_Store
*.tar.gz
*.tgz
81 changes: 71 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

> Ahead of Time (AOT) compiling for PyTorch JIT
TRTorch is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, TRTorch is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. TRTorch operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/F16) and other settings for your module.
TRTorch is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, TRTorch is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into an module targeting a TensorRT engine. TRTorch operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/F16/INT8) and other settings for your module.

More Information / System Architecture:

Expand Down Expand Up @@ -35,28 +35,89 @@ auto results = trt_mod.forward({in_tensor});
| Platform | Support |
| -------- | ------- |
| Linux AMD64 / GPU | **Supported** |
| Linux aarch64 / GPU | **Planned/Possible with Native Compiation and small modifications to the build system** |
| Linux aarch64 / GPU | **Planned/Possible with Native Compiation but untested** |
| Linux aarch64 / DLA | **Planned/Possible with Native Compilation but untested** |
| Windows / GPU | - |
| Linux ppc64le / GPU | - |
### Dependencies
- Libtorch 1.4.0
- CUDA 10.1
- cuDNN 7.6
- TensorRT 6.0.1
- Libtorch 1.5.0
- CUDA 10.2
- cuDNN 7.6.5
- TensorRT 7.0.0
## Prebuilt Binaries
Releases: https://github.com/NVIDIA/TRTorch/releases
## Compiling TRTorch
Install TensorRT, CUDA and cuDNN on the system before starting to compile.
### Installing Dependencies
You need to start by having CUDA installed on the system, Libtorch will automatically be pulled for you by bazel,
then you have two options.
#### 1. Building using cuDNN & TensorRT tarball distributions
> This is recommended so as to build TRTorch hermetically and insures any bugs are not caused by version issues
> Make sure when running TRTorch that these versions of the libraries are prioritized in your `$LD_LIBRARY_PATH`
1. You need to download the tarball distributions of TensorRT and cuDNN from the NVIDIA website.
- https://developer.nvidia.com/cudnn
- https://developer.nvidia.com/tensorrt
2. Place these files in a directory (the directories `thrid_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu]` exist for this purpose)
3. Compile using:
``` shell
bazel build //:libtrtorch --compilation_mode opt --distdir thrid_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu]
```

#### 2. Building using locally installed cuDNN & TensorRT

> If you find bugs and you compiled using this method please disclose it in the issue
> (an `ldd` dump would be nice too)
1. Install TensorRT, CUDA and cuDNN on the system before starting to compile.
2. In `WORKSPACE` comment out
```py
# Downloaded distributions to use with --distdir
http_archive(
name = "cudnn",
urls = ["<URL>",],

build_file = "@//third_party/cudnn/archive:BUILD",
sha256 = "<TAR SHA256>",
strip_prefix = "cuda"
)

http_archive(
name = "tensorrt",
urls = ["<URL>",],

build_file = "@//third_party/tensorrt/archive:BUILD",
sha256 = "<TAR SHA256>",
strip_prefix = "TensorRT-<VERSION>"
)
```
and uncomment
```py
# Locally installed dependencies
new_local_repository(
name = "cudnn",
path = "/usr/",
build_file = "@//third_party/cudnn/local:BUILD"
)

new_local_repository(
name = "tensorrt",
path = "/usr/",
build_file = "@//third_party/tensorrt/local:BUILD"
)
```
3. Compile using:
``` shell
bazel build //:libtrtorch --compilation_mode=opt
bazel build //:libtrtorch --compilation_mode opt
```

### Debug build
Expand Down Expand Up @@ -84,9 +145,9 @@ Thanks for wanting to contribute! There are two main ways to handle supporting a

### In my application?

> The Node Converter Registry is not exposed in the top level API but you can try using the internal headers shipped with the tarball.
> The Node Converter Registry is not exposed in the top level API but in the internal headers shipped with the tarball.
You can register a converter for your op using the NodeConverterRegistry inside your application.
You can register a converter for your op using the `NodeConverterRegistry` inside your application.

## Structure of the repo

Expand Down
54 changes: 37 additions & 17 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@ py_repositories()
load("@rules_python//python:pip.bzl", "pip_repositories", "pip_import")
pip_repositories()

http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.4.0.zip"],
sha256 = "f214bfde532877aa5d4e0803e51a28fa8edd97b6a44b6615f75a70352b6b542e"
)

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "rules_pkg",
url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.4/rules_pkg-0.2.4.tar.gz",
Expand All @@ -34,24 +25,53 @@ http_archive(
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
rules_pkg_dependencies()

# CUDA should be installed on the system locally
new_local_repository(
name = "cuda",
path = "/usr/local/cuda-10.1/targets/x86_64-linux/",
path = "/usr/local/cuda-10.2/targets/x86_64-linux/",
build_file = "@//third_party/cuda:BUILD",
)

new_local_repository(
http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.5.0.zip"],
sha256 = "0efdd4e709ab11088fa75f0501c19b0e294404231442bab1d1fb953924feb6b5"
)

# Downloaded distributions to use with --distdir
http_archive(
name = "cudnn",
path = "/usr/",
build_file = "@//third_party/cudnn:BUILD"
urls = ["https://developer.nvidia.com/compute/machine-learning/cudnn/secure/7.6.5.32/Production/10.2_20191118/cudnn-10.2-linux-x64-v7.6.5.32.tgz",],

build_file = "@//third_party/cudnn/archive:BUILD",
sha256 = "600267f2caaed2fd58eb214ba669d8ea35f396a7d19b94822e6b36f9f7088c20",
strip_prefix = "cuda"
)

new_local_repository(
name = "tensorrt",
path = "/usr/",
build_file = "@//third_party/tensorrt:BUILD"
http_archive(
name = "tensorrt",
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.0/7.0.0.11/tars/TensorRT-7.0.0.11.Ubuntu-18.04.x86_64-gnu.cuda-10.2.cudnn7.6.tar.gz",],

build_file = "@//third_party/tensorrt/archive:BUILD",
sha256 = "c7d73b2585b18aae68b740249efa8c8ba5ae852abe9a023720595432a8eb4efd",
strip_prefix = "TensorRT-7.0.0.11"
)

## Locally installed dependencies
# new_local_repository(
# name = "cudnn",
# path = "/usr/",
# build_file = "@//third_party/cudnn/local:BUILD"
#)

# new_local_repository(
# name = "tensorrt",
# path = "/usr/",
# build_file = "@//third_party/tensorrt/local:BUILD"
#)

git_repository(
name = "googletest",
remote = "https://github.com/google/googletest",
Expand Down
21 changes: 13 additions & 8 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

#include "ATen/core/function_schema.h"

#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/pass_manager.h"
#include "torch/csrc/jit/frontend/function_schema_parser.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/passes/pass_manager.h"
#include "torch/csrc/jit/passes/lower_graph.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/script/module.h"
#include "torch/csrc/jit/script/function_schema_parser.h"

#include "core/util/prelude.h"
#include "core/compiler.h"
Expand Down Expand Up @@ -42,25 +41,31 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str

void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
auto schema = execution::GetEngineFunctionSchema(uid);
auto num_io = execution::GetEngineIO(uid);

auto self = g->addInput("self.1");
self->setType(mod.type());
std::vector<torch::jit::Value*> graph_inputs;

auto id_val = g->insertConstant(uid);

std::vector<torch::jit::Value*> engine_inputs;
engine_inputs.push_back(id_val);

for (uint64_t i = 0; i < num_io.first; i++) {
auto in_val = g->addInput("");
in_val->setType(c10::TensorType::get());
graph_inputs.push_back(in_val);
engine_inputs.push_back(in_val);
}

auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second);
auto engine_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs), num_io.second);
g->block()->appendNode(engine_node);

for (auto o : engine_node->outputs()) {
g->registerOutput(o);
}

LOG_DEBUG(*g << "(AddEngineToGraph)\n");

return;
}

Expand Down
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <vector>
#include "torch/csrc/jit/script/module.h"
#include "torch/csrc/jit/api/module.h"
#include "core/conversion/conversion.h"

namespace trtorch {
Expand Down
4 changes: 1 addition & 3 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ namespace conversion {
bool isNodeConversionBlacklisted(const torch::jit::Node* n);

bool OpSupported(const torch::jit::Node* n) {
bool evalable = evaluators::shouldEvalAtConversionTime(n);
bool convertable = converters::node_is_convertable(n);
return evalable || convertable;
return evaluators::shouldEvalAtConversionTime(n) || converters::node_is_convertable(n);
}

c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level=0, int limit=10) {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <map>

#include "NvInfer.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/ir/ir.h"
#include "core/conversion/conversionctx/ConversionCtx.h"

namespace torch {
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/conversion_blacklist.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include <string>
#include <unordered_set>

#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/ir/ir.h"

namespace trtorch {
namespace core {
namespace conversion {

const std::unordered_set<std::string>& get_non_convertable_nodes() {
// Set of nodes that should not invoke a converter or evaluator
static std::unordered_set<std::string> nonconvertable_nodes = {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <memory>

//#include "ATen/ATen.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/ir/ir.h"
#include "NvInfer.h"

#include "core/util/prelude.h"
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/NodeConverterRegistry.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"
#include "torch/csrc/jit/script/function_schema_parser.h"
#include "torch/csrc/jit/frontend/function_schema_parser.h"

namespace trtorch {
namespace core {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <string>
#include <map>

#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/runtime/custom_operator.h"
#include "ATen/core/function_schema.h"

#include "core/conversion/conversionctx/ConversionCtx.h"
Expand Down
6 changes: 3 additions & 3 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <unordered_map>

#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/constants.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/ir/constants.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/List.h"
Expand Down Expand Up @@ -41,7 +41,7 @@ class NodeEvaluatorRegistry {
return true;
}
}

private:
EvaluatorLUT evaluator_lut_;
};
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/evaluators.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <string>
#include <map>

#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/ir/ir.h"

namespace trtorch {
namespace core {
Expand All @@ -19,7 +19,7 @@ typedef std::map<const torch::jit::Value*, const torch::jit::IValue*> kwargs;
// when writing evaluators
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, const kwargs&)> NodeEvaluator;

struct EvalRegistration {
struct EvalRegistration {
torch::jit::NodeKind kind;
NodeEvaluator evaluator;
};
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/constants.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/ir/constants.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/List.h"
Expand Down
3 changes: 2 additions & 1 deletion core/execution/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ cc_library(
"@tensorrt//:nvinfer",
"@libtorch//:libtorch",
"//core/util:prelude"
]
],
alwayslink = True,
)

load("@rules_pkg//:pkg.bzl", "pkg_tar")
Expand Down
Loading

0 comments on commit a51c7b6

Please sign in to comment.