Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 12, 2022
1 parent 564bc89 commit a3472de
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
print(intrin_group)
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member
intrin_group,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def get_wmma_intrin_group(
return {
"init": init_intrins[out_dtype],
"load_a": load_a_intrins[in_dtype],
"load_b": load_b_intrins,
"load_b": load_b_intrins[in_dtype],
"compute": compute_intrins[out_dtype],
"store": store_intrins[out_dtype],
}
58 changes: 22 additions & 36 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
Optional<LoopRV> TransformWithTensorIntrin(TensorCoreStateNode* state,
const String& intrin_name) const;

using BufferTypeIndex = std::pair<tir::BufferIndexType, int>;

/*!
* \brief Extract buffer index and its type from block reads/writes
* \param block_sref The sref to the block to extract
* \return The mapping from buffer to its type and and index
*/
std::unordered_map<tir::Buffer, BufferTypeIndex, ObjectPtrHash, ObjectPtrEqual>
ExtractBufferIndex(const tir::StmtSRef& block_sref) const;

/*!
* \brief Tile, blockize and annotate for tensorization with the given intrin
* \param block_rv The block to be tensorized
Expand Down Expand Up @@ -231,31 +221,15 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
return {state};
}

std::unordered_map<tir::Buffer, MultiLevelTilingTensorCoreNode::BufferTypeIndex, ObjectPtrHash,
ObjectPtrEqual>
MultiLevelTilingTensorCoreNode::ExtractBufferIndex(const tir::StmtSRef& block_sref) const {
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Collect buffer info before
std::unordered_map<tir::Buffer, BufferTypeIndex, ObjectPtrHash, ObjectPtrEqual> buffer_index_info;
for (int i = 0; i < static_cast<int>(block->reads.size()); ++i) {
buffer_index_info[block->reads[i]->buffer] = {tir::BufferIndexType::kRead, i};
}
for (int i = 0; i < static_cast<int>(block->writes.size()); ++i) {
buffer_index_info[block->writes[i]->buffer] = {tir::BufferIndexType::kWrite, i};
}
return buffer_index_info;
}

Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
TensorCoreStateNode* state, const String& intrin_name) const {
BlockRV block_rv = state->block_rv;
tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);

std::unordered_map<tir::Buffer, BufferTypeIndex, ObjectPtrHash, ObjectPtrEqual>
buffer_index_info = ExtractBufferIndex(block_sref);

// Add reindex stages
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Hold the reference of the block before reindex
const tir::Block block_before_reindex = GetRef<tir::Block>(block);
if (block->reads.size() != 2 || block->writes.size() != 1) {
// only matmul-like computation is allowed
return NullOpt;
Expand Down Expand Up @@ -320,17 +294,28 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
return tir::IndexMap(sub_index_map_src, sub_index_map_tgt);
};

for (const auto& it : buffer_index_info) {
const tir::Buffer& lhs_buffer = it.first;
const tir::BufferIndexType buffer_type = it.second.first;
int buffer_index = it.second.second;
std::unordered_set<tir::Buffer, ObjectPtrHash, ObjectPtrEqual> visited_buffers;

auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) {
const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer(
state->sch->state(), block_before_reindex, buffer_index, index_type);
if (visited_buffers.count(lhs_buffer)) {
return;
}
visited_buffers.insert(lhs_buffer);
// Refresh block pointer (block sref is not invalidated)
block = TVM_SREF_TO_BLOCK(block, block_sref);
const tir::BufferRegion& reindexed_buffer_region = buffer_type == tir::BufferIndexType::kRead
? block->reads[buffer_index]
: block->writes[buffer_index];
const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion(
state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
state->sch->TransformLayout(state->block_rv, buffer_index, buffer_type, sub_index_map);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map);
};

for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
f_transform_buffer_layout(tir::BufferIndexType::kRead, i);
}
for (int i = 0, n = block_before_reindex->writes.size(); i < n; ++i) {
f_transform_buffer_layout(tir::BufferIndexType::kWrite, i);
}

// Transform the layout of current block and reindex blocks
Expand Down Expand Up @@ -374,6 +359,7 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
auto f_initialize_intrin = [&intrin_group](String key_name, String* intrin_name) {
CHECK(intrin_group.count(key_name)) << "ValueError: " << key_name << " is not set.";
*intrin_name = intrin_group.at(key_name);
LOG(INFO) << key_name;
// Check the existence of the intrin
tir::TensorIntrin::Get(*intrin_name);
};
Expand Down
18 changes: 16 additions & 2 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
Expand Down Expand Up @@ -422,11 +423,24 @@ struct ProducerConsumerSplit {
* \param self The schedule state.
* \param block The queried block.
* \param n The index of the queried buffer.
* \param is_write A boolean flag to indicate querying write buffer or read buffer.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \return The buffer of the n-th read/write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
BufferIndexType index_type);

/*!
* \brief Get the n-th read or write buffer of the given block.
* \param self The schedule state.
* \param block The queried block.
* \param n The index of the queried buffer.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \return The n-th read/write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n,
BufferIndexType index_type);

/*!
* \brief Find the defining site of the buffer in the given block and its ancestors
Expand Down
30 changes: 19 additions & 11 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1142,17 +1142,19 @@ ProducerConsumerSplit ProducerConsumerSplit::Find(

/******** Block-buffer relation ********/

Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) {
BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n,
BufferIndexType index_type) {
class BufferIndexOutOfRangeError : public ScheduleError {
public:
explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, bool is_write)
explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index,
BufferIndexType index_type)
: mod_(std::move(mod)),
block_(std::move(block)),
buffer_index_(buffer_index),
is_write_(is_write) {}
index_type_(index_type) {}

