From a9f33e4d6645d6d87844f38d4c6a03257e45204b Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sun, 3 May 2020 18:43:44 -0700 Subject: [PATCH] fix(//core/conversion/converters/impl/element_wise): Fix broadcast support Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .../converters/impl/element_wise.cpp | 30 +++++++++++-------- tests/accuracy/accuracy_test.h | 1 + tests/accuracy/test_fp16_accuracy.cpp | 4 +-- tests/accuracy/test_fp32_accuracy.cpp | 4 +-- tests/accuracy/test_int8_accuracy.cpp | 4 +-- 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 80d5cdbe93..375e7a2d8f 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -8,18 +8,22 @@ namespace converters { namespace impl { namespace { -nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, float scalar=1) { +nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name, float scalar=1) { auto self_dims = self->getDimensions(); + auto self_dims_vec = util::toVec(self_dims); auto other_dims = other->getDimensions(); + auto other_dims_vec = util::toVec(other_dims); + auto other_batch = other_dims_vec[0]; - TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims); + // TODO: Proper broadcast check + TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims) || util::volume(self_dims) == util::volume(other_dims) / other_batch, "Found inputs to elementwise operation do not have the same number of elements or is not broadcastable:\n Found: self " << self_dims << " other " << other_dims); if (self_dims != other_dims) { LOG_DEBUG("Input shape dont match inserting shuffle layers to reshape to " << self_dims); - auto other_shuffle = ctx->net->addShuffle(*other); - other_shuffle->setReshapeDimensions(self_dims); - other_shuffle->setName(std::string("[Reshape other to " + util::toStr(self_dims) + ']').c_str()); - other = other_shuffle->getOutput(0); + auto self_shuffle = ctx->net->addShuffle(*self); + self_shuffle->setReshapeDimensions(util::toDimsPad(self_dims_vec, other_dims_vec.size())); + self_shuffle->setName(std::string("[Reshape self to " + util::toStr(self_dims) + " for broadcasting (" + name + ")]").c_str()); + self = self_shuffle->getOutput(0); } @@ -72,7 +76,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); - auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar); + auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar); TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n); @@ -89,7 +93,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); - auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar); + auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar); TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n); @@ -106,7 +110,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() auto self = args[0].ITensor(); auto other = args[1].ITensor(); auto scalar = args[2].unwrapToScalar().to(); - auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, scalar); + auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, util::node_info(n), scalar); TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n); @@ -122,7 +126,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // Should implement self / other auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other); + auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); @@ -138,7 +142,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // TODO: Remove with functionalization auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other); + auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); @@ -154,7 +158,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // Should implement self * other auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); @@ -170,7 +174,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns() // TODO: Remove with functionalization auto self = args[0].ITensor(); auto other = args[1].ITensor(); - auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); diff --git a/tests/accuracy/accuracy_test.h b/tests/accuracy/accuracy_test.h index 229608de6d..28cfe9cdca 100644 --- a/tests/accuracy/accuracy_test.h +++ b/tests/accuracy/accuracy_test.h @@ -20,6 +20,7 @@ class AccuracyTests std::cerr << "error loading the model\n"; return; } + mod.eval(); } void TearDown() { diff --git a/tests/accuracy/test_fp16_accuracy.cpp b/tests/accuracy/test_fp16_accuracy.cpp index 7ebcc8b0fb..6de40a6c31 100644 --- a/tests/accuracy/test_fp16_accuracy.cpp +++ b/tests/accuracy/test_fp16_accuracy.cpp @@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { jit_total += targets.sizes()[0]; jit_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor jit_accuracy = jit_correct / jit_total; + torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; std::vector> input_shape = {{32, 3, 32, 32}}; auto extra_info = trtorch::ExtraInfo({input_shape}); @@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { trt_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor trt_accuracy = trt_correct / trt_total; + torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100; ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); } diff --git a/tests/accuracy/test_fp32_accuracy.cpp b/tests/accuracy/test_fp32_accuracy.cpp index b014340e82..d3d8bddb96 100644 --- a/tests/accuracy/test_fp32_accuracy.cpp +++ b/tests/accuracy/test_fp32_accuracy.cpp @@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { jit_total += targets.sizes()[0]; jit_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor jit_accuracy = jit_correct / jit_total; + torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; std::vector> input_shape = {{32, 3, 32, 32}}; auto extra_info = trtorch::ExtraInfo({input_shape}); @@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { trt_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor trt_accuracy = trt_correct / trt_total; + torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100; ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); } diff --git a/tests/accuracy/test_int8_accuracy.cpp b/tests/accuracy/test_int8_accuracy.cpp index 07d399c96d..aa4824948a 100644 --- a/tests/accuracy/test_int8_accuracy.cpp +++ b/tests/accuracy/test_int8_accuracy.cpp @@ -54,7 +54,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { jit_total += targets.sizes()[0]; jit_correct += torch::sum(torch::eq(predictions, targets)); } - torch::Tensor jit_accuracy = jit_correct / jit_total; + torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; // Compile Graph auto trt_mod = trtorch::CompileGraph(mod, extra_info); @@ -72,7 +72,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { trt_total += targets.sizes()[0]; trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); } - torch::Tensor trt_accuracy = trt_correct / trt_total; + torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100; ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3)); }