Skip to content

Commit

Permalink
[TIR] Make loop unrolling in LoopPartition optional (apache#6823)
Browse files Browse the repository at this point in the history
* [TIR] Make loop unrolling in LoopPartition optional

For certain analysis/tensorization, it can be useful
to keep the loop structure when partitioning loops.
The current behaviour removes For loops of length 1.
This change introduces the option to preserve these
loops with the 'unroll' flag.
  • Loading branch information
mbaret authored and trevor-m committed Dec 4, 2020
1 parent ca940fb commit 680f1dc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ namespace tir {

struct LoopPartitionConfigNode : public tvm::AttrsNode<LoopPartitionConfigNode> {
bool partition_const_loop;
bool no_unroll_loop_with_extent_one;

TVM_DECLARE_ATTRS(LoopPartitionConfigNode, "tir.transform.LoopPartitionConfig") {
TVM_ATTR_FIELD(partition_const_loop).describe("Split constant loop").set_default(false);
TVM_ATTR_FIELD(no_unroll_loop_with_extent_one)
.describe("Don't unroll loops with extent 1")
.set_default(false);
}
};

Expand Down Expand Up @@ -334,8 +338,9 @@ class ThreadPartitionInserter : public StmtMutator {
// likely conditions
class LoopPartitioner : public StmtMutator {
public:
explicit LoopPartitioner(bool partition_const_loop)
: selector(CandidateSelector(partition_const_loop)) {}
explicit LoopPartitioner(bool partition_const_loop, bool no_unroll_loop_with_extent_one)
: selector(CandidateSelector(partition_const_loop)),
no_unroll_loop_with_extent_one_(no_unroll_loop_with_extent_one) {}

Stmt VisitAndMutate(Stmt stmt) {
selector(stmt);
Expand Down Expand Up @@ -402,6 +407,7 @@ class LoopPartitioner : public StmtMutator {
std::unordered_map<const VarNode*, IntSet> relax_map_;
arith::Analyzer analyzer_;
CandidateSelector selector;
bool no_unroll_loop_with_extent_one_;
};

// Returns an interval (in the first component) in which all the conditions
Expand Down Expand Up @@ -596,7 +602,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) {
const ForNode* for_node = static_cast<const ForNode*>(node);
ICHECK(for_node);
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
!no_unroll_loop_with_extent_one_) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
Expand All @@ -617,8 +624,9 @@ class RemoveLikelyTags : public StmtExprMutator {
}
};

Stmt LoopPartition(Stmt stmt, bool partition_const_loop) {
stmt = LoopPartitioner(partition_const_loop).VisitAndMutate(std::move(stmt));
Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) {
stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one)
.VisitAndMutate(std::move(stmt));
stmt = RemoveLikelyTags()(std::move(stmt));
return stmt;
}
Expand All @@ -632,7 +640,8 @@ Pass LoopPartition() {
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<LoopPartitionConfig>();
}
n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop);
n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop,
cfg.value()->no_unroll_loop_with_extent_one);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,34 @@ def test_const_loop():
assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))


def test_no_unroll_loop():
n = 21
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")

T = te.compute((n,), lambda i: A[i] + B[i])
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)

bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
with tvm.transform.PassContext(
config={
"tir.LoopPartition": {
"partition_const_loop": True,
"no_unroll_loop_with_extent_one": True,
}
}
):
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
stmt = tvm.tir.transform.RemoveNoOp()(mod)["main"].body

assert sum(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.For))) == 4


def test_multi_loop():
ib = tvm.tir.ir_builder.create()
m = te.size_var("m")
Expand Down

0 comments on commit 680f1dc

Please sign in to comment.