From 205d2795d1cccf5682022b1deed46964835d5874 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Thu, 15 Dec 2022 18:18:59 -0800 Subject: [PATCH] Fix crash when calling unbind on evaluated tensor --- core/conversion/converters/impl/select.cpp | 2 +- .../conversion/converters/test_select.cpp | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 6f4bbb9d67..910b8f7d6d 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -16,7 +16,7 @@ namespace impl { namespace { bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto numOutputs = 1, numRemainder = 0; std::vector sizes; diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 0e007271ec..e5576f0109 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -1122,6 +1122,34 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { } } +TEST(Converters, ATenUnbindEvaluatedTensor) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %3 : int[] = aten::size(%x.1) + %z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2) + %5 : int = prim::Constant[value=-1]() + %6 : Tensor[] = aten::unbind(%z.1, %5) + %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%6) + return (%o1.1, %o2.1))IR"; + + auto in = at::randint(1, 10, {2}, {at::kCUDA}); + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i]; + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i].cuda(), trt, 2e-6)); + } +} + TEST(Converters, ScatterValueConvertsCorrectly) { const auto graph = R"IR( graph(%data : Tensor,