Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Fix opaque access in buffer locator pass and match_buffer in region detector #8855

Merged
merged 5 commits into from
Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) {
ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
match_buffers_[target_var.get()] = match_buffer;
buffer_var_map_.Set(target_var, match_buffer->buffer);
const Var& source_var = match_buffer->source->buffer->data;
if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) {
match_buffers_[target_var.get()] = match_buffer;
buffer_var_map_.Set(target_var, match_buffer->buffer);
}
}
StmtExprVisitor::operator()(stmt);
}
Expand Down
39 changes: 27 additions & 12 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class BufferAllocationLocator : public StmtExprMutator {

Stmt VisitStmt_(const BlockNode* op) final {
ICHECK(!op->init.defined());
bool is_root = is_root_;
is_root_ = false;
Array<Buffer> alloc_buffers;
auto it = alloc_buffers_.find(op);
if (it != alloc_buffers_.end()) {
Expand All @@ -85,11 +83,23 @@ class BufferAllocationLocator : public StmtExprMutator {
buffer_data_to_buffer_.Set(buf->data, buf);
}
}
for (const MatchBufferRegion match_buffer : op->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
const Var& source_var = match_buffer->source->buffer->data;
ICHECK(buffer_data_to_buffer_.count(source_var));
buffer_data_to_buffer_.Set(target_var, match_buffer->buffer);
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<BlockNode>();
ICHECK(op != nullptr);

// Ignore buffer allocated inside the block when getting access region.
// No longer consider buffers created by match_buffer inside the block when updating access
// region.
for (const MatchBufferRegion match_buffer : op->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
buffer_data_to_buffer_.erase(target_var);
}
// No longer consider buffers allocated inside the block when updating access region.
if (it != alloc_buffers_.end()) {
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.erase(buf->data);
Expand All @@ -98,12 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator {

ObjectPtr<BlockNode> n = CopyOnWrite(op);
n->alloc_buffers = std::move(alloc_buffers);
// The read/write regions of root block are always empty.
if (!is_root) {
// Recalculate block access region
CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
}

// Erase buffer allocated inside the block from access region.
n->reads = RemoveRedundantBufferRegion(n->reads);
n->writes = RemoveRedundantBufferRegion(n->writes);
return Stmt(n);
}

Expand All @@ -127,8 +134,18 @@ class BufferAllocationLocator : public StmtExprMutator {
return std::move(realize);
}

Array<BufferRegion> RemoveRedundantBufferRegion(const Array<BufferRegion>& region) const {
Array<BufferRegion> result;
for (const BufferRegion& buffer_region : region) {
if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) {
result.push_back(buffer_region);
}
}
return result;
}

void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
Array<BufferRegion>* writes) {
Array<BufferRegion>* writes) const {
Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
*reads = access[0];
*writes = access[1];
Expand All @@ -142,8 +159,6 @@ class BufferAllocationLocator : public StmtExprMutator {
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Map<Var, Buffer> buffer_data_to_buffer_;
/*! \brief indicate the whether the block is root. */
bool is_root_{true};
};

PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
Expand Down
21 changes: 15 additions & 6 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,29 @@ def test_match_buffer():
root_block = match_buffer_func.body.block
block = root_block.body.body.body.block
block_inner = block.body[0].body.body.block
alloc_buffers = func.body.block.alloc_buffers
alloc_buffers = match_buffer_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}

# Check inner block AAA
ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
tvm.ir.assert_structural_equal(block_inner.writes, ret[1])

# Check block
ret = tir.analysis.get_block_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.writes, ret[1])
# B is opaque access
tvm.ir.assert_structural_equal(block.reads, ret[2])

# Check inner block AAA without updating buffer_var_map
ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
# Since AA is not in the buffer_var_map, region of AA will not be collected.
tvm.ir.assert_structural_equal([], ret[1])

# Check inner block AAA
for match_buffer in block.match_buffers:
target_buffer = match_buffer.buffer
buffer_var_map[target_buffer.data] = target_buffer

ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
tvm.ir.assert_structural_equal(block_inner.writes, ret[1])


if __name__ == "__main__":
test_block_access_region_detector()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,63 @@ def transformed_match_buffer_func() -> None:
C1[()] = 0


@tvm.script.tir
def opaque_access(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [1024])
B = tir.match_buffer(b, [1024])
A_cache = tir.alloc_buffer([1024])
for i in tir.serial(0, 8):
with tir.block([8]) as [vi]:
with tir.block([8]) as [v]:
tir.bind(v, vi)
tir.reads([A[(v * 128) : ((v * 128) + 128)]])
tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]])
tir.evaluate(
tir.call_extern(
"test",
A_cache.data,
(v * 128),
128,
A.data,
(v * 128),
128,
dtype="float32",
)
)
for j in tir.serial(0, 128):
with tir.block([1024]) as [v]:
tir.bind(v, ((vi * 128) + j))
tir.reads([A_cache[v]])
tir.writes([B[v]])
B[v] = A_cache[v]


@tvm.script.tir
def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [1024])
B = tir.match_buffer(b, [1024])
for i in tir.serial(0, 8):
with tir.block([8]) as [vi]:
tir.reads(A[vi * 128 : vi * 128 + 128])
tir.writes(B[vi * 128 : vi * 128 + 128])
A_cache = tir.alloc_buffer([1024])
with tir.block([8]) as [v]:
tir.bind(v, vi)
tir.reads([A[v * 128 : v * 128 + 128]])
tir.writes([A_cache[v * 128 : v * 128 + 128]])
tir.evaluate(
tir.call_extern(
"test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32"
)
)
for j in tir.serial(0, 128):
with tir.block([1024]) as [v]:
tir.bind(v, ((vi * 128) + j))
tir.reads([A_cache[v]])
tir.writes([B[v]])
B[v] = A_cache[v]


def test_elementwise():
_check(element_func, transformed_element_func)

Expand All @@ -149,6 +206,10 @@ def test_match_buffer_allocation():
_check(match_buffer_func, transformed_match_buffer_func)


def test_opaque_access():
_check(opaque_access, transformed_opaque_access)


def test_lower_te():
x = te.placeholder((1,))
y = te.compute((1,), lambda i: x[i] + 2)
Expand All @@ -164,4 +225,5 @@ def test_lower_te():
test_elementwise()
test_locate_buffer_allocation()
test_match_buffer_allocation()
test_opaque_access()
test_lower_te()