Skip to content

Commit

Permalink
[TIR] Make compact buffer and get access region aware of conditions (#…
Browse files Browse the repository at this point in the history
…9372)

* Support condition bound awareness in compact buffer and get block access region

* remove intset difference usage

* fix to visit match buffer's access region

* change method to distinguish annotated opaque access regions
  • Loading branch information
wrongtest-intellif authored Nov 8, 2021
1 parent a644e29 commit 811312c
Show file tree
Hide file tree
Showing 9 changed files with 486 additions and 102 deletions.
2 changes: 1 addition & 1 deletion include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ IntSet UnionLowerBound(const Array<IntSet>& sets);
Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets);

/*!
* \brief Create an union set of all sets
* \brief Create an intersected set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
Expand Down
2 changes: 2 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
}

IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
if (a->IsEmpty()) return b;
if (b->IsEmpty()) return a;
PrimExpr max_value = max(a->max_value, b->max_value);
PrimExpr min_value = min(a->min_value, b->min_value);
return IntervalSet(min_value, max_value);
Expand Down
34 changes: 34 additions & 0 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ class BlockReadWriteDetector : public StmtExprVisitor {
void UpdateOpaque(const Var& buffer_var);

void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const CallNode* op) override;
};

void BlockReadWriteDetector::operator()(const Stmt& stmt) {
Expand Down Expand Up @@ -154,6 +156,38 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
dom_map_.erase(op->loop_var.get());
}

void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}

void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::if_then_else())) {
VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
}
StmtExprVisitor::VisitExpr_(op);
}

void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
UpdateOpaque(op->buffer_var);
StmtVisitor::VisitStmt_(op);
Expand Down
232 changes: 137 additions & 95 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include <tvm/arith/int_set.h>
#include <tvm/arith/int_solver.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
Expand All @@ -41,16 +42,19 @@ namespace tir {
using support::NDIntSet;

/*!
* \brief return the region collected by NDIntSet. return the oroginal buffer shape if the
* int_set is empty.
* \brief simplify and return the region collected by NDIntSet. return the original
* buffer shape if the int_set is empty.
*/
Region NarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
const Array<PrimExpr>& original_shape) {
Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
const Array<PrimExpr>& original_shape,
arith::Analyzer* analyzer) {
Array<Range> result;
result.reserve(nd_int_set.size());
for (size_t i = 0; i < nd_int_set.size(); ++i) {
const arith::IntSet& int_set = nd_int_set[i];
result.push_back(int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i])));
Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i]));
result.push_back(
Range::FromMinExtent(analyzer->Simplify(range->min), analyzer->Simplify(range->extent)));
}
return result;
}
Expand Down Expand Up @@ -85,6 +89,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {

void VisitStmt_(const BufferStoreNode* op) final {
VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
VisitExpr(op->value);
}

void VisitExpr_(const BufferLoadNode* op) final {
Expand All @@ -105,58 +110,91 @@ class BufferAccessRegionCollector : public StmtExprVisitor {

void VisitStmt_(const ForNode* op) final {
ancestor_loops_.push_back(op);
Range loop_range = Range::FromMinExtent(op->min, op->extent);
dom_analyzer_.Bind(op->loop_var, loop_range);
dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->loop_var.get());
ancestor_loops_.pop_back();
// The iter_dom_map is updated by post DFS order.
// If the union point is under the for node, the loop var will not be relaxed.
// If the union point is outer of the for loop, the loop var should be relaxed.
iter_dom_map_on_post_order_[op->loop_var.get()] =
arith::IntSet::FromMinExtent(op->min, op->extent);
}

void VisitStmt_(const IfThenElseNode* op) final {
// Visit condition
StmtExprVisitor::VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}

void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::if_then_else())) {
// Visit condition
StmtExprVisitor::VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
}
return StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BlockNode* op) final {
// Step 0. Check there is no init part.
ICHECK(!op->init.defined());
// Step 1. Update outer buffer access info using buffer region
// Step 1. Record and update current read/write region annotations
std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual>
cur_access_annotations;
for (const BufferRegion& region : op->reads) {
VisitBufferAccess(region);
cur_access_annotations[region->buffer].push_back(region);
}
for (const BufferRegion& region : op->writes) {
VisitBufferAccess(region);
cur_access_annotations[region->buffer].push_back(region);
}

// Step 2. Update inner buffer
// Step 2.1. rebuild map buffer_var_in_scope
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_var_in_scope;
for (auto& p : cur_access_annotations) {
auto& regions = access_annotations_[p.first];
p.second.swap(regions);
}
// Step 2. Record relax position of ancestor_loops_ into buffer_var_in_scope_
for (const Buffer& buffer : op->alloc_buffers) {
buffer_var_in_scope.emplace(buffer->data, buffer);
buffer_var_in_scope_.emplace(buffer->data, std::make_pair(buffer, ancestor_loops_.size()));
}
// Step 2.2 Record top stack element before recursive visiting.
size_t stack_top = buffer_access_stack_.size();

// Step 2.3. Update the buffer_var_in_scope_ of visitor and visit recursively
std::swap(buffer_var_in_scope, buffer_var_in_scope_);
// Step 3. Visit match buffers
for (const MatchBufferRegion& region : op->match_buffers) {
VisitBufferAccess(region->source);
}
// Step 4. Visit block body recursively
StmtExprVisitor::VisitStmt_(op);
std::swap(buffer_var_in_scope, buffer_var_in_scope_);

// Step 2.4. Combine and relax access
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> relaxed_region =
CombineAndRelax(stack_top);

// Step 2.5. Visit ancestor_loops and try to relax outer thread loops.
// Step 5. Recover read/write region annotations
for (auto& p : cur_access_annotations) {
auto& regions = access_annotations_[p.first];
if (p.second.empty()) {
access_annotations_.erase(p.first);
} else {
regions.swap(p.second);
}
}
// Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers.
for (const Buffer& buffer : op->alloc_buffers) {
auto it = relaxed_region.find(buffer);
ICHECK(it != relaxed_region.end());
auto it = relaxed_accesses_.find(buffer);
ICHECK(it != relaxed_accesses_.end())
<< buffer << " is allocated but not accessed within block scope";
const NDIntSet& nd_int_set = it->second;
std::unordered_map<const VarNode*, arith::IntSet> dom_map;
for (const ForNode* loop : ancestor_loops_) {
const VarNode* loop_var = loop->loop_var.get();
if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer.scope()))) {
dom_map[loop_var] = arith::IntSet::FromMinExtent(loop->min, loop->extent);
}
}
NDIntSet int_set = support::NDIntSetEval(nd_int_set, dom_map);
buffer_access_region_[buffer] = NarrowBufferRegionFromNDIntSet(int_set, buffer->shape);
buffer_access_region_[buffer] =
SimplifyAndNarrowBufferRegionFromNDIntSet(nd_int_set, buffer->shape, &dom_analyzer_);
}
}

