Skip to content

Commit

Permalink
add support for aten::reciprocal(int) (#1308)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeliz-cruise authored Sep 8, 2022
1 parent b25738e commit 096fd41
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
16 changes: 15 additions & 1 deletion core/conversion/converters/impl/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern
}
}});

auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::reciprocal(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
if (in->getType() == nvinfer1::DataType::kINT32) {
// pytorch implicitly casts to float for aten::reciprocal(int)
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
}
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kRECIP);
TORCHTRT_CHECK(unary_layer, "Unable to create recip layer from node: " << *n);
unary_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}});

#define convert(unary, trt_type) \
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
{"aten::" #unary "(Tensor self) -> Tensor", \
Expand All @@ -74,7 +89,6 @@ convert(sinh, kSINH);
convert(tan, kTAN);
convert(atan, kATAN);
convert(floor, kFLOOR);
convert(reciprocal, kRECIP);
convert(log, kLOG);
convert(ceil, kCEIL);
convert(sqrt, kSQRT);
Expand Down
16 changes: 16 additions & 0 deletions tests/core/conversion/converters/test_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ TEST(Converters, ATenAbsIntConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenReciprocalIntConvertsCorrectly) {
const auto graph = gen_test_graph("reciprocal");
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::tensor({-1, 1, -2, 2, -3, 3}, {at::kCUDA}).to(torch::kInt32);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
}

#define test_unary(unary, name) \
TEST(Converters, ATen##name##ConvertsCorrectly) { \
const auto graph = gen_test_graph(#unary); \
Expand Down

0 comments on commit 096fd41

Please sign in to comment.