From df294de4d7ec9172921c67f6a71f1116376eb7f3 Mon Sep 17 00:00:00 2001 From: Michael Feliz <104801882+mfeliz-cruise@users.noreply.github.com> Date: Thu, 27 Apr 2023 17:07:10 -0700 Subject: [PATCH] feat: Add converter support for logical_and (#1856) --- .../converters/impl/element_wise.cpp | 23 +++++++++++++++++++ .../converters/test_comparators.cpp | 12 +++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index a86307c682..ac781697b7 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -810,6 +810,29 @@ auto element_wise_registrations TORCHTRT_UNUSED = return true; }}) .pattern( + {"aten::logical_and(Tensor self, Tensor other) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // torch.logical_and autocasts inputs to bool + auto input_as_bool = [&](int idx) { + auto x = args[idx].ITensorOrFreeze(ctx); + if (x->getType() != nvinfer1::DataType::kBOOL) { + x = castITensor( + ctx, x, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_bool_" + str(idx)).c_str()); + } + return x; + }; + auto self = input_as_bool(0); + auto other = input_as_bool(1); + + auto and_layer = + add_elementwise(ctx, nvinfer1::ElementWiseOperation::kAND, self, other, util::node_info(n) + "_and"); + TORCHTRT_CHECK(and_layer, "Unable to create and layer from node: " << *n); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], and_layer->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + return true; + }}) + .pattern( {"aten::atan2(Tensor self, Tensor other) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Element-wise divide input Tensors, apply atan unary, apply quadrant correction diff --git a/tests/core/conversion/converters/test_comparators.cpp b/tests/core/conversion/converters/test_comparators.cpp index 0107fa6837..c82ca959e2 100644 --- a/tests/core/conversion/converters/test_comparators.cpp +++ b/tests/core/conversion/converters/test_comparators.cpp @@ -134,4 +134,14 @@ TEST(Converters, ATenMinConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); -} \ No newline at end of file +} + +TEST(Converters, ATenLogicalAndConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::logical_and(%0, %1) + return (%2))IR"; + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kBool, at::kBool); + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kInt, at::kBool); + pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kInt, at::kInt); +}