diff --git a/core/conversion/converters/impl/unary.cpp b/core/conversion/converters/impl/unary.cpp index fa4e88fa5e..c78602963c 100644 --- a/core/conversion/converters/impl/unary.cpp +++ b/core/conversion/converters/impl/unary.cpp @@ -49,6 +49,21 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern } }}); +auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::reciprocal(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + if (in->getType() == nvinfer1::DataType::kINT32) { + // pytorch implicitly casts to float for aten::reciprocal(int) + in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT); + } + auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kRECIP); + TORCHTRT_CHECK(unary_layer, "Unable to create recip layer from node: " << *n); + unary_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + }}); + #define convert(unary, trt_type) \ auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \ {"aten::" #unary "(Tensor self) -> Tensor", \ @@ -74,7 +89,6 @@ convert(sinh, kSINH); convert(tan, kTAN); convert(atan, kATAN); convert(floor, kFLOOR); -convert(reciprocal, kRECIP); convert(log, kLOG); convert(ceil, kCEIL); convert(sqrt, kSQRT); diff --git a/tests/core/conversion/converters/test_unary.cpp b/tests/core/conversion/converters/test_unary.cpp index 1d40c3c94b..06f092ff36 100644 --- a/tests/core/conversion/converters/test_unary.cpp +++ b/tests/core/conversion/converters/test_unary.cpp @@ -31,6 +31,22 @@ TEST(Converters, ATenAbsIntConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0])); } +TEST(Converters, ATenReciprocalIntConvertsCorrectly) { + const auto graph = gen_test_graph("reciprocal"); + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::tensor({-1, 1, -2, 2, -3, 3}, {at::kCUDA}).to(torch::kInt32); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0])); +} + #define test_unary(unary, name) \ TEST(Converters, ATen##name##ConvertsCorrectly) { \ const auto graph = gen_test_graph(#unary); \