Skip to content

Commit

Permalink
finish
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Jul 29, 2021
1 parent f7d79cf commit cd511b2
Show file tree
Hide file tree
Showing 18 changed files with 244 additions and 85 deletions.
6 changes: 3 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# match_buffers of the block,
# which bind a sub-region of source buffer into a new buffer
D = tir.match_buffer_region(C[vi, vj])
D = tir.match_buffer(C[vi, vj], ())
# init part of the block, executed when all reduce axes are the beginning value
with tir.init():
C[vi, vj] = tir.float32(0)
# block body
CC[0, 0] = A[vi, vk] * B[vj, vk]
D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature"""
"""List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def LowerMatchBuffer():
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerMatchBuffer()
return _ffi_api.LowerMatchBuffer() # type: ignore


def FlattenBuffer():
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source)
<< ")" << Doc::NewLine();
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
Expand Down
30 changes: 11 additions & 19 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,16 @@ class LCADetector : public StmtExprVisitor {
buffer_var_map_.emplace(buf->data.get(), buf.get());
}

const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);

ancestor_scopes_.push_back(current_scope);
// Update match_buffers
for (const MatchBufferRegion& match_buffer : op->match_buffers) {
const Buffer& target_buffer = match_buffer->buffer;
buffer_var_map_.emplace(target_buffer->data.get(), target_buffer.get());

const Buffer& source_buffer = match_buffer->source->buffer;
auto it = match_buffers_.find(source_buffer.get());
if (it != match_buffers_.end()) {
match_buffers_[target_buffer.get()] = it->second;
} else {
match_buffers_[target_buffer.get()] = source_buffer.get();
}
UpdateBufferLCA(match_buffer->source->buffer.get());
match_buffers_.insert(match_buffer->buffer.get());
}

const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}
Expand Down Expand Up @@ -144,12 +137,11 @@ class LCADetector : public StmtExprVisitor {
}

void UpdateBufferLCA(const BufferNode* buffer) {
auto it = match_buffers_.find(buffer);
if (it != match_buffers_.end()) {
buffer = it->second;
if (match_buffers_.find(buffer) == match_buffers_.end()) {
// Ingore buffer created by block match_buffer
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
}
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
}

static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
Expand Down Expand Up @@ -184,7 +176,7 @@ class LCADetector : public StmtExprVisitor {
/*! \brief The map from Buffer data to the Buffer. */
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
/*! \brief The match buffers inside blocks. */
std::unordered_map<const BufferNode*, const BufferNode*> match_buffers_ = {};
std::unordered_set<const BufferNode*> match_buffers_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};
Expand Down
24 changes: 10 additions & 14 deletions src/tir/ir/script/script_complete.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,12 @@ namespace tir {
/*! \brief Generate surrounding loops automatically */
class ScriptCompleter : public StmtMutator {
public:
explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map, bool contain_root)
: buffer_var_map_(buffer_var_map), contain_root_(contain_root) {}
explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) : buffer_var_map_(buffer_var_map) {}
/*! \brief Whether the stmt contains at least one block. */
bool contains_block = false;

private:
Map<Var, Buffer>* buffer_var_map_;
bool contain_root_;
bool visited_root_ = false;
Stmt VisitStmt_(const BlockRealizeNode* op) override {
contains_block = true;
Stmt body = StmtMutator::VisitStmt_(op);
Expand All @@ -65,17 +62,23 @@ class ScriptCompleter : public StmtMutator {
}

Stmt VisitStmt_(const BlockNode* op) override {
bool is_root_block = contain_root_ && !visited_root_;
visited_root_ = true;
// Buffers allocated in the block can be accessed by its body.
for (const auto& alloc_buffer : op->alloc_buffers) {
buffer_var_map_->Set(alloc_buffer->data, alloc_buffer);
}
for (const auto& match_buffer : op->match_buffers) {
const Buffer& target_buffer = match_buffer->buffer;
buffer_var_map_->Set(target_buffer->data, target_buffer);
}
Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
// Remove buffers allocated inside block to detect its access region
for (const auto& alloc_buffer : op->alloc_buffers) {
buffer_var_map_->erase(alloc_buffer->data);
}
for (const auto& match_buffer : op->match_buffers) {
const Buffer& target_buffer = match_buffer->buffer;
buffer_var_map_->erase(target_buffer->data);
}
// Get access detection mask
// 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write
int mask = 0;
Expand All @@ -85,13 +88,6 @@ class ScriptCompleter : public StmtMutator {
}
// ignore root block or blocks which already has reads/writes regions
if (mask != 0) {
if (op->iter_vars.empty()) {
// non-root opaque block is not allowed
CHECK(is_root_block)
<< "ValueError: Can not auto detect buffer access region for an opaque block. Please "
"annotate the access region manually.";
return std::move(block);
}
auto access_region = GetBlockAccessRegion(block, *buffer_var_map_);
const Array<BufferRegion>& reads = access_region[0];
const Array<BufferRegion>& writes = access_region[1];
Expand Down Expand Up @@ -122,7 +118,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
}
bool contain_root = root_allocates.empty() && func->body->IsInstance<BlockRealizeNode>() &&
Downcast<BlockRealize>(func->body)->block->iter_vars.empty();
ScriptCompleter script_completer(&buffer_var_map, contain_root);
ScriptCompleter script_completer(&buffer_var_map);
// generate surrounding loops automatically
Stmt res = script_completer(func->body);
// generate root block automatically
Expand Down
8 changes: 4 additions & 4 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,9 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
const Buffer& source_buffer = source->buffer;
arith::Analyzer analyzer;
// Check scope and dtype
CHECK_EQ(buffer->scope, source_buffer->scope)
<< "MatchBuffer " << buffer << " scope mismatch:" << buffer->scope << " vs. "
<< source_buffer->scope;
CHECK_EQ(buffer.scope(), source_buffer.scope())
<< "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. "
<< source_buffer.scope();
CHECK_EQ(buffer->dtype, source_buffer->dtype)
<< "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. "
<< source_buffer->dtype;
Expand Down Expand Up @@ -798,7 +798,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << " = match_buffer_region(";
p->stream << op->buffer->name << " = match_buffer(";
p->Print(op->source);
p->stream << ")\n";
});
Expand Down
8 changes: 4 additions & 4 deletions src/tir/transforms/lower_match_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ class MatchBufferLower : public StmtExprMutator {
const Buffer& source_buffer = source->buffer;

// Step.1.1. Check scope & dtype
ICHECK_EQ(buffer->scope, source_buffer->scope)
<< "MatchBuffer " << buffer << " scope mismatch:" << buffer->scope << "vs."
<< source_buffer->scope;
ICHECK_EQ(buffer.scope(), source_buffer.scope())
<< "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << "vs."
<< source_buffer.scope();
ICHECK_EQ(buffer->dtype, source_buffer->dtype)
<< "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << "vs."
<< source_buffer->dtype;
Expand Down Expand Up @@ -251,4 +251,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB
} // namespace transform

} // namespace tir
} // namespace tvm
} // namespace tvm
4 changes: 2 additions & 2 deletions tests/python/integration/test_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
)


