Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add TIR While node #7425

Merged
merged 29 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2fad3f1
add while node
masahi Feb 6, 2021
4382dbf
update visitors
masahi Feb 6, 2021
81ddea4
binary search lowering works
masahi Feb 6, 2021
322d025
llvm codegen working
masahi Feb 6, 2021
45647b1
cuda codegen working
masahi Feb 7, 2021
5592b25
nms updated to use while loop
masahi Feb 7, 2021
7896302
add missing upper bound check too
masahi Feb 7, 2021
f041f85
add mandelbrot test
masahi Feb 7, 2021
b53150c
add gpu mandel
masahi Feb 7, 2021
265cffc
rename test
masahi Feb 9, 2021
0b6a93b
run black
masahi Feb 9, 2021
3dfe189
add doc
masahi Feb 9, 2021
6834cc3
add collatz test
masahi Feb 9, 2021
92d9add
add while + vectorize test
masahi Feb 9, 2021
e56d570
simplify bin search
masahi Feb 9, 2021
ef64278
Add special case visit method to storage_access.cc
masahi Feb 10, 2021
ff86f16
disallow while loop inside vectorized loop
masahi Feb 10, 2021
3f77a9e
disallow trivial condition since we do not have break
masahi Feb 10, 2021
220c7eb
error out in CoprocSync for now
masahi Feb 10, 2021
3817b5a
error out LiftAttrScope for now
masahi Feb 10, 2021
384ac45
add placeholder to inject_vpthread
masahi Feb 10, 2021
da3ca49
refactor to use MakeAttach
masahi Feb 13, 2021
626c7ff
handle WhileNode in InplaceOpVerifier
masahi Feb 13, 2021
0fadb47
error out in InjectVirtualThread
masahi Feb 13, 2021
45818ea
try handle WhileNode in StoragePlanRewriter
masahi Feb 13, 2021
f442ecc
remove WhileNode visitor from storage rewrite
masahi Mar 2, 2021
3012876
add while loop storage rewrite test
masahi Mar 2, 2021
c3af5ae
update tests
masahi Mar 2, 2021
35b8e28
move test_vectorize_while_fail to test_tir_transform_vectorize.py
masahi Mar 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,53 @@ class For : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
};

/*!
* \brief A While loop
*
* \code
*
* while (condition)
* body
*
* \endcode
*/
class WhileNode : public StmtNode {
public:
/*! \brief The termination condition. */
PrimExpr condition;
/*! \brief The body of the while loop. */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("span", &span);
}

bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(condition);
hash_reduce.DefHash(body);
}

static constexpr const char* _type_key = "tir.While";
TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode);
};

/*!
* \brief Managed reference to WhileNode.
* \sa WhileNode
*/
class While : public Stmt {
public:
TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
};

/*!
* \brief A prefetch hint for a buffer
*/
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -111,6 +112,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
IR_STMT_FUNCTOR_DISPATCH(ForNode);
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need checks through the current passes, per my comment

IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
Expand Down Expand Up @@ -152,6 +154,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
Expand Down Expand Up @@ -245,6 +248,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const IfThenElseNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,35 @@ def _exit_cb():

return WithScope(loop_var, _exit_cb)

def while_loop(self, condition):
"""Create a while loop scope.

Parameters
----------
condition : Expr
The termination condition.

Returns
-------
loop_scope : With.Scope of Var
The while scope.

Examples
--------
.. code-block:: python

ib = tvm.tir.ir_builder.create()
iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
with ib.while_loop(iterations[0] < 10):
iterations[0] += 1
"""
self._seq_stack.append([])

def _exit_cb():
self.emit(_stmt.While(condition, self._pop_seq()))

return WithScope(None, _exit_cb)

def if_scope(self, cond):
"""Create an if scope.

Expand Down
25 changes: 25 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,31 @@ def __init__(
)


@tvm._ffi.register_object("tir.While")
class While(Stmt):
"""While node.

Parameters
----------
condition : PrimExpr
The termination condition.

body : Stmt
The body statement.

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, condition, body, span=None):
self.__init_handle_by_constructor__(
_ffi_api.While,
condition,
body,
span,
)


