Skip to content

Commit

Permalink
feat(//core/converters): Add power layer conversion support and minor…
Browse files Browse the repository at this point in the history
… README edits

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Oct 5, 2020
1 parent e4a4574 commit a801506
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 19 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ then you have two options.
1. You need to download the tarball distributions of TensorRT and cuDNN from the NVIDIA website.
- https://developer.nvidia.com/cudnn
- https://developer.nvidia.com/tensorrt
2. Place these files in a directory (the directories `third_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu]` exist for this purpose)
2. Place these files in a directory (the directories `third_party/dist_dir/[x86_64-linux-gnu | aarch64-linux-gnu]` exist for this purpose)
3. Compile using:
``` shell
bazel build //:libtrtorch --compilation_mode opt --distdir third_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu]
bazel build //:libtrtorch --compilation_mode opt --distdir third_party/dist_dir/[x86_64-linux-gnu | aarch64-linux-gnu]
```

#### 2. Building using locally installed cuDNN & TensorRT
Expand Down Expand Up @@ -175,7 +175,7 @@ bazel build //:libtrtorch --compilation_mode=dbg

### Native compilation on NVIDIA Jetson AGX
``` shell
bazel build //:libtrtorch --distdir third_party/distdir/aarch64-linux-gnu
bazel build //:libtrtorch --distdir third_party/dist_dir/aarch64-linux-gnu
```
> Note: Please refer [installation](docs/tutorials/installation.html) instructions for Pre-requisites
Expand Down
45 changes: 45 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <torch/torch.h>
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"

Expand Down Expand Up @@ -180,6 +181,50 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
mul->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}
}).pattern({
"aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto exponent = args[1].ITensorOrFreeze(ctx);
auto pow = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponent, util::node_info(n));
TRTORCH_CHECK(pow, "Unable to create Power layer from node: " << *n);

pow->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}
}).pattern({
"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto exponentScalar = args[1].unwrapToScalar().to<float>();

// Calculate size of the input and define an exponent tensor of the same size
int volume = 1;
for (int i = 0; i < self->getDimensions().nbDims; i++) {
volume = volume * (self->getDimensions().d[i]);
}

// Create a torch tensor with constant exponent values
LOG_DEBUG("Broadcasting the exponent in power layer");
torch::Tensor exponentBlob = torch::full({volume}, exponentScalar);

// Create a corresponding constant layer in TRT and get the layer output.
auto weights = converters::Weights(ctx, exponentBlob);
auto exponentTensor = ctx->net->addConstant(self->getDimensions(), weights.data)->getOutput(0);

auto pow = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponentTensor, util::node_info(n));
TRTORCH_CHECK(pow, "Unable to create Power layer from node: " << *n);

pow->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion py/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ traced_model = torch.jit.trace(model, [data])

# Compile module
compiled_trt_model = trtorch.compile(model, {
"input_shape": [data.shape],
"input_shapes": [data.shape],
"op_precision": torch.half, # Run in FP16
})

Expand Down
54 changes: 39 additions & 15 deletions tests/core/converters/test_element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,39 @@
#include "tests/util/util.h"
#include "core/compiler.h"

void pointwise_test_helper(std::string graph_ir) {
void pointwise_test_helper(std::string graph_ir, bool singleInput) {
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph_ir, &*g);

auto in0 = at::randint(1, 5, {5}, {at::kCUDA});
auto in1 = at::randint(1, 5, {5}, {at::kCUDA});

// singleInput case is enabled when elementwise operation is performed
// with an input and a constant embedded in graph
std::vector<at::Tensor> torch_inputs;
torch_inputs.push_back(at::randint(1, 5, {5}, {at::kCUDA}));
if (!singleInput) {
torch_inputs.push_back(at::randint(1, 5, {5}, {at::kCUDA}));
}
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0, in1});
auto jit_results = trtorch::tests::util::RunGraph(g, params, torch_inputs);

std::vector<at::Tensor> trt_inputs;
for (auto in : torch_inputs) {
trt_inputs.push_back(at::clone(in));
}

in0 = at::clone(in0);
in1 = at::clone(in1);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0, in1});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, trt_inputs);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}



TEST(Converters, ATenAddConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::add(%0, %1, %2)
return (%3))IR";
pointwise_test_helper(graph);
pointwise_test_helper(graph, false);
}


Expand All @@ -39,7 +46,7 @@ TEST(Converters, ATenAddConvertsCorrectly) {
// %2 : int = prim::Constant[value=2]()
// %3 : Tensor = aten::add(%0, %1, %2)
// return (%3))IR";
// pointwise_test_helper(graph);
// pointwise_test_helper(graph, false);
// }

TEST(Converters, ATenSubConvertsCorrectly) {
Expand All @@ -48,7 +55,7 @@ TEST(Converters, ATenSubConvertsCorrectly) {
%2 : int = prim::Constant[value=1]()
%3 : Tensor = aten::sub(%0, %1, %2)
return (%3))IR";
pointwise_test_helper(graph);
pointwise_test_helper(graph, false);
}

// TEST(Converters, ATenSubWithScaleConvertsCorrectly) {
Expand All @@ -57,21 +64,38 @@ TEST(Converters, ATenSubConvertsCorrectly) {
// %2 : float = prim::Constant[value=0.5]()
// %3 : Tensor = aten::add(%0, %1, %2)
// return (%3))IR";
// pointwise_test_helper(graph);
// pointwise_test_helper(graph, false);
// }

TEST(Converters, ATenMulConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::mul(%0, %1)
return (%2))IR";
pointwise_test_helper(graph);
pointwise_test_helper(graph, false);
}

TEST(Converters, ATenDivConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::div(%0, %1)
return (%2))IR";
pointwise_test_helper(graph);
pointwise_test_helper(graph, false);
}

TEST(Converters, ATenPowTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor, %x2.1 : Tensor):
%3 : Tensor = aten::pow(%x.1, %x2.1)
return (%3))IR";
pointwise_test_helper(graph, false);
}

TEST(Converters, ATenPowScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=2]()
%3 : Tensor = aten::pow(%x.1, %2)
return (%3))IR";
pointwise_test_helper(graph, true);
}

0 comments on commit a801506

Please sign in to comment.