diff --git a/taichi/transforms/unreachable_code_elimination.cpp b/taichi/transforms/unreachable_code_elimination.cpp index 440df20680b44..186da4c8f3301 100644 --- a/taichi/transforms/unreachable_code_elimination.cpp +++ b/taichi/transforms/unreachable_code_elimination.cpp @@ -33,6 +33,7 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { using BasicStmtVisitor::visit; bool modified; UselessContinueEliminator useless_continue_eliminator; + DelayedIRModifier modifier; UnreachableCodeEliminator() : modified(false) { allow_undefined_visitor = true; @@ -87,16 +88,24 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { if (if_stmt->cond->is() && if_stmt->cond->width() == 1) { if (if_stmt->cond->as()->val[0].equal_value(0)) { // if (0) - if (if_stmt->true_statements) { - if_stmt->true_statements = nullptr; - modified = true; + if (if_stmt->false_statements) { + modifier.insert_before( + if_stmt, + VecStatement(std::move(if_stmt->false_statements->statements))); } + modifier.erase(if_stmt); + modified = true; + return; } else { // if (1) - if (if_stmt->false_statements) { - if_stmt->false_statements = nullptr; - modified = true; + if (if_stmt->true_statements) { + modifier.insert_before( + if_stmt, + VecStatement(std::move(if_stmt->true_statements->statements))); } + modifier.erase(if_stmt); + modified = true; + return; } } if (if_stmt->true_statements) @@ -110,6 +119,7 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { while (true) { UnreachableCodeEliminator eliminator; node->accept(&eliminator); + eliminator.modifier.modify_ir(); if (eliminator.modified || eliminator.useless_continue_eliminator.modified) { modified = true; diff --git a/tests/python/test_while.py b/tests/python/test_while.py index 4ef25e3173050..a6971b7fe5346 100644 --- a/tests/python/test_while.py +++ b/tests/python/test_while.py @@ -38,4 +38,4 @@ def func(): ret[None] = s func() - print(ret[None]) + assert ret[None] == 55