@tvm._ffi.register_object("tir.Store")
class Store(Stmt):
"""Store node.
Expand Down
28 changes: 17 additions & 11 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def nms_inner_loop(ib, j):
offset_j = j * 4
num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)

with ib.for_range(0, num_iter_per_thread) as _k:
with ib.for_range(0, num_iter_per_thread, name="_k") as _k:
k = j + 1 + _k * nthread_tx + tx
offset_k = k * 4

Expand Down Expand Up @@ -555,16 +555,22 @@ def nms_inner_loop(ib, j):

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size
# boxes
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size boxes
box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local")
box_idx[0] = 0
with ib.while_loop(
tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size)
):
# Proceed to the inner loop if the box with id box_idx is still valid
with ib.if_scope(out_scores[i, box_idx[0]] > -1.0):
nms_inner_loop(ib, box_idx[0])
box_idx[0] += 1

with ib.else_scope():
with ib.for_range(0, nkeep, name="j") as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
nms_inner_loop(ib, j)

with ib.if_scope(tx + 0 == 0):
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
Doc VisitStmt_(const WhileNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;

Expand Down
7 changes: 7 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,13 @@ Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) {
Doc doc;
doc << "while (" << Print(op->condition) << ")";
doc << PrintBody(op->body);
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
Doc doc;
doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
Expand Down
14 changes: 14 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,20 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
}

void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
using llvm::BasicBlock;
BasicBlock* while_cond = BasicBlock::Create(*ctx_, "while_cond", function_);
BasicBlock* while_body = BasicBlock::Create(*ctx_, "while_body", function_);
BasicBlock* while_merge = BasicBlock::Create(*ctx_, "while_merge", function_);
builder_->CreateBr(while_cond);
builder_->SetInsertPoint(while_cond);
builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge);
builder_->SetInsertPoint(while_body);
this->VisitStmt(op->body);
builder_->CreateBr(while_cond);
builder_->SetInsertPoint(while_merge);
}

void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
Expand Down
1 change: 1 addition & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
// stmt
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
Expand Down
11 changes: 10 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
ICHECK(is_one(op->predicate)) << "Predicated store is not supported";
arith::PVar<PrimExpr> base;


if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
Expand Down Expand Up @@ -899,6 +898,16 @@ void CodeGenC::VisitStmt_(const ForNode* op) {
stream << "}\n";
}

void CodeGenC::VisitStmt_(const WhileNode* op) {
PrintIndent();
stream << "while (" << PrintExpr(op->condition) << ") {\n";
int while_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(while_scope);
PrintIndent();
stream << "}\n";
}

void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
Expand Down
32 changes: 32 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,38 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});

// While
While::While(PrimExpr condition, Stmt body, Span span) {
ICHECK(condition.defined());
ICHECK(condition.dtype().is_scalar());
ICHECK(condition.as<tir::IntImmNode>() == nullptr) << "The condition should not be trivial.";
ICHECK(body.defined());

ObjectPtr<WhileNode> node = make_object<WhileNode>();
node->condition = std::move(condition);
node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) {
return While(condition, body, span);
});

TVM_REGISTER_NODE_TYPE(WhileNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<WhileNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const WhileNode*>(node.get());
p->PrintIndent();
p->stream << "while(" << op->condition << "){\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});

// Store
Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) {
ICHECK(value.defined());
Expand Down
18 changes: 18 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ void StmtVisitor::VisitStmt_(const ForNode* op) {
this->VisitStmt(op->body);
}

void StmtVisitor::VisitStmt_(const WhileNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->body);
}

void StmtVisitor::VisitStmt_(const AllocateNode* op) {
VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitStmt(op->body);
Expand Down Expand Up @@ -283,6 +288,19 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) {
}
}

Stmt StmtMutator::VisitStmt_(const WhileNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
Stmt body = this->VisitStmt(op->body);
if (condition.same_as(op->condition) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->condition = std::move(condition);
n->body = std::move(body);
return Stmt(n);
}
}

Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
Stmt body = this->VisitStmt(op->body);
Expand Down
5 changes: 5 additions & 0 deletions src/tir/transforms/coproc_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,11 @@ class CoProcInstDepDetector : public StmtVisitor {
}
}

void VisitStmt_(const WhileNode* op) final {
// TODO(masahi): Do we need a special handling for While nodes?
LOG(FATAL) << "WhileNode not supported in CoProcSync.";
}

// insert before is stored in reverse order
// the first element is closest to the node.
std::unordered_map<const Object*, std::vector<Stmt> > insert_before_;
Expand Down
7 changes: 7 additions & 0 deletions src/tir/transforms/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ class VTInjector : public StmtExprMutator {
}
}

// While
Stmt VisitStmt_(const WhileNode* op) final {
// TODO(masahi): What should we do for While nodes?
LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet";
return Stmt();
}

// Seq
Stmt VisitStmt_(const SeqStmtNode* op) final {
ICHECK_EQ(max_loop_depth_, 0);
Expand Down
6 changes: 6 additions & 0 deletions src/tir/transforms/lift_attr_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ class AttrScopeLifter : public StmtMutator {
}
}

Stmt VisitStmt_(const WhileNode* op) final {
// TODO(masahi): Do we need a special handling for While nodes?
LOG(FATAL) << "WhileNode not supported in LiftAttrScope.";
return Stmt();
}

private:
// value comparison that also compares content of int constant
static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
Expand Down
Loading