Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jan 25, 2022
1 parent 2bc9ef1 commit daf6727
Showing 1 changed file with 196 additions and 119 deletions.
315 changes: 196 additions & 119 deletions src/tir/transforms/memhammer_tensorcore_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

namespace tvm {
namespace tir {

/*!
* \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite.
* \param stmt The stmt
Expand All @@ -33,34 +34,57 @@ std::pair<Stmt, For> TileWmmaBlock(Stmt stmt) {
loops.push_back(loop);
body = loop->body;
}
arith::Analyzer analyzer;
PrimExpr extent_last1 = loops[loops.size() - 1]->extent,
extent_last2 = loops[loops.size() - 2]->extent;

if (!analyzer.CanProve(floormod(extent_last1, 16) == 0) ||
!analyzer.CanProve(floormod(extent_last2, 16) == 0)) {
return std::make_pair(stmt, For());
}
std::vector<Var> new_loop_vars;
Array<PrimExpr> factor{floordiv(extent_last2, 16), floordiv(extent_last1, 16), 16, 16};
new_loop_vars.reserve(4);
for (int i = 0; i < 2; i++) {
new_loop_vars.push_back(loops[loops.size() - 2]->loop_var.copy_with_suffix(std::to_string(i)));
new_loop_vars.push_back(loops[loops.size() - 1]->loop_var.copy_with_suffix(std::to_string(i)));
int n = loops.size();
PrimExpr extent_last1 = loops[n - 1]->extent;
PrimExpr extent_last2 = loops[n - 2]->extent;
{
arith::Analyzer analyzer;
if (!analyzer.CanProveEqual(floormod(extent_last1, 16), 0) ||
!analyzer.CanProveEqual(floormod(extent_last2, 16), 0)) {
return std::make_pair(stmt, For());
}
}
Map<Var, PrimExpr> substitue_value;
substitue_value.Set(loops[loops.size() - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]);
substitue_value.Set(loops[loops.size() - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]);
body = Substitute(body, substitue_value);
for (int i = static_cast<int>(new_loop_vars.size()) - 1; i >= 0; i--) {
body = For(new_loop_vars[i], 0, factor[i], ForKind::kSerial, body);
Var new_loop_vars[4] = {
/*0:*/ loops[n - 2]->loop_var.copy_with_suffix("_0"),
/*1:*/ loops[n - 1]->loop_var.copy_with_suffix("_0"),
/*2:*/ loops[n - 2]->loop_var.copy_with_suffix("_1"),
/*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"),
};
body = Substitute(std::move(body),
Map<Var, PrimExpr>{
{loops[n - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]},
{loops[n - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]},
});
{
PrimExpr factor[4] = {
/*0:*/ floordiv(extent_last2, 16), //
/*1:*/ floordiv(extent_last1, 16), //
/*3:*/ 16, //
/*4:*/ 16, //
};
body = For(new_loop_vars[3], 0, factor[3], ForKind::kSerial, std::move(body));
body = For(new_loop_vars[2], 0, factor[2], ForKind::kSerial, std::move(body));
body = For(new_loop_vars[1], 0, factor[1], ForKind::kSerial, std::move(body));
body = For(new_loop_vars[0], 0, factor[0], ForKind::kSerial, std::move(body));
}
For compute_location = Downcast<For>(body);
for (int i = static_cast<int>(loops.size()) - 3; i >= 0; i--) {
body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, body,
for (int i = n - 3; i >= 0; i--) {
body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body),
loops[i]->thread_binding, loops[i]->annotations);
}
return std::make_pair(body, compute_location);
return {body, compute_location};
}

Array<Range> RelaxIndices(const Array<PrimExpr>& indices, const Array<PrimExpr>& shape,
const Map<Var, arith::IntSet>& var_dom) {
Array<arith::IntSet> int_set = arith::EvalSet(indices, var_dom);
int ndim = int_set.size();
Array<Range> region;
region.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
region.push_back(int_set[i].CoverRange(Range::FromMinExtent(0, shape[i])));
};
return region;
}

