Skip to content

Commit

Permalink
Avoid a couple more cases with flatten-able SeqStmt
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Apr 11, 2023
1 parent 0a3d3e0 commit a5eff13
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/relay/backend/aot/aot_lower_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class AOTMainLowerer : public MixedModeVisitor {
* runner function needs to be legalized by the LegalizePackedCalls pass.
*/
tir::PrimFunc CreateMainFunc(String mod_name) {
tir::Stmt body = tir::SeqStmt(stmts_);
tir::Stmt body = tir::SeqStmt::Flatten(stmts_);
// Allocate the sids
std::unordered_map<int, bool> allocated;
std::vector<std::pair<int64_t, int64_t>> sids_to_allocate;
Expand Down Expand Up @@ -717,7 +717,7 @@ class AOTMainLowerer : public MixedModeVisitor {
{tvm::tir::StringImm(device_hook_name), context})));
device_hooks.push_back(device_hook);
}
return tir::SeqStmt(device_hooks);
return tir::SeqStmt::Flatten(device_hooks);
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class HoistAllocatesMutator : public StmtExprMutator {
for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) {
Allocate current_alloc = *it;
if (it != allocates_.rbegin()) {
new_main_func_body = SeqStmt({new_main_func_body});
new_main_func_body = SeqStmt::Flatten(new_main_func_body);
}
new_main_func_body =
Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents,
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/ir_functor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ TEST(IRF, StmtMutator) {
Stmt body2 = Evaluate(1);
auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
body = SeqStmt({body, body2});
auto bref = body;
body = SeqStmt({body, body2});
body = v(std::move(body));
Expand Down

0 comments on commit a5eff13

Please sign in to comment.