@tvm.testing.requires_cuda
def test_gemm_tensorcore():
dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(1024, 1024)).astype("float16")
Expand All @@ -310,7 +311,6 @@ def test_gemm_tensorcore():
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev)
print(tvm.script.asscript(tvm.lower(tensorcore_gemm, simple_mode=True)))
f = tvm.build(tensorcore_gemm, target="cuda", name="dense")
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
Expand All @@ -324,4 +324,4 @@ def test_gemm_tensorcore():


if __name__ == "__main__":
test_gemm_tensorcore()
test_gemm_tensorcore()
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,11 @@ def match_buffer_func(a: ty.handle, b: ty.handle) -> None:
with tir.block([8, 8], "block") as [vi, vj]:
tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16])
tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16))
B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4))
B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8))
with tir.block([16, 16], "AAA") as [i, j]:
AAA = tir.match_buffer(AA[i, j], ())
AAA[()] = 1.0
AA = tir.match_buffer(A[i, j], ())
AA[()] = 1.0
tir.evaluate(B0.data)
tir.evaluate(B1.data)

Expand Down
35 changes: 35 additions & 0 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ def match_buffer_func() -> None:
tir.evaluate(B1.data)


@tvm.script.tir
def opaque_block_func() -> None:
with tir.block([], "root"):
A = tir.alloc_buffer((16, 16), "float32")
B = tir.alloc_buffer((16, 16), "float32")
tir.reads([])
tir.writes([])
# Need add read/write region manually to avoid triggering block access region detector
for i in range(0, 16):
with tir.block([]):
tir.reads(A[i, 0:16])
tir.writes([B[i, 0:16]])
for j in range(0, 16):
with tir.block([]):
tir.reads(A[i, j])
tir.writes(B[i, j])
B[i, j] = A[i, j] + 1.0


def test_block_access_region_detector():
block = func.body.block.body.block
alloc_buffers = func.body.block.alloc_buffers
Expand All @@ -76,6 +95,21 @@ def test_block_access_region_detector():
)


def test_opaque_block():
alloc_buffers = opaque_block_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}

block0 = opaque_block_func.body.block.body.body.block
ret = tir.analysis.get_block_access_region(block0, buffer_var_map)
tvm.ir.assert_structural_equal(block0.reads, ret[0])
tvm.ir.assert_structural_equal(block0.writes, ret[1])

block1 = block0.body.body.block
ret = tir.analysis.get_block_access_region(block1, buffer_var_map)
tvm.ir.assert_structural_equal(block1.reads, ret[0])
tvm.ir.assert_structural_equal(block1.writes, ret[1])


def test_match_buffer():
root_block = match_buffer_func.body.block
block = root_block.body.body.body.block
Expand All @@ -97,4 +131,5 @@ def test_match_buffer():

if __name__ == "__main__":
test_block_access_region_detector()
test_opaque_block()
test_match_buffer()
Loading

0 comments on commit cd511b2

Please sign in to comment.