/*!
Expand All @@ -69,67 +93,95 @@ std::pair<Stmt, For> TileWmmaBlock(Stmt stmt) {
* \return The stmt after rewrite
*/
Stmt RewriteWmmaLoad(Stmt stmt) {
Array<MatchBufferRegion> match_buffers;
using arith::IntSet;
const DataType dtype = DataType::Float(16);
const DataType int32 = DataType::Int(32);

Stmt body = stmt;
Map<Var, Range> var_range;
std::vector<const ForNode*> loops;
while (const ForNode* loop = body.as<ForNode>()) {
loops.push_back(loop);
body = loop->body;
}
for (int i = 1; i <= 2; i++) {
const ForNode* loop = loops[loops.size() - i];
var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
int n = loops.size();

Map<Var, IntSet> var_dom{
{loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)},
{loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)},
};
// TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate
const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode);
const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode);
Buffer src_buffer = buf_load->buffer;
Buffer tgt_buffer = buf_store->buffer;

DataType dtype = DataType::Float(16);
Var new_src_var("src", PointerType(PrimType(dtype), src_buffer.scope()));
DataType int32 = DataType::Int(32);
Buffer new_src_buffer(new_src_var, dtype, {Integer(16), Integer(16)},
{Var("s1", int32), Var("s0", int32)}, Var("src_elem_offset", int32), "src",
128, 16, kDefault);
auto read_int_set = arith::EvalSet(buf_load->indices, AsIntSet(var_range));
Array<Range> read_region;
for (int i = 0; i < static_cast<int>(read_int_set.size()); i++) {
read_region.push_back(
read_int_set[i].CoverRange(Range::FromMinExtent(0, src_buffer->shape[i])));
}
match_buffers.push_back(MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)));
Var new_tgt_var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope()));
Buffer new_tgt_buffer(new_tgt_var, dtype, {Integer(16), Integer(16)}, {},
Var("tgt_elem_offset", int32), "tgt", 128, 16, kDefault);
auto write_int_set = arith::EvalSet(buf_store->indices, AsIntSet(var_range));
Array<Range> write_region;
for (int i = 0; i < static_cast<int>(write_int_set.size()); i++) {
write_region.push_back(
write_int_set[i].CoverRange(Range::FromMinExtent(0, tgt_buffer->shape[i])));
}
match_buffers.push_back(
MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)));

PrimExpr frag_index = floordiv(new_tgt_buffer->elem_offset, 256) +
floordiv(floormod(new_tgt_buffer->elem_offset, 256), 16);

auto new_src_pointer = Call(
runtime::DataType::Handle(), builtin::tvm_access_ptr(),
{TypeAnnotation(new_src_buffer->dtype), new_src_buffer->data, new_src_buffer->elem_offset,
new_src_buffer->strides[new_src_buffer->strides.size() - 2] * 16, 1});

