diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index a15190665949..1c84304fb0e7 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -43,6 +43,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; + int explicit_unroll_max_extent; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -57,6 +58,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); + TVM_ATTR_FIELD(explicit_unroll_max_extent) + .describe("The maximum extent of a loop that can be unrolled explicitly (-1 means infinite)") + .set_default(32); } }; @@ -71,11 +75,12 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll) + bool explicit_unroll, int explicit_unroll_max_extent) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) {} + explicit_unroll_(explicit_unroll), + explicit_unroll_max_extent_(explicit_unroll_max_extent) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -165,6 +170,11 @@ class LoopUnroller : public StmtExprMutator { // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); + if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) { + // Do not unroll too long loops + ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; + return ForNode::make(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); + } Stmt body = op->body; Map vmap; Array unrolled; @@ -197,7 +207,10 @@ class LoopUnroller : public StmtExprMutator { // max extent of loop to auto unroll // this not not count the total steps, only count the number of loops int auto_max_extent_; + // Whether to explicitly unroll the loop instead of setting a pragma bool explicit_unroll_; + // The maximum extent of a loop that can be unrolled explicitly (-1 means infinite) + int explicit_unroll_max_extent_; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. @@ -210,7 +223,7 @@ class LoopUnroller : public StmtExprMutator { Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll)(stmt); + cfg->explicit_unroll, cfg->explicit_unroll_max_extent)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 68639940bb05..12c686634548 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -110,7 +110,31 @@ def test_unroll_single_count_loops(): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body assert ret == stmt +def test_unroll_explicitly_max_extent(): + n = 64 + A = te.placeholder((n,), name='A') + B = te.compute((n,), lambda *i: A(*i), name='B') + s = te.create_schedule(B.op) + s = s.normalize() + dom_map = tvm.te.schedule.InferBound(s) + stmt = tvm.te.schedule.ScheduleOps(s, dom_map) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"explicit_unroll_max_extent": n-1} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert tvm.ir.structural_equal(ret, stmt) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"explicit_unroll_max_extent": n} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert not tvm.ir.structural_equal(ret, stmt) + + if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() + test_unroll_explicitly_max_extent()