Skip to content

Commit

Permalink
Merge pull request #1029 from blchu/bitwise_not
Browse files Browse the repository at this point in the history
feat (//core/conversion) : Add converter for torch.bitwise_not
  • Loading branch information
peri044 authored Jun 23, 2022
2 parents 7d84caf + e699800 commit 28bce22
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 1 deletion.
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cc_library(
"NodeConverterRegistry.cpp",
"impl/activation.cpp",
"impl/batch_norm.cpp",
"impl/bitwise.cpp",
"impl/cast.cpp",
"impl/concat.cpp",
"impl/constant.cpp",
Expand Down
55 changes: 55 additions & 0 deletions core/conversion/converters/impl/bitwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"

#include <torch/torch.h>

namespace torch_tensorrt {
namespace core {
namespace conversion {
namespace converters {
namespace impl {

auto bitwise_not_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::bitwise_not(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
nvinfer1::ILayer* out;

if (in->getType() == nvinfer1::DataType::kINT32) {
// Integer case, using ~x = -x - 1
auto neg_one = torch::tensor({-1}, util::TRTDataTypeToScalarType(in->getType()));
auto neg_one_const = tensor_to_const(ctx, neg_one);
auto neg = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
in,
neg_one_const,
util::node_info(n) + std::string("_Negation"));
TORCHTRT_CHECK(neg, "Unable to create prod layer from node: " << *n);
out = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
neg->getOutput(0),
neg_one_const,
util::node_info(n) + std::string("_SubOne"));
TORCHTRT_CHECK(out, "Unable to create sum layer from node: " << *n);
} else if (in->getType() == nvinfer1::DataType::kBOOL) {
// Boolean case
out = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);
TORCHTRT_CHECK(out, "Unable to create logical not layer from node: " << *n);
} else {
LOG_ERROR("Input tensor must be 32 bit integer or boolean");
return false;
}

out->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}});

} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
5 changes: 5 additions & 0 deletions tests/core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ converter_test(
name = "test_batch_norm",
)

converter_test(
name = "test_bitwise",
)

converter_test(
name = "test_instance_norm",
)
Expand Down Expand Up @@ -136,6 +140,7 @@ test_suite(
tests = [
":test_activation",
":test_batch_norm",
":test_bitwise",
":test_instance_norm",
":test_cast",
":test_clone",
Expand Down
42 changes: 42 additions & 0 deletions tests/core/conversion/converters/test_bitwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <string>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"

std::string gen_test_graph() {
return R"IR(
graph(%0: Tensor):
%3 : Tensor = aten::bitwise_not(%0)
return (%3))IR";
}

#define test_bitwise_not(dtype) \
TEST(Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \
const auto graph = gen_test_graph(); \
\
auto g = std::make_shared<torch::jit::Graph>(); \
torch::jit::parseIR(graph, g.get()); \
\
at::Tensor in; \
if (strcmp(#dtype, "Integer") == 0) \
in = at::randint(-128, 128, {10}, {at::kCUDA}).toType(at::kInt); \
if (strcmp(#dtype, "Boolean") == 0) \
in = at::randint(0, 1, {10}, {at::kCUDA}).toType(at::kBool); \
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); \
\
in = at::clone(in); \
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); \
\
auto jit_int = jit_results[0].toType(at::kInt); \
auto trt_int = trt_results[0].toType(at::kInt); \
\
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_int, trt_int)); \
}

test_bitwise_not(Integer);
test_bitwise_not(Boolean);

#undef test_bitwise_not
3 changes: 2 additions & 1 deletion tests/util/run_graph_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/ir/ir.h"
#include "core/runtime/runtime.h"
#include "core/util/prelude.h"
#include "core/util/trt_util.h"
#include "cuda_runtime_api.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/ir/irparser.h"
Expand All @@ -19,7 +20,7 @@ namespace util {
std::vector<core::ir::Input> toInputs(std::vector<at::Tensor> ten) {
std::vector<core::ir::Input> a;
for (auto i : ten) {
a.push_back(core::ir::Input(core::util::toVec(i.sizes())));
a.push_back(core::ir::Input(core::util::toVec(i.sizes()), core::util::ScalarTypeToTRTDataType(i.scalar_type())));
}
return a;
}
Expand Down

0 comments on commit 28bce22

Please sign in to comment.