Stmt wmma_body = Evaluate(
Call(runtime::DataType::Handle(), builtin::tvm_load_matrix_sync(),
{new_tgt_buffer->data, 16, 16, 16, frag_index, new_src_pointer,
new_src_buffer->strides[new_src_buffer->strides.size() - 2], StringImm("row_major")}));
wmma_body = BlockRealize(
{}, Bool(true),
Block({}, {BufferRegion(src_buffer, read_region)}, {BufferRegion(tgt_buffer, write_region)},
"wmma_load", wmma_body, NullOpt, {}, match_buffers, {}));
for (int i = static_cast<int>(loops.size()) - 3; i >= 0; i--) {
wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, wmma_body,
loops[i]->thread_binding, loops[i]->annotations);
Buffer new_src_buffer(
/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())),
/*dtype=*/dtype,
/*shape=*/{Integer(16), Integer(16)},
/*strides=*/{Var("s1", int32), Var("s0", int32)},
/*elem_offset=*/Var("src_elem_offset", int32),
/*name=*/"src",
/*data_alignment=*/128,
/*offset_factor=*/16,
/*buffer_type=*/kDefault);
Buffer new_tgt_buffer(
/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())),
/*dtype=*/dtype,
/*shape=*/{Integer(16), Integer(16)},
/*strides=*/{},
/*elem_offset=*/Var("tgt_elem_offset", int32),
/*name=*/"tgt",
/*data_alignment=*/128,
/*offset_factor=*/16,
/*buffer_type=*/kDefault);
Array<Range> read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom);
Array<Range> write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom);
Stmt wmma_body = BlockRealize(
/*iter_values=*/{},
/*predicate=*/Bool(true),
Block(
/*iter_vars=*/{},
/*reads=*/{BufferRegion(src_buffer, read_region)},
/*writes=*/{BufferRegion(tgt_buffer, write_region)},
/*name_hint=*/"wmma_load",
/*body=*/
Evaluate(Call(
/*data=*/runtime::DataType::Handle(),
/*op=*/builtin::tvm_load_matrix_sync(),
{
/*0:*/ new_tgt_buffer->data,
/*1:*/ 16,
/*2:*/ 16,
/*3:*/ 16,
/*4:*/ floordiv(new_tgt_buffer->elem_offset, 256) +
floordiv(floormod(new_tgt_buffer->elem_offset, 256), 16),
/*5:*/
Call(
/*dtype=*/runtime::DataType::Handle(),
/*op=*/builtin::tvm_access_ptr(),
/*args=*/
{
/*0:*/ TypeAnnotation(new_src_buffer->dtype),
/*1:*/ new_src_buffer->data,
/*2:*/ new_src_buffer->elem_offset,
/*3:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2] * 16,
/*4:*/ 1,
}),
/*6:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2],
/*7:*/ StringImm("row_major"),
})),
/*init=*/NullOpt,
/*alloc_buffers=*/{},
/*match_buffers=*/
{
/*0:*/ MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)),
/*1:*/ MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)),
},
/*annotations=*/{}));
for (int i = n - 3; i >= 0; i--) {
wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind,
std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations);
}
return wmma_body;
}
Expand All @@ -140,65 +192,90 @@ Stmt RewriteWmmaLoad(Stmt stmt) {
* \return The stmt after rewrite
*/
Stmt RewriteWmmaStore(Stmt stmt) {
Array<MatchBufferRegion> match_buffers;
using arith::IntSet;
const DataType dtype = DataType::Float(32);
const DataType int32 = DataType::Int(32);

Stmt body = stmt;
Map<Var, Range> var_range;
std::vector<const ForNode*> loops;
while (const ForNode* loop = body.as<ForNode>()) {
loops.push_back(loop);
body = loop->body;
}
for (int i = 1; i <= 2; i++) {
const ForNode* loop = loops[loops.size() - i];
var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
int n = loops.size();

Map<Var, IntSet> var_dom{
{loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)},
{loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)},
};
// TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate
const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode);
const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode);
Buffer src_buffer = buf_load->buffer;
Buffer tgt_buffer = buf_store->buffer;

DataType dtype = DataType::Float(32);
DataType int32 = DataType::Int(32);
Var new_src_var("src", PointerType(PrimType(dtype), src_buffer.scope()));
Buffer new_src_buffer(new_src_var, dtype, {Integer(16), Integer(16)}, {},
Var("src_elem_offset", int32), "src", 128, 16, kDefault);
auto read_int_set = arith::EvalSet(buf_load->indices, AsIntSet(var_range));
Array<Range> read_region;
for (int i = 0; i < static_cast<int>(read_int_set.size()); i++) {
read_region.push_back(
read_int_set[i].CoverRange(Range::FromMinExtent(0, src_buffer->shape[i])));
}
match_buffers.push_back(MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)));
Var new_tgt_var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope()));
Buffer new_tgt_buffer(new_tgt_var, dtype, {Integer(16), Integer(16)},
{Var("s1", int32), Var("s0", int32)}, Var("tgt_elem_offset", int32), "tgt",
128, 16, kDefault);
auto write_int_set = arith::EvalSet(buf_store->indices, AsIntSet(var_range));
Array<Range> write_region;
for (int i = 0; i < static_cast<int>(write_int_set.size()); i++) {
write_region.push_back(
write_int_set[i].CoverRange(Range::FromMinExtent(0, tgt_buffer->shape[i])));
}
match_buffers.push_back(
MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)));

