Skip to content

Commit

Permalink
[TIR] add loop partition hint pragma (apache#9121)
Browse files Browse the repository at this point in the history
* add loop partition hint pragma

* fix unintialized var

* fix to remove hint at last

* use tir compare for loop partition testcase
  • Loading branch information
wrongtest-intellif authored and ylc committed Jan 7, 2022
1 parent bb37a8c commit 8260e95
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 27 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,12 @@ constexpr const char* hand_threaded = "hand_threaded";
* if (mask & 2) the write region should be detected.
*/
constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";

/*!
* \brief Mark that the loop should be partitioned.
*/
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
106 changes: 80 additions & 26 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ class CandidateSelector final : public StmtExprVisitor {
void VisitStmt_(const ForNode* op) final {
// partition const loop when sets partition_const_loop_
if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
// always treat var with hint to be partitioned
const VarNode* var = op->loop_var.get();
if (partition_hint_vars.count(var)) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
Expand All @@ -117,6 +123,12 @@ class CandidateSelector final : public StmtExprVisitor {
Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
// always treat var with hint to be partitioned
if (partition_hint_vars.count(var.get())) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
Expand All @@ -125,6 +137,15 @@ class CandidateSelector final : public StmtExprVisitor {
record_.erase(var.get());
return;
}
} else if (op->attr_key == attr::pragma_loop_partition_hint) {
const VarNode* var = nullptr;
if (op->node->IsInstance<VarNode>()) {
var = op->node.as<VarNode>();
} else if (op->node->IsInstance<IterVarNode>()) {
var = op->node.as<IterVarNode>()->var.get();
}
ICHECK(var);
partition_hint_vars.insert(var);
}
StmtExprVisitor::VisitStmt_(op);
}
Expand Down Expand Up @@ -162,6 +183,7 @@ class CandidateSelector final : public StmtExprVisitor {
}

std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
std::unordered_set<const VarNode*> partition_hint_vars;

private:
bool in_likely_{false};
Expand All @@ -170,15 +192,28 @@ class CandidateSelector final : public StmtExprVisitor {
std::unordered_map<const VarNode*, VarIsUsed> record_;
};

// Finder try best to find partitions for hinted vars
#define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \
void VisitExpr_(const OpNodeT* op) final { \
if (has_partition_hint_) { \
DeduceCondition(GetRef<PrimExpr>(op)); \
return; \
} \
StmtExprVisitor::VisitExpr_(op); \
}

// Populate partitions data structure, i.e., for a specific variable,
// find an interval in which each condition
// (currently, "likely" conditions) has fixed true or false value
// find an interval in which each condition has fixed true or false value
class PartitionFinder : public StmtExprVisitor {
public:
explicit PartitionFinder(Var current_var,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
: current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
const std::unordered_map<const VarNode*, IntSet>& relax_map,
bool has_partition_hint)
: current_var_(current_var),
has_partition_hint_(has_partition_hint),
hint_map_(hint_map),
relax_map_(relax_map) {
for (const auto& kv : hint_map) {
out_vars_.insert(kv.first);
}
Expand Down Expand Up @@ -218,33 +253,43 @@ class PartitionFinder : public StmtExprVisitor {

void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
PrimExpr cond = op->args[0];
if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond, true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
partitions[{cond, false}] = interval;
}
}
}
DeduceCondition(op->args[0]);
} else {
StmtExprVisitor::VisitExpr_(op);
}
}

DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode);

Partition partitions;

private:
void DeduceCondition(const PrimExpr& cond) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) {
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond, true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
partitions[{cond, false}] = interval;
}
}
}
}

PrimExpr InverseCond(const PrimExpr& cond) {
PrimExpr inverse_cond;
if (const LTNode* op = cond.as<LTNode>()) {
Expand All @@ -270,6 +315,7 @@ class PartitionFinder : public StmtExprVisitor {
}

Var current_var_;
bool has_partition_hint_;
std::unordered_set<const VarNode*> out_vars_;
std::unordered_map<const VarNode*, IntSet> hint_map_;
std::unordered_map<const VarNode*, IntSet> relax_map_;
Expand Down Expand Up @@ -472,7 +518,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
// include hint of var.
hint_map_.insert({var.get(), IntSet::Interval(min, max)});

PartitionFinder finder(var, hint_map_, relax_map_);
bool has_partition_hint_ = selector.partition_hint_vars.count(var.get());
PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_);
finder(body);

hint_map_.erase(var.get());
Expand Down Expand Up @@ -601,7 +648,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
}
}

class RemoveLikelyTags : public StmtExprMutator {
class RemoveLikelyTagsAndHints : public StmtExprMutator {
public:
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
Expand All @@ -611,12 +658,19 @@ class RemoveLikelyTags : public StmtExprMutator {
return StmtExprMutator::VisitExpr_(op);
}
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::pragma_loop_partition_hint) {
return VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
};

Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) {
stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one)
.VisitAndMutate(std::move(stmt));
stmt = RemoveLikelyTags()(std::move(stmt));
stmt = RemoveLikelyTagsAndHints()(std::move(stmt));
return stmt;
}

Expand Down
31 changes: 30 additions & 1 deletion tests/python/unittest/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import tvm
import tvm.testing
from tvm import te
from tvm import tir
from tvm.script import ty
import numpy


Expand Down Expand Up @@ -434,7 +436,6 @@ def test_conv_tiling():
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
mod = tvm.tir.transform.LoopPartition()(mod)
Expand Down Expand Up @@ -538,6 +539,33 @@ def test_simple_rfactor():
assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)


@tvm.script.tir
def partitioned_concat(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
A = tir.match_buffer(a, [16], dtype="float32")
B = tir.match_buffer(b, [16], dtype="float32")
C = tir.match_buffer(c, [32], dtype="float32")
for i in tir.serial(0, 16):
tir.store(C.data, i, tir.load("float32", A.data, i), True)
for i in tir.serial(0, 16):
tir.store(C.data, i + 16, tir.load("float32", B.data, i + 16), True)


def test_explicit_partition_hint():
A = te.placeholder((16,), name="A")
B = te.placeholder((16,), name="B")
C = te.compute((32,), lambda i: te.if_then_else(i < 16, A[i], B[i]), name="C")
s = te.create_schedule(C.op)
s.normalize()
s[C].pragma(s[C].op.axis[0], "loop_partition_hint")
mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
assert tvm.ir.structural_equal(mod["main"], partitioned_concat)


if __name__ == "__main__":
test_basic()
test_const_loop()
Expand All @@ -559,3 +587,4 @@ def test_simple_rfactor():
test_double_splitting_with_indivisible_factors()
test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor()
test_explicit_partition_hint()

0 comments on commit 8260e95

Please sign in to comment.