Skip to content

Commit

Permalink
add explicit_unroll_max_extent (apache#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jun 20, 2020
1 parent 674027f commit 2f241ed
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
int auto_max_depth;
int auto_max_extent;
int explicit_unroll;
int explicit_unroll_max_extent;

TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
TVM_ATTR_FIELD(auto_max_step)
Expand All @@ -57,6 +58,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
TVM_ATTR_FIELD(explicit_unroll)
.describe("Whether to explicitly unroll the loop instead of setting a pragma")
.set_default(true);
TVM_ATTR_FIELD(explicit_unroll_max_extent)
.describe("The maximum extent of a loop that can be unrolled explicitly (-1 means infinite)")
.set_default(32);
}
};

Expand All @@ -71,11 +75,12 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
class LoopUnroller : public StmtExprMutator {
public:
explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent,
bool explicit_unroll)
bool explicit_unroll, int explicit_unroll_max_extent)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
explicit_unroll_(explicit_unroll) {}
explicit_unroll_(explicit_unroll),
explicit_unroll_max_extent_(explicit_unroll_max_extent) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
Expand Down Expand Up @@ -165,6 +170,11 @@ class LoopUnroller : public StmtExprMutator {
// For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
if (value == 0) return Evaluate(0);
if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) {
// Do not unroll too long loops
ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type;
return ForNode::make(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body);
}
Stmt body = op->body;
Map<Var, PrimExpr> vmap;
Array<Stmt> unrolled;
Expand Down Expand Up @@ -197,7 +207,10 @@ class LoopUnroller : public StmtExprMutator {
// max extent of loop to auto unroll
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
// Whether to explicitly unroll the loop instead of setting a pragma
bool explicit_unroll_;
// The maximum extent of a loop that can be unrolled explicitly (-1 means infinite)
int explicit_unroll_max_extent_;
// Number of normal loops in scope
int normal_loop_depth_{0};
// number of unrolled cases in current scope.
Expand All @@ -210,7 +223,7 @@ class LoopUnroller : public StmtExprMutator {

Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent,
cfg->explicit_unroll)(stmt);
cfg->explicit_unroll, cfg->explicit_unroll_max_extent)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_tir_transform_unroll_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,31 @@ def test_unroll_single_count_loops():
ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
assert ret == stmt

def test_unroll_explicitly_max_extent():
n = 64
A = te.placeholder((n,), name='A')
B = te.compute((n,), lambda *i: A(*i), name='B')
s = te.create_schedule(B.op)
s = s.normalize()
dom_map = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))

with tvm.transform.PassContext(config={
"tir.UnrollLoop": {"explicit_unroll_max_extent": n-1}
}):
ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
assert tvm.ir.structural_equal(ret, stmt)

with tvm.transform.PassContext(config={
"tir.UnrollLoop": {"explicit_unroll_max_extent": n}
}):
ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
assert not tvm.ir.structural_equal(ret, stmt)


if __name__ == "__main__":
test_unroll_loop()
test_unroll_fake_loop()
test_unroll_single_count_loops()
test_unroll_explicitly_max_extent()

0 comments on commit 2f241ed

Please sign in to comment.