diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 94ac827ef4..68bfa933c1 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -65,13 +65,6 @@ nvinfer1::ILayer* add_elementwise( nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name) { - if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) { - LOG_DEBUG("Type mismatch, casting other to " << self->getType()); - other = castITensor(ctx, other, self->getType()); - } else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) { - LOG_DEBUG("Type mismatch, casting self to " << other->getType()); - self = castITensor(ctx, self, other->getType()); - } // ensure self to have larger number of dimension bool swapSelfOther = false; if (self->getDimensions().nbDims < other->getDimensions().nbDims) { @@ -123,6 +116,19 @@ nvinfer1::ILayer* add_elementwise( std::swap(self, other); swapSelfOther = false; } + + // Two types are compatible if they are the same type or are both in the set {kFLOAT, kHALF} + auto fp32_and_fp16 = (self->getType() == nvinfer1::DataType::kFLOAT) && (other->getType() == nvinfer1::DataType::kHALF); + auto fp16_and_fp32 = (other->getType() == nvinfer1::DataType::kFLOAT) && (self->getType() == nvinfer1::DataType::kHALF); + if (!fp32_and_fp16 && !fp16_and_fp32){ + if (self->getType() > other->getType()) { + LOG_DEBUG("Type mismatch in node : " << name << ", casting self to " << other->getType()); + self = castITensor(ctx, self, other->getType()); + } else if (self->getType() < other->getType()) { + LOG_DEBUG("Type mismatch in node : " << name << ", casting other to " << self->getType()); + other = castITensor(ctx, other, self->getType()); + } + } auto ele = ctx->net->addElementWise(*self, *other, op); ele->setName(name.c_str()); return ele;