Skip to content

Commit

Permalink
Support condition bound awareness in compact buffer and get block acc…
Browse files Browse the repository at this point in the history
…ess region
  • Loading branch information
wrongtest-intellif committed Oct 26, 2021
1 parent 75a8fa1 commit 74a1425
Show file tree
Hide file tree
Showing 13 changed files with 521 additions and 109 deletions.
10 changes: 9 additions & 1 deletion include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,20 @@ IntSet UnionLowerBound(const Array<IntSet>& sets);
Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets);

/*!
* \brief Create an union set of all sets
* \brief Create an intersected set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
IntSet Intersect(const Array<IntSet>& sets);

/*!
* \brief Create a difference set of two sets, possibly relaxed
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
IntSet Difference(const IntSet& a, const IntSet& b);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
* \param region The region to be analyzed
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ def single_point(point):
The result set.
"""
return _ffi_api.intset_single_point(point)

@staticmethod
def difference(a, b):
"""Create a difference set of two sets, possibly relaxed
Parameters
----------
a : IntSet
The lhs oprand of set difference.
b : IntSet
The rhs oprand of set difference.
Returns
----------
result : IntSet
The result set, possibly relaxed.
"""
return _ffi_api.IntSetDifference(a, b)


@tvm._ffi.register_object("arith.IntervalSet")
Expand Down
24 changes: 24 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,25 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
}

IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsEmpty()) return b;
if (b->IsEmpty()) return a;
PrimExpr max_value = max(a->max_value, b->max_value);
PrimExpr min_value = min(a->min_value, b->min_value);
return IntervalSet(min_value, max_value);
}

IntervalSet Difference(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
PrimExpr upper_min = (b->max_value->dtype.is_int() || b->max_value->dtype.is_uint())
? b->max_value + 1
: b->max_value;
IntervalSet upper = Intersect(analyzer, a, IntervalSet(upper_min, pos_inf()));
PrimExpr lower_max = (b->min_value->dtype.is_int() || b->min_value->dtype.is_uint())
? b->min_value - 1
: b->min_value;
IntervalSet lower = Intersect(analyzer, a, IntervalSet(neg_inf(), lower_max));
return Union(analyzer, lower, upper);
}

// type traits
template <typename OP>
struct is_logical_op {
Expand Down Expand Up @@ -725,6 +739,13 @@ IntSet Intersect(const Array<IntSet>& sets) {
return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value));
}

IntSet Difference(const IntSet& a, const IntSet& b) {
IntervalSet interval_a = ToIntervalSet(a);
IntervalSet interval_b = ToIntervalSet(b);
Analyzer ana;
return Difference(&ana, interval_a, interval_b);
}

Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
Map<Var, IntSet> dmap;
for (auto kv : dom_map) {
Expand Down Expand Up @@ -898,6 +919,9 @@ TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound")
TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; });
TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; });
TVM_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound);
TVM_REGISTER_GLOBAL("arith.IntSetDifference").set_body_typed([](const IntSet& a, const IntSet& b) {
return Difference(a, b);
});

} // namespace arith
} // namespace tvm
10 changes: 10 additions & 0 deletions src/arith/interval_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b);
*/
TVM_DLL IntervalSet Intersect(Analyzer* analzyer, IntervalSet a, IntervalSet b);

/*!
* \brief Create difference of two IntervalSets, which is the minimal interval
* set covering all of the integers belong to a but not belong to b.
* \param analzyer The analyzer for simplification analysis.
* \param a The first set.
* \param b The second set.
* \return The result set.
*/
TVM_DLL IntervalSet Difference(Analyzer* analzyer, IntervalSet a, IntervalSet b);

} // namespace arith
} // namespace tvm

Expand Down
36 changes: 36 additions & 0 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ class BlockReadWriteDetector : public StmtExprVisitor {
void UpdateOpaque(const Var& buffer_var);

void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const CallNode* op) override;
};

void BlockReadWriteDetector::operator()(const Stmt& stmt) {
Expand Down Expand Up @@ -154,6 +156,40 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
dom_map_.erase(op->loop_var.get());
}

void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
VisitExpr(op->condition);
Map<Var, Range> bounds = GetVarBoundsFromCondition(op->condition, dom_map_);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(&dom_map_, bounds, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
With<ConditionalBoundsContext> ctx(&dom_map_, bounds, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}

void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::if_then_else())) {
VisitExpr(op->args[0]);
Map<Var, Range> bounds = GetVarBoundsFromCondition(op->args[0], dom_map_);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(&dom_map_, bounds, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(&dom_map_, bounds, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
}
StmtExprVisitor::VisitExpr_(op);
}

void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
UpdateOpaque(op->buffer_var);
StmtVisitor::VisitStmt_(op);
Expand Down
Loading

0 comments on commit 74a1425

Please sign in to comment.