Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(//core/converters): Handle dtype mismatch in elementwise ops #1238

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,6 @@ nvinfer1::ILayer* add_elementwise(
nvinfer1::ITensor* self,
nvinfer1::ITensor* other,
const std::string& name) {
if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) {
LOG_DEBUG("Type mismatch, casting other to " << self->getType());
other = castITensor(ctx, other, self->getType());
} else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) {
LOG_DEBUG("Type mismatch, casting self to " << other->getType());
self = castITensor(ctx, self, other->getType());
}
// ensure self to have larger number of dimension
bool swapSelfOther = false;
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
Expand Down Expand Up @@ -123,6 +116,19 @@ nvinfer1::ILayer* add_elementwise(
std::swap(self, other);
swapSelfOther = false;
}

// Two types are compatible if they are the same type or are both in the set {kFLOAT, kHALF}
auto fp32_and_fp16 = (self->getType() == nvinfer1::DataType::kFLOAT) && (other->getType() == nvinfer1::DataType::kHALF);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not handle INT32 which was causing tests to fail earlier

auto fp16_and_fp32 = (other->getType() == nvinfer1::DataType::kFLOAT) && (self->getType() == nvinfer1::DataType::kHALF);
if (!fp32_and_fp16 && !fp16_and_fp32){
if (self->getType() > other->getType()) {
LOG_DEBUG("Type mismatch in node : " << name << ", casting self to " << other->getType());
self = castITensor(ctx, self, other->getType());
} else if (self->getType() < other->getType()) {
LOG_DEBUG("Type mismatch in node : " << name << ", casting other to " << self->getType());
other = castITensor(ctx, other, self->getType());
}
}
auto ele = ctx->net->addElementWise(*self, *other, op);
ele->setName(name.c_str());
return ele;
Expand Down