From 189f2fe0662576e57f7dd87b381a1c89ac2ac6f9 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 3 Jul 2020 20:01:41 -0400 Subject: [PATCH 1/2] [opt] Flatten if(0) and if(1) --- .../unreachable_code_elimination.cpp | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/taichi/transforms/unreachable_code_elimination.cpp b/taichi/transforms/unreachable_code_elimination.cpp index 440df20680b44..cb4627b35e67d 100644 --- a/taichi/transforms/unreachable_code_elimination.cpp +++ b/taichi/transforms/unreachable_code_elimination.cpp @@ -33,6 +33,7 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { using BasicStmtVisitor::visit; bool modified; UselessContinueEliminator useless_continue_eliminator; + DelayedIRModifier modifier; UnreachableCodeEliminator() : modified(false) { allow_undefined_visitor = true; @@ -87,16 +88,20 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { if (if_stmt->cond->is() && if_stmt->cond->width() == 1) { if (if_stmt->cond->as()->val[0].equal_value(0)) { // if (0) - if (if_stmt->true_statements) { - if_stmt->true_statements = nullptr; - modified = true; - } + modifier.insert_before( + if_stmt, + VecStatement(std::move(if_stmt->false_statements->statements))); + modifier.erase(if_stmt); + modified = true; + return; } else { // if (1) - if (if_stmt->false_statements) { - if_stmt->false_statements = nullptr; - modified = true; - } + modifier.insert_before( + if_stmt, + VecStatement(std::move(if_stmt->true_statements->statements))); + modifier.erase(if_stmt); + modified = true; + return; } } if (if_stmt->true_statements) @@ -110,6 +115,7 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { while (true) { UnreachableCodeEliminator eliminator; node->accept(&eliminator); + eliminator.modifier.modify_ir(); if (eliminator.modified || eliminator.useless_continue_eliminator.modified) { modified = true; From 7d02f56e30a56eedfbb3a6b8085d03625c36c946 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 4 Jul 2020 01:01:20 -0400 Subject: [PATCH 2/2] Fix tests --- .../transforms/unreachable_code_elimination.cpp | 16 ++++++++++------ tests/python/test_while.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/taichi/transforms/unreachable_code_elimination.cpp b/taichi/transforms/unreachable_code_elimination.cpp index cb4627b35e67d..186da4c8f3301 100644 --- a/taichi/transforms/unreachable_code_elimination.cpp +++ b/taichi/transforms/unreachable_code_elimination.cpp @@ -88,17 +88,21 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { if (if_stmt->cond->is() && if_stmt->cond->width() == 1) { if (if_stmt->cond->as()->val[0].equal_value(0)) { // if (0) - modifier.insert_before( - if_stmt, - VecStatement(std::move(if_stmt->false_statements->statements))); + if (if_stmt->false_statements) { + modifier.insert_before( + if_stmt, + VecStatement(std::move(if_stmt->false_statements->statements))); + } modifier.erase(if_stmt); modified = true; return; } else { // if (1) - modifier.insert_before( - if_stmt, - VecStatement(std::move(if_stmt->true_statements->statements))); + if (if_stmt->true_statements) { + modifier.insert_before( + if_stmt, + VecStatement(std::move(if_stmt->true_statements->statements))); + } modifier.erase(if_stmt); modified = true; return; diff --git a/tests/python/test_while.py b/tests/python/test_while.py index 4ef25e3173050..a6971b7fe5346 100644 --- a/tests/python/test_while.py +++ b/tests/python/test_while.py @@ -38,4 +38,4 @@ def func(): ret[None] = s func() - print(ret[None]) + assert ret[None] == 55