diff --git a/src/relay/backend/aot/aot_lower_main.cc b/src/relay/backend/aot/aot_lower_main.cc index a97f7bcaa64a..5bfe7cb8e749 100644 --- a/src/relay/backend/aot/aot_lower_main.cc +++ b/src/relay/backend/aot/aot_lower_main.cc @@ -417,7 +417,7 @@ class AOTMainLowerer : public MixedModeVisitor { * runner function needs to be legalized by the LegalizePackedCalls pass. */ tir::PrimFunc CreateMainFunc(String mod_name) { - tir::Stmt body = tir::SeqStmt(stmts_); + tir::Stmt body = tir::SeqStmt::Flatten(stmts_); // Allocate the sids std::unordered_map allocated; std::vector> sids_to_allocate; @@ -717,7 +717,7 @@ class AOTMainLowerer : public MixedModeVisitor { {tvm::tir::StringImm(device_hook_name), context}))); device_hooks.push_back(device_hook); } - return tir::SeqStmt(device_hooks); + return tir::SeqStmt::Flatten(device_hooks); } /*! diff --git a/src/tir/contrib/ethosu/passes.cc b/src/tir/contrib/ethosu/passes.cc index 369c4adc8536..5968febe9af0 100644 --- a/src/tir/contrib/ethosu/passes.cc +++ b/src/tir/contrib/ethosu/passes.cc @@ -144,7 +144,7 @@ class HoistAllocatesMutator : public StmtExprMutator { for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) { Allocate current_alloc = *it; if (it != allocates_.rbegin()) { - new_main_func_body = SeqStmt({new_main_func_body}); + new_main_func_body = SeqStmt::Flatten(new_main_func_body); } new_main_func_body = Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents, diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 727b07c2e488..30b1bc78247a 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -295,7 +295,7 @@ TEST(IRF, StmtMutator) { Stmt body2 = Evaluate(1); auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. - body = SeqStmt({body}); + body = SeqStmt({body, body2}); auto bref = body; body = SeqStmt({body, body2}); body = v(std::move(body));