String FastErrorString() const final {
if (is_write_) {
if (index_type_ == BufferIndexType::kWrite) {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range "
"[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
Expand All @@ -1167,9 +1169,9 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,

String DetailRenderTemplate() const final {
std::ostringstream os;
size_t num = is_write_ ? block_->writes.size() : block_->reads.size();
std::string access_type = is_write_ ? "write" : "read";
os << "The block {0} has " << num << " " << access_type
size_t num =
index_type_ == BufferIndexType::kWrite ? block_->writes.size() : block_->reads.size();
os << "The block {0} has " << num << " " << BufferIndexType2Str(index_type_)
<< " regions, so `buffer_index` is required to be in [0, " << num
<< "). However, the input `buffer_index` is " << buffer_index_
<< ", which is out of the expected range.";
Expand All @@ -1183,15 +1185,21 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
IRModule mod_;
Block block_;
int buffer_index_;
bool is_write_;
BufferIndexType index_type_;
};

const Array<BufferRegion>& access_region = is_write ? block->writes : block->reads;
const Array<BufferRegion>& access_region =
index_type == BufferIndexType::kWrite ? block->writes : block->reads;

if (n < 0 || static_cast<int>(access_region.size()) <= n) {
throw BufferIndexOutOfRangeError(self->mod, block, n, is_write);
throw BufferIndexOutOfRangeError(self->mod, block, n, index_type);
}
return access_region[n]->buffer;
return access_region[n];
}

Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
BufferIndexType index_type) {
return GetNthAccessBufferRegion(self, block, n, index_type)->buffer;
}

std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/primitive/block_annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind
int factor, int offset) {
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
Buffer buffer =
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, /*is_write=*/true);
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, BufferIndexType::kWrite);
StorageAlignInvalidFactorError::Check(self->mod, factor);
axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis);
NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
Expand Down Expand Up @@ -275,7 +275,8 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind
void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& storage_scope) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Buffer buffer = GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, true);
Buffer buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite);

// Step 1. If `storage_scope` equals the original storage scope of the buffer, just return.
if (buffer.scope() == storage_scope) {
Expand Down
6 changes: 3 additions & 3 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
// Step 1. Check index, getting the target buffer and the parent scope
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Buffer read_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, /*is_write=*/false);
GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);

Expand Down Expand Up @@ -1052,7 +1052,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
// Step 1. Checking index, getting the target buffer and the parent scope
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Buffer write_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, /*is_write=*/true);
GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kWrite);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);

// Step 2. Creating CacheStageInfo
Expand Down Expand Up @@ -1095,7 +1095,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
Block block = GetRef<Block>(block_ptr);
Buffer buffer =
GetNthAccessBuffer(self, block, buffer_index, buffer_index_type == BufferIndexType::kWrite);
GetNthAccessBuffer(self, block, buffer_index, buffer_index_type);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
arith::Analyzer analyzer;

Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
Buffer old_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index,
buffer_index_type == BufferIndexType::kRead ? false : true);
buffer_index_type);
Optional<StmtSRef> defining_site_sref;
bool is_alloc;
std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer);
Expand Down Expand Up @@ -492,7 +492,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer
BufferIndexType buffer_index_type, const Array<IntImm>& axis_separators) {
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
Buffer old_buffer = GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index,
buffer_index_type == BufferIndexType::kWrite);
buffer_index_type);
Optional<StmtSRef> defining_site_sref;
bool is_alloc;
std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,9 @@ def test_cuda_tensor_core_matmul_relu():
b2 = sch.reindex(block=b0, buffer=("write", 0))
b3 = sch.reindex(block=b0, buffer=("read", 0))
b4 = sch.reindex(block=b0, buffer=("read", 1))
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, ))
Expand Down Expand Up @@ -736,9 +736,9 @@ def test_cuda_tensor_core_matmul_relu_global():
b1 = sch.reindex(block=b0, buffer=("write", 0))
b2 = sch.reindex(block=b0, buffer=("read", 0))
b3 = sch.reindex(block=b0, buffer=("read", 1))
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, ))
sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
Expand Down Expand Up @@ -863,9 +863,9 @@ def test_cuda_tensor_core_conv2d():
b1 = sch.reindex(block=b0, buffer=("write", 0))
b2 = sch.reindex(block=b0, buffer=("read", 0))
b3 = sch.reindex(block=b0, buffer=("read", 1))
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, ))
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, ))
sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda h, w, rh, rw, rc: (((h*16) + w), (((rh*96) + (rw*32)) + rc), ))
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, ))
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, ))
sch.transform_block_layout(block=b1, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
sch.transform_block_layout(block=b2, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
sch.transform_block_layout(block=b3, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), ))
Expand Down

0 comments on commit a3472de

Please sign in to comment.