Skip to content

Commit

Permalink
Merge pull request #1286 from mfeliz-cruise/michael.feliz/square
Browse files Browse the repository at this point in the history
feat: Add support for aten::square
  • Loading branch information
peri044 authored Aug 19, 2022
2 parents 09a857f + 6b77b72 commit abf3d47
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
12 changes: 12 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,18 @@ auto element_wise_registrations TORCHTRT_UNUSED =
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
{"aten::square(Tensor self) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n));
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);

mul->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
{"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
8 changes: 8 additions & 0 deletions tests/core/conversion/converters/test_element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ TEST(Converters, ATenMulConvertsCorrectly) {
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
}

TEST(Converters, ATenSquareConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : Tensor = aten::square(%0)
return (%1))IR";
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
Expand Down

0 comments on commit abf3d47

Please sign in to comment.