Expand All @@ -166,61 +204,54 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
const BufferNode* buffer = buffer_region->buffer.get();
auto it = buffer_var_in_scope_.find(buffer->data);
if (it != buffer_var_in_scope_.end()) {
const Buffer& buffer = it->second;
const BufferAccessInfo* info =
arena_.make<BufferAccessInfo>(buffer, support::NDIntSetFromRegion(buffer_region->region));
buffer_access_stack_.push(info);
const Buffer& buffer = it->second.first;
size_t n_ancestor_loops = it->second.second;
NDIntSet nd_int_set = support::NDIntSetFromRegion(buffer_region->region);
// Step 1. Stop ancestor loop vars out of the allocation block from
// being relaxed unless NeedRelaxThread() is true.
std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const ForNode* loop = ancestor_loops_[i];
const VarNode* v = loop->loop_var.get();
if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer.scope()))) {
continue;
}
auto dom_it = dom_map_.find(v);
ICHECK(dom_it != dom_map_.end());
non_relaxed[i] = dom_it->second;
dom_map_.erase(dom_it);
}
// Step 2. Relax the access region
nd_int_set = support::NDIntSetEval(nd_int_set, dom_map_);
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const VarNode* v = ancestor_loops_[i]->loop_var.get();
dom_map_.emplace(v, non_relaxed[i]);
}
// Step 4. Update relaxed_accesses_ dict
auto access_it = relaxed_accesses_.find(buffer);
if (access_it != relaxed_accesses_.end()) {
support::NDIntSetUnionWith(&access_it->second, nd_int_set);
} else {
relaxed_accesses_.insert(access_it, {buffer, nd_int_set});
}
}
}

