Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MemHammer][Refactor] Code Review #15

Merged
merged 5 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class IterVarNode : public Object {
IterVarType iter_type;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
* set this if this is bound already to a known thread tag.
*/
String thread_tag;
/*!
Expand Down
79 changes: 39 additions & 40 deletions src/tir/transforms/memhammer_coalesce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
* under the License.
*/
#include "../../runtime/thread_storage_scope.h"
#include "memhammer_rewrite_rule.h"
#include "./memhammer_rewrite_rule.h"

namespace tvm {
namespace tir {

/*!
* \brief Fuse consecutive loops
* \param stmt the outer-most loop
* \param body the outer-most loop
* \return the fused loop
*/
Stmt FuseNestLoops(const Stmt& stmt) {
Stmt FuseNestLoops(Stmt body) {
std::vector<const ForNode*> loops;
Stmt body = stmt;
while (const ForNode* loop = body.as<ForNode>()) {
loops.push_back(loop);
body = loop->body;
Expand All @@ -52,9 +53,8 @@ Stmt FuseNestLoops(const Stmt& stmt) {
for (int i = 0; i < n; i++) {
fused_extent *= loops[i]->extent;
}
Stmt new_stmt = Substitute(body, f_substitute);
new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
return new_stmt;
return For(fused_var, 0, fused_extent, ForKind::kSerial,
Substitute(std::move(body), f_substitute));
}

/*!
Expand All @@ -65,72 +65,71 @@ Stmt FuseNestLoops(const Stmt& stmt) {
* \return The stmt after transformation
*/
Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) {
Stmt body = stmt;
const ForNode* loop = body.as<ForNode>();
PrimExpr vector_bytes = constraints.vector_bytes;
PrimExpr threadIdx_x = constraints.thread_extent.Get("threadIdx.x").value_or(Integer(1));
PrimExpr threadIdx_y = constraints.thread_extent.Get("threadIdx.y").value_or(Integer(1));
PrimExpr threadIdx_z = constraints.thread_extent.Get("threadIdx.z").value_or(Integer(1));
PrimExpr tot_threads = threadIdx_x * threadIdx_y * threadIdx_z;
PrimExpr data_bits = constraints.data_bits;
PrimExpr vector_len = max(1, vector_bytes * 8 / data_bits);
if (!loop || !is_zero(indexmod(loop->extent, (vector_len * tot_threads)))) {
LOG(FATAL) << "the number of elements must be a multiple of thread num";
}
PrimExpr outer_loop_extent = indexdiv(loop->extent, tot_threads * vector_len);
Array<PrimExpr> factors{outer_loop_extent};
std::vector<std::string> thread_axis;
const ForNode* loop = TVM_TYPE_AS(loop, stmt, ForNode);
int loop_extent = Downcast<Integer>(loop->extent)->value;
int vector_bytes = constraints.vector_bytes;
int data_bits = constraints.data_bits;
int vector_len = std::max(1, vector_bytes * 8 / data_bits);
int tot_threads = 1;
// generate thread binding loops
if (!is_one(threadIdx_z)) {
factors.push_back(threadIdx_z);
std::vector<int> factors{-1};
std::vector<std::string> thread_axis;
if (Optional<Integer> o_t = constraints.thread_extent.Get("threadIdx.z")) {
int t = o_t.value()->value;
tot_threads *= t;
factors.push_back(t);
thread_axis.push_back("threadIdx.z");
}
if (!is_one(threadIdx_y)) {
factors.push_back(threadIdx_y);
if (Optional<Integer> o_t = constraints.thread_extent.Get("threadIdx.y")) {
int t = o_t.value()->value;
tot_threads *= t;
factors.push_back(t);
thread_axis.push_back("threadIdx.y");
}
if (!is_one(threadIdx_x)) {
factors.push_back(threadIdx_x);
if (Optional<Integer> o_t = constraints.thread_extent.Get("threadIdx.x")) {
int t = o_t.value()->value;
tot_threads *= t;
factors.push_back(t);
thread_axis.push_back("threadIdx.x");
}
// generate vectorized loop
factors.push_back(vector_len);
// generate outer loop
ICHECK_EQ(loop_extent % (tot_threads * vector_len), 0);
factors[0] = loop_extent / (tot_threads * vector_len);
// create new loop vars
int n = factors.size();
std::vector<Var> new_loop_vars;
new_loop_vars.reserve(n);
for (int i = 0; i < n; i++) {
new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i)));
}

// substitute fused loop var with new loop vars
PrimExpr substitute_value = 0;
for (int i = 0; i < n; i++) {
substitute_value *= factors[i];
substitute_value += new_loop_vars[i];
}
body = Substitute(loop->body, [&](const Var& v) -> Optional<PrimExpr> {
// Construct the new loop nest
Stmt body = Substitute(loop->body, [&](const Var& v) -> Optional<PrimExpr> {
if (v.same_as(loop->loop_var)) {
return substitute_value;
} else {
return NullOpt;
}
});

For new_loop = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, body);

body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body));
for (int i = n - 2; i >= 1; i--) {
new_loop =
For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(new_loop),
IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1]));
body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body),
IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1]));
}

