From c53818488218cf3dedaa1dd44f8971f08085d40c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 6 Aug 2022 23:34:02 -0700 Subject: [PATCH 1/3] fix: Handle different datatypes for elementwise ops Signed-off-by: Dheeraj Peri --- core/conversion/converters/converter_util.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 745261589e..59c0f026fe 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -116,6 +116,11 @@ nvinfer1::ILayer* add_elementwise( std::swap(self, other); swapSelfOther = false; } + if (self->getType() > other->getType()) { + self = castITensor(ctx, self, other->getType()); + } else if (self->getType() < other->getType()) { + other = castITensor(ctx, other, self->getType()); + } auto ele = ctx->net->addElementWise(*self, *other, op); ele->setName(name.c_str()); return ele; From afe6aa56df603c1d2894751d335434c3d6fb1569 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 6 Aug 2022 23:44:03 -0700 Subject: [PATCH 2/3] fix: Fix datatype mismatch issue in elementwise ops Signed-off-by: Dheeraj Peri --- core/conversion/converters/converter_util.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 59c0f026fe..ee8a5894f7 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -116,10 +116,16 @@ nvinfer1::ILayer* add_elementwise( std::swap(self, other); swapSelfOther = false; } - if (self->getType() > other->getType()) { - self = castITensor(ctx, self, other->getType()); - } else if (self->getType() < other->getType()) { - other = castITensor(ctx, other, self->getType()); + + // Two types are compatible if they are the same type or are both in the set {kFLOAT, kHALF} + auto fp32_and_fp16 = (self->getType() == nvinfer1::DataType::kFLOAT) && (other->getType() == nvinfer1::DataType::kHALF); + auto fp16_and_fp32 = (other->getType() == nvinfer1::DataType::kFLOAT) && (self->getType() == nvinfer1::DataType::kHALF); + if (!fp32_and_fp16 && !fp16_and_fp32){ + if (self->getType() > other->getType()) { + self = castITensor(ctx, self, other->getType()); + } else if (self->getType() < other->getType()) { + other = castITensor(ctx, other, self->getType()); + } } auto ele = ctx->net->addElementWise(*self, *other, op); ele->setName(name.c_str()); From 9194d19fc14b62c486cc3af497f05797b737ed1c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 9 Aug 2022 12:42:14 -0700 Subject: [PATCH 3/3] chore: refactor dtype handling in elementwise ops Signed-off-by: Dheeraj Peri --- core/conversion/converters/converter_util.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 7bca1a9ac0..68bfa933c1 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -65,13 +65,6 @@ nvinfer1::ILayer* add_elementwise( nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name) { - if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) { - LOG_DEBUG("Type mismatch, casting other to " << self->getType()); - other = castITensor(ctx, other, self->getType()); - } else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) { - LOG_DEBUG("Type mismatch, casting self to " << other->getType()); - self = castITensor(ctx, self, other->getType()); - } // ensure self to have larger number of dimension bool swapSelfOther = false; if (self->getDimensions().nbDims < other->getDimensions().nbDims) { @@ -129,8 +122,10 @@ nvinfer1::ILayer* add_elementwise( auto fp16_and_fp32 = (other->getType() == nvinfer1::DataType::kFLOAT) && (self->getType() == nvinfer1::DataType::kHALF); if (!fp32_and_fp16 && !fp16_and_fp32){ if (self->getType() > other->getType()) { + LOG_DEBUG("Type mismatch in node : " << name << ", casting self to " << other->getType()); self = castITensor(ctx, self, other->getType()); } else if (self->getType() < other->getType()) { + LOG_DEBUG("Type mismatch in node : " << name << ", casting other to " << self->getType()); other = castITensor(ctx, other, self->getType()); } }