Skip to content

Commit

Permalink
feat: Add converter support for logical_and (#1856)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeliz-cruise authored and bowang007 committed Apr 29, 2023
1 parent e7333a6 commit df294de
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
23 changes: 23 additions & 0 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,29 @@ auto element_wise_registrations TORCHTRT_UNUSED =
return true;
}})
.pattern(
{"aten::logical_and(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// torch.logical_and autocasts inputs to bool
auto input_as_bool = [&](int idx) {
auto x = args[idx].ITensorOrFreeze(ctx);
if (x->getType() != nvinfer1::DataType::kBOOL) {
x = castITensor(
ctx, x, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_bool_" + str(idx)).c_str());
}
return x;
};
auto self = input_as_bool(0);
auto other = input_as_bool(1);

auto and_layer =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kAND, self, other, util::node_info(n) + "_and");
TORCHTRT_CHECK(and_layer, "Unable to create and layer from node: " << *n);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], and_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
{"aten::atan2(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Element-wise divide input Tensors, apply atan unary, apply quadrant correction
Expand Down
12 changes: 11 additions & 1 deletion tests/core/conversion/converters/test_comparators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,14 @@ TEST(Converters, ATenMinConvertsCorrectly) {
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}
}

TEST(Converters, ATenLogicalAndConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::logical_and(%0, %1)
return (%2))IR";
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kBool, at::kBool);
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kInt, at::kBool);
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kInt, at::kInt);
}

0 comments on commit df294de

Please sign in to comment.