new_loop = For(new_loop_vars[0], 0, outer_loop_extent, ForKind::kSerial, new_loop);
return std::move(new_loop);
return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body));
}

Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints,
OutputSet* output) const {
Stmt after_fuse = FuseNestLoops(stmt);
Stmt after_split = SplitBindVectorize(after_fuse, constraints);
Stmt after_split = SplitBindVectorize(std::move(after_fuse), constraints);
return after_split;
}

Expand Down
55 changes: 23 additions & 32 deletions src/tir/transforms/memhammer_intermediate_stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Stmt CopyLoopChain(const std::vector<const ForNode*> loops, const Stmt& inner_bo
* \return a pair. The first is the transformed stmt.
* The second is the lowest thread binding loop.
*/
std::pair<Stmt, For> LiftThreadBindingLoops(Stmt stmt) {
std::pair<Stmt, Optional<For>> LiftThreadBindingLoops(Stmt stmt) {
std::vector<const ForNode*> normal_loops;
std::vector<const ForNode*> thread_binding_loops;
Stmt body = stmt;
Expand All @@ -53,11 +53,10 @@ std::pair<Stmt, For> LiftThreadBindingLoops(Stmt stmt) {
}
body = loop->body;
}
body = CopyLoopChain(normal_loops, body);
For compute_location;
body = CopyLoopChain(thread_binding_loops, body,
body = CopyLoopChain(normal_loops, std::move(body));
For compute_location{nullptr};
body = CopyLoopChain(thread_binding_loops, std::move(body),
static_cast<int>(thread_binding_loops.size()) - 1, &compute_location);

return std::make_pair(body, compute_location);
}

Expand Down Expand Up @@ -85,7 +84,7 @@ class IndexPatternFinder : public ExprVisitor {
* \param rewrite_indices The access indices after rank promotion
* \return The new buffer shape after rank promotion.
*/
static Array<Array<PrimExpr>> getRankPromotedShape(Array<PrimExpr> indices,
static Array<Array<PrimExpr>> GetRankPromotedShape(Array<PrimExpr> indices,
const Map<Var, Range>& var_range,
Array<PrimExpr>* rewrite_indices) {
Map<Var, arith::IntSet> var_dom = AsIntSet(var_range);
Expand Down Expand Up @@ -169,7 +168,7 @@ class RankPromoter : public StmtExprMutator {
/*!
* \brief Flatten the buffer shape like performing inverse rank promotion.
* For example, [[i0, i1], [j0, j1]] to [i0 * i1, j0 * j1]
* \param new_shape The buffer shape in the special form as returned by getRankPromotedShape
* \param new_shape The buffer shape in the special form as returned by GetRankPromotedShape
* \return The buffer shape after flatten
*/
static Array<PrimExpr> FlattenNewShape(const Array<Array<PrimExpr>>& new_shape) {
Expand Down Expand Up @@ -271,7 +270,7 @@ class RankPromoter : public StmtExprMutator {
/*!
* \brief Rewrite the indices after performing buffer rank promotion +
* buffer compacting + buffer flattening.
* \param indices The origina indices
* \param indices The original indices
* \return The indices after these transformations
*/
Array<PrimExpr> ConvertIndices(const Array<PrimExpr>& indices) {
Expand All @@ -290,20 +289,9 @@ class RankPromoter : public StmtExprMutator {
Array<Range> relaxed_region_;
};

/*!
* \brief Insert a cache stage to the compute location
* \param stmt the stmt
* \param is_write_cache whether to write a read cache or write cache
* \param storage_scope the storage scope of the new cache
* \param compute_location the compute location.
* \param outer_loops the outer loops of this stmt
* \param alloc_buffer the new cache block
* \return a pair. The first is the stmt after transformation.
* The second is the SeqStmt that contains 2 stages (one original and another inserted).
*/
std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope,
For compute_location, const Array<For>& outer_loops,
Buffer* alloc_buffer) {
Optional<For> compute_location,
const Array<For>& outer_loops, Buffer* alloc_buffer) {
Stmt body = stmt;
std::vector<const ForNode*> loops;
bool need_relax = !compute_location.defined();
Expand Down Expand Up @@ -340,22 +328,26 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
}

const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode);
// TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate
const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode);
Buffer orig_buffer = is_write_cache ? buf_store->buffer : buf_load->buffer;
Array<PrimExpr> indices = is_write_cache ? buf_store->indices : buf_load->indices;
// Step 1.2 get the new shape and new access indices after rank promotion
Array<PrimExpr> rewrite_indices;
Array<Array<PrimExpr>> new_shape =
IndexPatternFinder::getRankPromotedShape(indices, all_var_range, &rewrite_indices);
IndexPatternFinder::GetRankPromotedShape(indices, all_var_range, &rewrite_indices);
// Step 2. relax the access region after rank promotion
Region relaxed_region;
auto relax_var_intset = AsIntSet(relax_var_range);
arith::Analyzer analyzer;
analyzer.Bind(all_var_range);
for (const PrimExpr& index : rewrite_indices) {
auto int_set = arith::EvalSet(index, relax_var_intset);
relaxed_region.push_back(
Range::FromMinExtent(int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1)));
Array<Range> relaxed_region;
relaxed_region.reserve(rewrite_indices.size());
{
Map<Var, arith::IntSet> relax_var_intset = AsIntSet(relax_var_range);
for (const PrimExpr& index : rewrite_indices) {
arith::IntSet int_set = arith::EvalSet(index, relax_var_intset);
relaxed_region.push_back(Range::FromMinExtent(
int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1)));
}
}
// Step 3. generate the data copy bodies
// preparation work
Expand All @@ -378,8 +370,7 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
}
// Step 3.1 create a buffer for the cache
Buffer new_buffer = WithScope(orig_buffer, storage_scope);
BufferNode* buffer_ptr = new_buffer.CopyOnWrite();
buffer_ptr->shape = RankPromoter::FlattenNewShape(relaxed_new_shape);
new_buffer.CopyOnWrite()->shape = RankPromoter::FlattenNewShape(relaxed_new_shape);
*alloc_buffer = new_buffer;
Array<PrimExpr> real_orig_buf_indices =
RankPromoter::RewriteBackIndex(orig_buf_indices, new_shape);
Expand All @@ -404,7 +395,7 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
// Step 3.3 rewrite the original body to load from cache
Stmt rewrite_body;
if (compute_location.defined()) {
rewrite_body = compute_location->body;
rewrite_body = compute_location.value()->body;
} else {
rewrite_body = stmt;
}
Expand All @@ -423,7 +414,7 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraints,
OutputSet* output) const {
Stmt body;
For compute_location;
Optional<For> compute_location;
std::tie(body, compute_location) = LiftThreadBindingLoops(std::move(stmt));
Buffer cache_buffer;
Stmt after_caching = InsertCacheStage(body, false, "local", compute_location,
Expand Down
Loading