Skip to content

Commit

Permalink
Merge pull request #319 from NVIDIA/mul_scalar
Browse files Browse the repository at this point in the history
Add mul.scalar converter
  • Loading branch information
narendasan authored Feb 23, 2021
2 parents f453ede + e0ce9f3 commit 34f84df
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
15 changes: 15 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,21 @@ auto element_wise_registrations TRTORCH_UNUSED =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);

mul->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto otherScalar = args[1].unwrapToScalar().to<float>();
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
auto mul =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);

mul->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
Expand Down
9 changes: 9 additions & 0 deletions tests/core/conversion/converters/test_element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ TEST(Converters, ATenMulConvertsCorrectly) {
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}

TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%scalar : float = prim::Constant[value=2.4]()
%1 : Tensor = aten::mul(%0, %scalar)
return (%1))IR";
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenDivConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
Expand Down

0 comments on commit 34f84df

Please sign in to comment.