PrimExpr frag_index = floordiv(new_src_buffer->elem_offset, 256) +
floordiv(floormod(new_src_buffer->elem_offset, 256), 16);
Buffer new_src_buffer(/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())),
/*dtype=*/dtype,
/*shape=*/{Integer(16), Integer(16)},
/*strides=*/{},
/*elem_offset=*/Var("src_elem_offset", int32),
/*name=*/"src",
/*data_alignment=*/128,
/*offset_factor=*/16,
/*buffer_type=*/kDefault);
Buffer new_tgt_buffer(/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())),
/*dtype=*/dtype,
/*shape=*/{Integer(16), Integer(16)},
/*strides=*/{Var("s1", int32), Var("s0", int32)},
/*elem_offset=*/Var("tgt_elem_offset", int32),
/*name=*/"tgt",
/*data_alignment=*/128,
/*offset_factor=*/16,
/*buffer_type=*/kDefault);

auto new_tgt_pointer = Call(runtime::DataType::Handle(), builtin::tvm_access_ptr(),
{TypeAnnotation(new_tgt_buffer->dtype), new_tgt_buffer->data,
new_tgt_buffer->elem_offset, new_tgt_buffer->strides[0] * 16, 2});
Array<Range> read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom);
Array<Range> write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom);

Stmt wmma_body = Evaluate(Call(runtime::DataType::Handle(), builtin::tvm_store_matrix_sync(),
{new_src_buffer->data, 16, 16, 16, frag_index, new_tgt_pointer,
new_tgt_buffer->strides[0], StringImm("row_major")}));
wmma_body = BlockRealize(
{}, Bool(true),
Block({}, {BufferRegion(src_buffer, read_region)}, {BufferRegion(tgt_buffer, write_region)},
"wmma_store", wmma_body, NullOpt, {}, match_buffers, {}));
for (int i = static_cast<int>(loops.size()) - 3; i >= 0; i--) {
wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, wmma_body,
loops[i]->thread_binding, loops[i]->annotations);
Stmt wmma_body = BlockRealize(
/*iter_values=*/{}, //
/*predicate=*/Bool(true),
Block(/*iter_vars=*/{},
/*reads=*/{BufferRegion(src_buffer, read_region)},
/*writes=*/{BufferRegion(tgt_buffer, write_region)},
/*name_hint=*/"wmma_store",
Evaluate(Call(
/*data=*/runtime::DataType::Handle(),
/*op=*/builtin::tvm_store_matrix_sync(),
{/*0:*/ new_src_buffer->data,
/*1:*/ 16,
/*2:*/ 16,
/*3:*/ 16,
/*4:*/ floordiv(new_src_buffer->elem_offset, 256) +
floordiv(floormod(new_src_buffer->elem_offset, 256), 16),
/*5:*/
Call(
/*data=*/runtime::DataType::Handle(),
/*op=*/builtin::tvm_access_ptr(),
{
/*0:*/ TypeAnnotation(new_tgt_buffer->dtype),
/*1:*/ new_tgt_buffer->data,
/*2:*/ new_tgt_buffer->elem_offset,
/*3:*/ new_tgt_buffer->strides[0] * 16,
/*4:*/ 2,
}),
/*6:*/ new_tgt_buffer->strides[0],
/*7:*/ StringImm("row_major")})),
/*init=*/NullOpt,
/*alloc_buffers=*/{},
/*match_buffers=*/
{
MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)),
MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)),
},
/*annotations=*/{}));
for (int i = n - 3; i >= 0; i--) {
wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind,
std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations);
}
return wmma_body;
}
Expand Down

0 comments on commit daf6727

Please sign in to comment.