diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index ef901a62aa..96249a644f 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -18,8 +18,10 @@ static auto shuffle_registrations TRTORCH_UNUSED = auto end_dim = args[2].unwrapToInt(); auto in_shape = util::toVec(in->getDimensions()); std::vector out_shape; - if (ctx->input_is_dynamic) { + if (ctx->input_is_dynamic && in_shape[0] != -1) { out_shape = std::vector({in_shape[0], -1}); + } else if (ctx->input_is_dynamic && in_shape[0] == -1) { + out_shape = std::vector({-1, -1 * std::accumulate(std::begin(in_shape), std::end(in_shape), 1, std::multiplies())}); } else { out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec(); } diff --git a/tests/core/conversion/converters/test_pooling.cpp b/tests/core/conversion/converters/test_pooling.cpp index 949e23c27b..dee8001bb7 100644 --- a/tests/core/conversion/converters/test_pooling.cpp +++ b/tests/core/conversion/converters/test_pooling.cpp @@ -402,7 +402,7 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) { auto trt_in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false); ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } diff --git a/tests/core/conversion/converters/test_shuffle.cpp b/tests/core/conversion/converters/test_shuffle.cpp index 5338e2ecab..9b9d43a4ed 100644 --- a/tests/core/conversion/converters/test_shuffle.cpp +++ b/tests/core/conversion/converters/test_shuffle.cpp @@ -186,7 +186,31 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicInput) { in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); - auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}); + auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, false); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + + +TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=1]() + %3 : Tensor = aten::flatten(%0, %1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {2, 3}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in}, true); auto trt = trt_results[0].reshape_as(jit_results[0]); ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index 167fbe79c0..e39fd80f5f 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -23,19 +23,29 @@ std::vector toInputRanges(std::vector return std::move(a); } -std::vector toInputRangesDynamic(std::vector ten) { +std::vector toInputRangesDynamic(std::vector ten, bool dynamic_batch) { std::vector a; for (auto i : ten) { auto opt = core::util::toVec(i.sizes()); - std::vector min_range(opt); - std::vector max_range(opt); + if (dynamic_batch) { + std::vector min_range(opt); + std::vector max_range(opt); - min_range[1] = ceil(opt[1] / 2.0); - max_range[1] = 2 * opt[1]; + min_range[0] = ceil(opt[0] / 2.0); + max_range[0] = 2 * opt[0]; - a.push_back(core::conversion::InputRange(min_range, opt, max_range)); + a.push_back(core::conversion::InputRange(min_range, opt, max_range)); + } else { + std::vector min_range(opt); + std::vector max_range(opt); + + min_range[1] = ceil(opt[1] / 2.0); + max_range[1] = 2 * opt[1]; + + a.push_back(core::conversion::InputRange(min_range, opt, max_range)); + } } return std::move(a); @@ -63,9 +73,10 @@ std::vector RunGraphEngine( std::vector RunGraphEngineDynamic( std::shared_ptr& g, core::conversion::GraphParams& named_params, - std::vector inputs) { + std::vector inputs, + bool dynamic_batch) { LOG_DEBUG("Running TRT version"); - auto in = toInputRangesDynamic(inputs); + auto in = toInputRangesDynamic(inputs, dynamic_batch); auto info = core::conversion::ConversionInfo(in); info.engine_settings.workspace_size = 1 << 20; std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params); diff --git a/tests/util/util.h b/tests/util/util.h index bf58d0ac1c..09f9281d63 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -35,7 +35,8 @@ std::vector RunGraphEngine( std::vector RunGraphEngineDynamic( std::shared_ptr& g, core::conversion::GraphParams& named_params, - std::vector inputs); + std::vector inputs, + bool dynamic_batch); // Run the forward method of a module and return results torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector inputs);