Skip to content

Commit

Permalink
[Dynamic] M2 for S3: Compute Inline (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii authored and tqchen committed May 25, 2023
1 parent b7ed202 commit 8a4e746
Show file tree
Hide file tree
Showing 4 changed files with 542 additions and 112 deletions.
94 changes: 54 additions & 40 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ namespace tir {
*/
class BlockReadWriteDetector : public StmtExprVisitor {
public:
explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
: buffer_var_map_(buffer_var_map) {}
explicit BlockReadWriteDetector(const Array<Buffer>& alloc_buffers,
const Map<Var, Buffer>& buffer_var_map)
: buffer_var_map_(buffer_var_map),
alloc_buffers_(alloc_buffers.begin(), alloc_buffers.end()) {}

/*! \brief Return read regions of the block */
Array<BufferRegion> CollectReads(
Expand Down Expand Up @@ -80,6 +82,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
/*!\ brief Internal analyzer. */
arith::Analyzer ana_;
/*! \brief The alloc buffers of the current block*/
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> alloc_buffers_;

/*!
* \brief Update read/write buffers and regions with provided buffer and region
Expand Down Expand Up @@ -147,11 +151,13 @@ Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef<Var>(op)); }

void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
if (!alloc_buffers_.count(op->buffer)) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
ExprVisitor::VisitExpr_(op);
}

Expand Down Expand Up @@ -184,20 +190,22 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
auto it = buffer_var_map_.find(GetRef<Var>(buffer_var));
if (it != buffer_var_map_.end()) {
const Buffer& buffer = (*it).second;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
const Region& region = buffer_region->region;
std::vector<arith::IntSet> int_set;
int_set.reserve(region.size());
for (const Range& range : region) {
int_set.push_back(arith::EvalSet(range, dom_map_));
}
// read access, write access or opaque access
if ((access_mask->value & 1) && (access_mask->value & 2)) {
Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
} else if (access_mask->value & 1) {
Update(&read_buffers_, &read_regions_, buffer, int_set);
} else if (access_mask->value & 2) {
Update(&writes_buffers_, &write_regions_, buffer, int_set);
if (!alloc_buffers_.count(buffer)) {
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
const Region& region = buffer_region->region;
std::vector<arith::IntSet> int_set;
int_set.reserve(region.size());
for (const Range& range : region) {
int_set.push_back(arith::EvalSet(range, dom_map_));
}
// read access, write access or opaque access
if ((access_mask->value & 1) && (access_mask->value & 2)) {
Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
} else if (access_mask->value & 1) {
Update(&read_buffers_, &read_regions_, buffer, int_set);
} else if (access_mask->value & 2) {
Update(&writes_buffers_, &write_regions_, buffer, int_set);
}
}
}
} else {
Expand All @@ -223,11 +231,13 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
}

void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
if (!alloc_buffers_.count(op->buffer)) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
StmtVisitor::VisitStmt_(op);
}

Expand All @@ -238,24 +248,28 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i];
}
for (const auto& read : op->block->reads) {
std::vector<arith::IntSet> relaxed_region;
for (const auto& range : read->region) {
relaxed_region.push_back(
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
dom_map_));
if (!alloc_buffers_.count(read->buffer)) {
std::vector<arith::IntSet> relaxed_region;
for (const auto& range : read->region) {
relaxed_region.push_back(
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
dom_map_));
}
Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
}
Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
}
for (const auto& write : op->block->writes) {
std::vector<arith::IntSet> relaxed_region;
for (const auto& range : write->region) {
relaxed_region.push_back(
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
dom_map_));
if (!alloc_buffers_.count(write->buffer)) {
std::vector<arith::IntSet> relaxed_region;
for (const auto& range : write->region) {
relaxed_region.push_back(
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
dom_map_));
}
Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
}
Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
}
}

Expand Down Expand Up @@ -351,7 +365,7 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) {

Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map) {
BlockReadWriteDetector detector(buffer_var_map);
BlockReadWriteDetector detector(block->alloc_buffers, buffer_var_map);
detector(block);
Array<BufferRegion> writes = detector.CollectWrites();
std::unordered_set<const BufferNode*> excluded_buffers;
Expand All @@ -368,7 +382,7 @@ Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,

Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map) {
BlockReadWriteDetector detector(buffer_var_map);
BlockReadWriteDetector detector(block->alloc_buffers, buffer_var_map);
detector(block);
Array<BufferRegion> opaques = detector.CollectOpaques();
std::unordered_set<const BufferNode*> excluded_buffers;
Expand Down
138 changes: 66 additions & 72 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ namespace tvm {
namespace tir {

static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of
'A[i, j, k, ...] = f(i, j, k, ...)',
where the indices on the left are distinct atomic variables,
and there should be no variables other than the index variables)";
'A[f(i, j, k, ...)] = g(i, j, k, ...)',
where the store indices mapping f on the left are bijective affine.)";

static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
Expand Down Expand Up @@ -284,31 +283,6 @@ class BaseInliner : public StmtExprMutator {
return std::move(tgt_block);
}

/*!
* \brief Count the number of undefined variables that are not used
* as buffer objects.
*
* This is used to determine whether inlining or reverse inlining is
* possible. The only undefined variables present should be the
* load/store indices, or buffer access based on those indices.
*
* \param stmt The statement in which to count undefined variables
*/
static int GetNumUndefinedNonpointerVars(const Stmt& stmt) {
auto undefined_vars = UndefinedVars(stmt, {});
// Buffer pointers and the inlined indices are allowed, but no
// other variables may appear in the inlined block.
int num_nonpointer_vars = 0;
for (const auto& var : undefined_vars) {
bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() &&
var->type_annotation.as<PointerTypeNode>();
if (!is_pointer) {
num_nonpointer_vars++;
}
}
return num_nonpointer_vars;
}

private:
/*!
* \brief Add the buffers in the block signature to the `buffer_var_map_`,
Expand Down Expand Up @@ -406,7 +380,7 @@ class BaseInliner : public StmtExprMutator {
/*! \brief Maps a buffer's data field to itself */
Map<Var, Buffer> buffer_var_map_;
/*! \brief The indices used for indexing the buffer to be inlined */
std::vector<const VarNode*> idx_vars_;
std::vector<Var> idx_vars_;
/*! \brief The mapping to substitute index variables to PrimExprs */
std::unordered_map<const VarNode*, PrimExpr> idx_sub_;

Expand Down Expand Up @@ -443,10 +417,62 @@ class ComputeInliner : public BaseInliner {
return false;
}

int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) {
// Fast path on trivial case:
// Check the store indices are same with the block iters;
store_value_ = inlined_store_->value;
size_t num_iters = producer_block->iter_vars.size();
size_t buffer_ndim = inlined_store_->indices.size();
if (num_iters == buffer_ndim) {
std::vector<Var> idx_vars;
idx_vars.reserve(num_iters);
for (size_t i = 0; i < num_iters; ++i) {
const IterVar& iter = producer_block->iter_vars[i];
const PrimExpr& e = inlined_store_->indices[i];
if (e.same_as(iter->var) ||
(analyzer_.CanProveEqual(e, 0) && analyzer_.CanProveEqual(iter->dom->min, 0) &&
analyzer_.CanProveEqual(iter->dom->extent, 1))) {
idx_vars.push_back(iter->var);
} else {
break;
}
}
if (idx_vars.size() == num_iters) {
// match success
idx_vars_ = std::move(idx_vars);
return true;
}
}

// If the mapping for store indices is non-trivial
// check bijective mapping from producer iter var to store indices
Map<Var, Range> producer_iter_doms;
for (const auto& iter : producer_block->iter_vars) {
producer_iter_doms.Set(iter->var, iter->dom);
}
auto res = arith::DetectIterMap(
/*indices=*/inlined_store_->indices,
/*input_iters=*/producer_iter_doms,
/*predicate=*/true,
/*check_level=*/arith::IterMapLevel::Bijective,
/*analyzer=*/&analyzer_,
/*simplify_trivial_iterators=*/false);
if (res->indices.empty()) {
// Failure: indices of BufferStore are not bijective affine
return false;
}
idx_vars_.resize(buffer_ndim);
for (size_t i = 0; i < idx_vars_.size(); ++i) {
idx_vars_[i] = Var("ph_" + std::to_string(i), inlined_store_->indices[i].dtype());
}
auto inverse_iter_map = arith::InverseAffineIterMap(
res->indices, Array<PrimExpr>(idx_vars_.begin(), idx_vars_.end()));
for (const auto& iter : producer_block->iter_vars) {
if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) {
// fallback mapping for constant iters
inverse_iter_map.Set(iter->var, iter->dom->min);
}
}
store_value_ = Substitute(store_value_, inverse_iter_map);
return true;
}

Expand All @@ -464,45 +490,7 @@ class ComputeInliner : public BaseInliner {

PrimExpr ReplaceInlinedBuffer(BufferLoad load) {
SetIndexSubstitution(load->indices);
return Substitute(inlined_store_->value, idx_sub_);
}

/*!
* \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
* If so, set `self->idx_vars_` properly.
* \param indices The indices to be extracted
* \param expected_ndim The expected ndim of the access
* \return A boolean flag indicating if the check is successful
*/
bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
int n = indices.size();
if (n != expected_ndim) {
// Failure: dimension mismatch
return false;
}
std::vector<const VarNode*> result;
result.reserve(n);
for (const PrimExpr& i : indices) {
if (const auto* var = i.as<VarNode>()) {
result.push_back(var);
} else {
// Failure: indexing expression is not a variable
return false;
}
}
using DistinctSet = std::unordered_set<const VarNode*>;
int n_distinct = DistinctSet(result.begin(), result.end()).size();
if (n != n_distinct) {
// Failure: indexing variables are not distinct
return false;
}
if (idx_vars_.empty()) {
idx_vars_ = std::move(result);
} else if (!support::ArrayWithSameContent(idx_vars_, result)) {
// Failure: indexing variables are not consitent in different BufferLoads
return false;
}
return true;
return Substitute(store_value_, idx_sub_);
}

/*!
Expand All @@ -512,11 +500,17 @@ class ComputeInliner : public BaseInliner {
void SetIndexSubstitution(const Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), idx_vars_.size());
int n = idx_vars_.size();
idx_sub_.reserve(n);
for (int i = 0; i < n; ++i) {
idx_sub_[idx_vars_[i]] = indices[i];
idx_sub_[idx_vars_[i].get()] = indices[i];
}
}

/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer_;
/*! \brief The store value for inlinement. If the producer
store indices are trivial, it is wrt the producer block iter var,
otherwise it is wrt to the placeholder vars of store indices. */
PrimExpr store_value_;
};

/*!
Expand Down
Loading

0 comments on commit 8a4e746

Please sign in to comment.