Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add additional termination condition to For node to enable While loop like feature #7385

Closed
wants to merge 16 commits into from
10 changes: 8 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,9 @@ class ForNode : public StmtNode {
ForKind kind;
/*! \brief The body of the for loop. */
Stmt body;
/*! \brief The additional termination condition of the for loop. */
Optional<PrimExpr> test;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to have a RFC discussion, since different strategies changes to the IR can have different implications


/*!
* \brief Only valid when kind == ForKind::kThreadBinding
* The context thread that this loop variable bounds to.
Expand All @@ -823,6 +826,7 @@ class ForNode : public StmtNode {
v->Visit("extent", &extent);
v->Visit("kind", &kind);
v->Visit("body", &body);
v->Visit("test", &test);
v->Visit("thread_binding", &thread_binding);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
Expand All @@ -831,7 +835,8 @@ class ForNode : public StmtNode {
bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
equal(test, other->test) && equal(thread_binding, other->thread_binding) &&
equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -840,6 +845,7 @@ class ForNode : public StmtNode {
hash_reduce(extent);
hash_reduce(kind);
hash_reduce(body);
hash_reduce(test);
hash_reduce(thread_binding);
hash_reduce(annotations);
}
Expand All @@ -855,7 +861,7 @@ class ForNode : public StmtNode {
class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding = NullOpt,
Optional<PrimExpr> test = NullOpt, Optional<IterVar> thread_binding = NullOpt,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
def for_range(self, begin, end, name="i", test=None, dtype="int32", kind="serial"):
"""Create a for iteration scope.

Parameters
Expand All @@ -221,6 +221,9 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
The name of iteration variable, if no input names,
using typical index names i, j, k, then i_nidx

test : Expr, optional
The additional termination condition.

dtype : str, optional
The data type of iteration variable.

Expand Down Expand Up @@ -248,6 +251,10 @@ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
loop_var = _expr.Var(name, dtype=dtype)
extent = end if begin == 0 else (end - begin)

if test is not None:
msg = "A general termination condition is only supported for a serial loop."
assert kind == "serial", msg

def _exit_cb():
if kind == "serial":
kind_id = _stmt.ForKind.SERIAL
Expand All @@ -259,7 +266,7 @@ def _exit_cb():
kind_id = _stmt.ForKind.UNROLLED
else:
raise ValueError("Unknown kind")
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq()))
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq(), test))

return WithScope(loop_var, _exit_cb)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
extent,
kind,
body,
test=None,
thread_binding=None,
annotations=None,
span=None,
Expand All @@ -149,6 +150,7 @@ def __init__(
extent,
kind,
body,
test,
thread_binding,
annotations,
span,
Expand Down
22 changes: 12 additions & 10 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,16 +541,18 @@ def nms_inner_loop(ib, j):

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size
# boxes
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size
# boxes
with ib.for_range(0, nkeep, test=(num_valid_boxes_local[0] < max_output_size)) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
nms_inner_loop(ib, j)

with ib.else_scope():
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
nms_inner_loop(ib, j)

with ib.if_scope(tx + 0 == 0):
Expand Down
7 changes: 6 additions & 1 deletion src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,12 @@ inline const char* ForKind2String(ForKind t) {
Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
<< Print(op->min + op->extent) << ")";
<< Print(op->min + op->extent);
if (op->test) {
doc << ", (" << Print(op->test.value()) << "))";
} else {
doc << ")";
}
if (op->kind != ForKind::kSerial) {
doc << " " << Doc::StrLiteral(ForKind2String(op->kind));
}
Expand Down
7 changes: 4 additions & 3 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
CodeGenLLVM::VisitStmt_(op);
} else if (op->kind == ForKind::kParallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, op->test,
op->thread_binding, op->annotations),
0);
} else {
Expand All @@ -996,13 +996,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
parallel_env_.in_parallel_loop = true;
if (parallel_env_.stride_pattern) {
CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task),
op->loop_var, op->body);
op->loop_var, op->body, op->test);
} else {
PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
PrimExpr begin = min(task_id * step, op->extent);
PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin), MakeValue(end),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body);
llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body,
op->test);
}
parallel_env_.in_parallel_loop = false;
++parallel_env_.parallel_loop_count;
Expand Down
15 changes: 11 additions & 4 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
}

void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body) {
const Var& loop_var, const Stmt& body, Optional<PrimExpr> test) {
using llvm::BasicBlock;
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_);
Expand All @@ -673,8 +673,14 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va
loop_value->addIncoming(begin, pre_block);
ICHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end,
md_very_likely_branch_);

llvm::Value* less_than = CreateLT(loop_var.dtype(), loop_value, end);
llvm::Value* cond = less_than;
if (test) {
cond = builder_->CreateAnd(less_than, MakeValue(test.value()));
}
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);

builder_->SetInsertPoint(for_body);
this->VisitStmt(body);
var_map_.erase(loop_var.get());
Expand Down Expand Up @@ -1325,7 +1331,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
ICHECK(op->kind == ForKind::kSerial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body,
op->test);
}

void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
// Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body);
const Var& loop_var, const Stmt& body, Optional<PrimExpr> test);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index);
// The IRBuilder.
Expand Down
7 changes: 6 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,12 @@ void CodeGenC::VisitStmt_(const ForNode* op) {
ICHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n";
stream << ' ' << vid << " = 0; " << vid << " < " << extent;
if (op->test) {
std::string test = PrintExpr(op->test.value());
stream << " && (" << test << ")";
}
stream << "; ++" << vid << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
Expand Down
6 changes: 3 additions & 3 deletions src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = tir::Substitute(body, rmap);
under_outer = false;
return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body,
return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, op->test,
op->thread_binding, op->annotations);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
Expand Down Expand Up @@ -332,7 +332,7 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar,
return AttrStmt(iter_var, "thread_extent", op->extent, body);
} else {
return For(op->loop_var, op->min, op->extent, IterVarTypeToForKind(attr->iter_type),
op->body, op->thread_binding, op->annotations);
op->body, op->test, op->thread_binding, op->annotations);
}
}
return StmtMutator::VisitStmt_(op);
Expand Down Expand Up @@ -414,7 +414,7 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>
kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type);
}
const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
return For(target->var, range->min, range->extent, kind, body, op->thread_binding,
return For(target->var, range->min, range->extent, kind, body, op->test, op->thread_binding,
op->annotations);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
scaled_extent_value = ori_extent_value / scale_factor;
}
PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding,
op->annotations);
stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->test,
op->thread_binding, op->annotations);
}
}
return stmt;
Expand Down
8 changes: 5 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

// For
For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding, Map<String, ObjectRef> annotations, Span span) {
Optional<PrimExpr> test, Optional<IterVar> thread_binding,
Map<String, ObjectRef> annotations, Span span) {
ICHECK(min.defined());
ICHECK(extent.defined());
ICHECK(min.dtype().is_scalar());
Expand All @@ -143,16 +144,17 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
node->extent = std::move(extent);
node->kind = kind;
node->body = std::move(body);
node->test = std::move(test);
node->thread_binding = std::move(thread_binding);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.For").set_body_typed(
[](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body,
[](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, Optional<PrimExpr> test,
Optional<IterVar> thread_binding, Optional<Map<String, ObjectRef>> annotations, Span span) {
return For(loop_var, min, extent, static_cast<ForKind>(kind), body, thread_binding,
return For(loop_var, min, extent, static_cast<ForKind>(kind), body, test, thread_binding,
annotations.value_or(Map<String, ObjectRef>()), span);
});

Expand Down
11 changes: 11 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ void StmtVisitor::VisitStmt_(const ForNode* op) {
this->VisitExpr(op->min);
this->VisitExpr(op->extent);
this->VisitStmt(op->body);
if (op->test) {
this->VisitExpr(op->test.value());
}
}

void StmtVisitor::VisitStmt_(const AllocateNode* op) {
Expand Down Expand Up @@ -168,13 +171,21 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) {
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
Stmt body = this->VisitStmt(op->body);
Optional<PrimExpr> test = NullOpt;
if (op->test) {
test = this->VisitExpr(op->test.value());
}

if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->min = std::move(min);
n->extent = std::move(extent);
n->body = std::move(body);
if (test) {
n->test = std::move(test);
}
return Stmt(n);
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ class HoistCandidateSelector final : public StmtExprVisitor {
HoistCandidateSelector() { InitRecorder(); }

void VisitStmt_(const ForNode* op) final {
if (op->test) {
// Do not hoist if this is a while loop
return;
}
// If already recording complete,
// then stop tracing
if (RecordingComplete()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class IRConvertSSA final : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<ForNode>();
return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding,
return For(new_var, op->min, op->extent, op->kind, op->body, op->test, op->thread_binding,
op->annotations);
} else {
defined_.insert(v.get());
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class DataTypeRewriter : public StmtExprMutator {
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body,
op->thread_binding, op->annotations);
op->test, op->thread_binding, op->annotations);
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class StoragePlanRewriter : public StmtExprMutator {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body),
return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), op->test,
op->thread_binding, op->annotations);
} else {
return StmtExprMutator::VisitStmt_(op);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class LoopUnroller : public StmtExprMutator {
} else {
if (auto_unroll) {
if (op->kind != ForKind::kUnrolled) {
return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body,
return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, op->test,
op->thread_binding, op->annotations);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding,
return For(op->loop_var, op->min, extent, op->kind, body, op->test, op->thread_binding,
op->annotations);
}
}
Expand Down
Loading