diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 5877e13210..6f4bbb9d67 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -736,8 +736,22 @@ auto select_registrations TORCHTRT_UNUSED = {"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto condition = args[0].ITensorOrFreeze(ctx); + auto condition_nbDims = condition->getDimensions().nbDims; auto x = args[1].ITensorOrFreeze(ctx); + auto x_nbDims = x->getDimensions().nbDims; auto y = args[2].ITensorOrFreeze(ctx); + auto y_nbDims = y->getDimensions().nbDims; + + // Get maximum rank of all input tensors + auto max_nbDims = std::max(condition_nbDims, std::max(x_nbDims, y_nbDims)); + + // TensorRT requires all inputs to Select layers to have the same rank, so for each + // tensor input, ensure that its rank is equal to the maximum number of dimensions + // If not, left-pad the tensor dimension with 1s until the max rank is achieved + condition = + addPadding(ctx, n, condition, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false); + x = addPadding(ctx, n, x, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false); + y = addPadding(ctx, n, y, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false); auto layer = ctx->net->addSelect(*condition, *x, *y); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 40c5f11843..0e007271ec 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -1224,3 +1224,35 @@ TEST(Converters, WhereConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } + +TEST(Converters, WhereConvertsMismatchedShapesCorrectly) { + const auto graph = R"IR( + graph(%condition : Tensor, + %x : Tensor, + %y : Tensor): + %out : Tensor = aten::where(%condition, %x, %y) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + // As per Torch behavior, the input Tensors are expected to be broadcasted + // along their respective dimension in the largest-rank Tensor provided + auto condition = at::randint(0, 2, {7, 5}, {at::kCUDA}).to(torch::kBool); + auto x = at::randn({2, 7, 5}, {at::kCUDA}); + auto y = at::randn({5}, {at::kCUDA}); + + auto jit_condition = at::clone(condition); + auto jit_x = at::clone(x); + auto jit_y = at::clone(y); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y}); + + auto trt_condition = at::clone(condition); + auto trt_x = at::clone(x); + auto trt_y = at::clone(y); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +}