Skip to content

Commit

Permalink
fix: fix aten::sub.scalar operator
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Jul 29, 2021
1 parent 2760b8d commit 9a09514
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,16 @@ auto element_wise_registrations TRTORCH_UNUSED =
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].unwrapToScalar().to<float>();
auto alpha = args[2].unwrapToScalar().to<float>();
auto scaled_val = other * alpha;

auto rhs = other * alpha;
if (1 != rhs) {
auto rhs_tensor = tensor_to_const(ctx, torch::tensor({rhs}));
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, rhs_tensor, util::node_info(n));
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
sub->setName(util::node_info(n).c_str());
LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions());
ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
return true;
} else {
LOG_DEBUG("Nothing to be done this layer, passing through input");
LOG_DEBUG("Output tensor shape: " << self->getDimensions());

ctx->AssociateValueAndTensor(n->outputs()[0], self);
return true;
}
auto scaled_other_tensor = tensor_to_const(ctx, torch::tensor({scaled_val}));
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, scaled_other_tensor, util::node_info(n));
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
sub->setName(util::node_info(n).c_str());
LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions());
ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));

return true;
}})
.pattern({"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar "
"alpha=1) -> (Tensor(a!))",
Expand Down

0 comments on commit 9a09514

Please sign in to comment.