Skip to content

Commit

Permalink
[BugFix][TVMScript] Fix printer for dependent loops (apache#9506)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored and ylc committed Jan 7, 2022
1 parent 4939149 commit 900a575
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,27 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
/*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */
Doc PrintLoopStack();
/*!
* \brief Print all simple loops in stack into one line using tir_prefix_.grid().
* \brief Check whether a loop satisfies:
* 1. the loop is serial;
* 2. the loop has no annotation;
* 3. the loop starts from 0;
* 4. there is no optional information.
* \param for_op the for node to be checked
* \return A boolean indicating whether the input loop satisfies the above conditions
*/
bool IsSimpleLoop(const ForNode* for_op) {
return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
}
/*!
* \brief Check whether the `min` or `extent` of a loop depends on previous loops
* \param for_op The loop to be checked
* \return A boolean indicating whether the input loop depends on previous loops
*/
bool DependOnPrevLoops(const ForNode* for_op) {
auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); };
return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
}

/*!
* \brief Print additional info about expr in comment.
Expand Down Expand Up @@ -895,7 +909,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
bool simple_loop = IsSimpleLoop(op);
if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
// It is a loop that can be compressed, let the loops below print it out
if (simple_loop && body != nullptr && IsSimpleLoop(body)) {
if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) {
doc << Print(GetRef<For>(body));
TryDeallocVar(op->loop_var);
loop_var_map_.erase(op->loop_var.get());
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,5 +3176,19 @@ def test_div_mod():
assert isinstance(func.body[3].value, tvm.tir.Mod)


@T.prim_func
def loop_extent_dependent(a: T.handle) -> None:
A = T.match_buffer(a, [], dtype="int32")
for i in T.serial(0, 128):
for j in T.serial(0, i):
A[()] = A[()] + j


def test_loop_extent_dependent():
func = loop_extent_dependent
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 900a575

Please sign in to comment.