void VisitBufferVar(const Var& var) {
auto it = buffer_var_in_scope_.find(var);
if (it != buffer_var_in_scope_.end()) {
const Buffer& buffer = it->second;
VisitBufferAccess(BufferRegion::FullRegion(buffer));
}
}

/*!
* \brief Combine buffer accesses in the sub-tree.
* \details The access info is stored in a stack by DFS order, so that the accesses in the
* sub-tree are top-n elements in the stack.
* \param stack_top compact the access information in `stack[stack_top:end]`.
*/
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> CombineAndRelax(
size_t stack_top) {
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> accesses;
while (buffer_access_stack_.size() > stack_top) {
const BufferAccessInfo* info = buffer_access_stack_.top();
buffer_access_stack_.pop();
NDIntSet nd_int_set =
support::NDIntSetEval(info->accessed_region, iter_dom_map_on_post_order_);
auto it = accesses.find(info->buffer);
if (it != accesses.end()) {
support::NDIntSetUnionWith(&it->second, nd_int_set);
const Buffer& buffer = it->second.first;
auto annotation_it = access_annotations_.find(buffer);
if (annotation_it != access_annotations_.end()) {
// opaque buffer has explicit accessed region annotations
for (const BufferRegion& region : annotation_it->second) {
VisitBufferAccess(region);
}
} else {
accesses[info->buffer] = nd_int_set;
VisitBufferAccess(BufferRegion::FullRegion(buffer));
}
}
return accesses;
}

/*!
* \brief Combine buffer accesses in the sub-tree and push the combined result into the stack.
* \details The access info is stored in a stack by DFS order, so that the accesses in the
* sub-tree are top-n elements in the stack.
* \param stack_top The top element of the stack before visiting the sub-tree.
*/
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> CombineRelaxAndPushStack(
size_t stack_top) {
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> accesses =
CombineAndRelax(stack_top);
for (const auto& kv : accesses) {
const Buffer& buffer = kv.first;
const NDIntSet& int_set = kv.second;
buffer_access_stack_.push(arena_.make<BufferAccessInfo>(buffer, int_set));
}
return accesses;
}

/*! \brief Check whether the thread binding loop should be relaxed with given storage scope. */
Expand All @@ -236,19 +267,30 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}

/**************** Class members ****************/

/*! \brief Buffer access in DFS order. */
std::stack<const BufferAccessInfo*> buffer_access_stack_;
/*! \brief The loops from the current node up to the root. */
std::vector<const ForNode*> ancestor_loops_;
/*! \brief The vars of the buffer allocated under the current block. */
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_var_in_scope_;

/*!
* \brief The vars of the buffer allocated under the current block.
* Map each buffer var to (buffer_obj, n_ancester_loop) pair, where
* n_ancester_loop is the loop num out of the current block.
* Tancestor_loops_[0: n_ancester_loop] should not be relaxed when
* we evaluate this buffer's access regions.
*/
std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash, ObjectPtrEqual>
buffer_var_in_scope_;

/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> iter_dom_map_on_post_order_;
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief The analyzer aware of loop domains. */
arith::Analyzer dom_analyzer_;
/*! \brief The map from Buffer to it's relaxed access set. */
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> relaxed_accesses_;
/*! \brief The map from Buffer to it entire access region, used for returning. */
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> buffer_access_region_;
/*! \brief Internal arena. */
support::Arena arena_;
/*! \brief The map from Buffer to it's access regions annotated by current block. */
std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual>
access_annotations_;
};

/*! \brief Collect storage alignment information from block annotations. */
Expand Down
Loading

0 comments on commit 811312c

Please sign in to comment.