Skip to content

Commit

Permalink
Fix bug: correct the output shape of aten::index.Tensor (#1314)
Browse files Browse the repository at this point in the history
* support multiple indices for aten::index.Tensor

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

* fix: correct output shape of aten::index.Tensor

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
  • Loading branch information
ruoqianguo authored Aug 28, 2022
1 parent ce67ceb commit d69651c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
3 changes: 2 additions & 1 deletion core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ auto select_registrations TORCHTRT_UNUSED =
.pattern(
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// refer to https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4627
auto in = args[0].ITensorOrFreeze(ctx);
auto ts = args[1].IValue()->toListRef();

Expand Down Expand Up @@ -471,7 +472,7 @@ auto select_registrations TORCHTRT_UNUSED =
}
}
auto concat_final_shape_layer =
ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
ctx->net->addConcatenation(concat_final_tensors.data(), concat_final_tensors.size());
auto unfold_advanced_shuffle_layer = ctx->net->addShuffle(*shuffle_out);
unfold_advanced_shuffle_layer->setInput(1, *concat_final_shape_layer->getOutput(0));
reshape_output = unfold_advanced_shuffle_layer->getOutput(0);
Expand Down
31 changes: 31 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,37 @@ TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) {
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
%index0 : Tensor,
%index1 : Tensor,
%index2 : Tensor):
%5 : NoneType = prim::Constant()
%18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5)
%19 : Tensor = aten::index(%x.1, %18)
return (%19))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA});
auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong);
auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong);
auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong);
auto index0_trt = index0.to(torch::kInt32);
auto index1_trt = index1.to(torch::kInt32);
auto index2_trt = index2.to(torch::kInt32);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenUnbindConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down

0 comments on commit d69651c

Please sign in to comment.