Skip to content

Commit

Permalink
[TIR][TVMScript] Update printer / parser to make T.allocate return bu…
Browse files Browse the repository at this point in the history
…ffer var (apache#12412)

* Updated TVMScript syntax of `T.allocate` to return buffer var.

* Added syntax sugar for `T.decl_buffer`. When `data` field is not
  specified, `data` will be implicitly created via `Allocate` stmt.
  
* Updated the existing test cases. Most test cases can be updated by
  changing `T.allocate` to `T.decl_buffer`. `T.allocate` in some tests
  are updated to `T.allocate` + `T.buffer_decl`, to maintain the
  legacy behavior of allocation and implicit buffer declaration (will
  be followed up in future PR to adopt `T.decl_buffer`).
  • Loading branch information
vinx13 authored Aug 31, 2022
1 parent acbbd9f commit 0c37454
Show file tree
Hide file tree
Showing 32 changed files with 804 additions and 590 deletions.
57 changes: 30 additions & 27 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,17 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None)
scope = tvm.runtime.convert(scope)

return tvm.tir.Allocate(
self.buffer.data,
self.buffer.dtype,
self.buffer.shape,
self.buffer_var,
dtype,
extents,
condition,
self.body,
annotations=annotations,
span=span,
)

super().__init__(allocate, concise_scope=True, def_symbol=True)
self.buffer = None
self.buffer_var = None

def enter_scope(
self,
Expand All @@ -146,20 +146,15 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer(
def setup_buffer_var(
extents, dtype, scope, condition=True, annotations=None, span: Span = None
):
"""Setup buffer object for a given type."""
self.buffer = tvm.tir.decl_buffer(
shape=extents,
dtype=dtype,
name=name,
scope=scope,
span=span,
)
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer, node)
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer_var, node)


@register
Expand All @@ -176,7 +171,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
list_data.append(i.value)
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
n = tvm.tir.AllocateConst(
self.buffer.data,
self.buffer_var,
dtype,
shape,
nd_data,
Expand All @@ -187,7 +182,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
return n

super().__init__(allocate_const, concise_scope=True, def_symbol=True)
self.buffer = None
self.buffer_var = None

def enter_scope(
self,
Expand All @@ -211,17 +206,13 @@ def enter_scope(
else:
raise Exception("Internal Bug")

def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None):
def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None):
"""Setup buffer var for a given type."""
self.buffer = tvm.tir.decl_buffer(
shape=shape,
dtype=dtype,
name=name,
span=span,
)
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer, node)
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer_var, node)


@register
Expand All @@ -248,7 +239,18 @@ def decl_buffer(
axis_separators=None,
span=None,
):
return tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
if data is None:
# when data is not specified, the buffer is implicitly allocated
return tvm.tir.Allocate(
self.buffer.data,
dtype,
shape,
tvm.runtime.convert(True),
decl_buffer,
span=span,
)
return decl_buffer

super().__init__(decl_buffer, concise_scope=True, def_symbol=True)

Expand Down Expand Up @@ -298,6 +300,7 @@ def setup_buffer(
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
name=name,
span=span,
)

Expand Down
128 changes: 70 additions & 58 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,22 @@ class BufferUsageFinder : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const DeclBufferNode* op) final {
buffers_declared_.insert(op->buffer.get());
StmtExprVisitor::VisitStmt_(op);
buffers_declared_.erase(op->buffer.get());
}

private:
explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}

void VisitBuffer(const Buffer& buffer) {
if (buffers_visited_.count(buffer.get())) {
return;
}
if (buffers_declared_.count(buffer.get())) {
return;
}
buffers_visited_.insert(buffer.get());

Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
Expand All @@ -119,6 +128,9 @@ class BufferUsageFinder : public StmtExprVisitor {
// The buffers that have been visited so far, to avoid duplicate
// entries in the search result.
std::unordered_set<const BufferNode*> buffers_visited_;
// The buffers declared via `DeclBuffer`. These buffers are excluded from the result because
// T.buffer_decl shouldn't be printed for them.
std::unordered_set<const BufferNode*> buffers_declared_;
};

/*!
Expand Down Expand Up @@ -1055,58 +1067,57 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
}

namespace {
struct AllocUsage {
Buffer alloc_buffer;
Array<Buffer> aliasing_buffers;
};

template <typename AllocNode>
AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
Map<Var, Array<Buffer>>& cache = *cache_ptr;
if (!cache.count(op->buffer_var)) {
cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
const Var& buffer_var = allocate->buffer_var;
const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
if (!decl_buffer) {
return false;
}
Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});

auto is_exact_match = [](Buffer a, Buffer b) {
if (a->dtype != b->dtype) return false;
if (a->shape.size() != b->shape.size()) return false;

arith::Analyzer analyzer;
for (size_t i = 0; i < a->shape.size(); i++) {
if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
return false;
}
}
return true;
};

// If the buffer allocated via T.allocate is an exact match to the
// usage of the buffer later on, then that buffer is the return
// value of T.allocate, and no T.buffer_decl statement is needed.
Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
0, kDefault);
bool found_alloc_buf = false;
Array<Buffer> aliasing_buffers;
for (const auto& buf : buffer_usage) {
if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
alloc_buffer = buf;
found_alloc_buf = true;
} else {
aliasing_buffers.push_back(buf);
const Buffer& buffer = decl_buffer->buffer;
if (!buffer_var.same_as(buffer->data)) {
return false;
}
if (allocate->dtype != buffer->dtype) {
return false;
}
if (!is_one(allocate->condition)) {
return false;
}
if (allocate->annotations.size()) {
return false;
}
if (allocate->extents.size() != buffer->shape.size()) {
return false;
}
tir::ExprDeepEqual expr_equal;
for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
return false;
}
}

return AllocUsage{alloc_buffer, aliasing_buffers};
return true;
}

} // namespace

Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
auto usage = FindAllocateUsage(op, &buffer_var_usage_);
Buffer& alloc_buffer = usage.alloc_buffer;
Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
buf_not_in_headers_.insert(alloc_buffer.get());
var_not_in_headers_.insert(alloc_buffer->data.get());
var_not_in_headers_.insert(op->buffer_var.get());

if (!buffer_var_usage_.count(op->buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});

if (buffer_usage.empty()) {
if (IsAllocateDeclBufferPattern(op)) {
// As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single
// DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to
// delegate the printing of the current node to `DeclBufferNode` while maintaining the
// same value of `current_num_` and `num_child_`.
return Print(op->body);
}
}

auto storage_scope = GetPtrStorageScope(op->buffer_var);
Doc func_call;
Expand All @@ -1124,12 +1135,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {

Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
<< PrintBody(op->body));
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(
4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
} else {
doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(op->body);
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
return doc;
Expand Down Expand Up @@ -1179,11 +1190,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
}
auto ndarray_str = ss.str();

auto usage = FindAllocateUsage(alloc, &buffer_var_usage_);
Buffer& alloc_buffer = usage.alloc_buffer;
Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
buf_not_in_headers_.insert(alloc_buffer.get());
var_not_in_headers_.insert(alloc_buffer->data.get());
var_not_in_headers_.insert(alloc->buffer_var.get());

if (!buffer_var_usage_.count(alloc->buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({});

Doc func_call;
func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype)
Expand All @@ -1192,12 +1204,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
Doc doc;
var_not_in_headers_.insert(alloc->buffer_var.get());
if (current_num_ != num_child_ - 1) {
doc << "with " << func_call << " as " << Print(alloc_buffer) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(aliasing_buffers)
doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage)
<< PrintBody(alloc->body));
} else {
doc << Print(alloc_buffer) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(aliasing_buffers) << PrintBody(alloc->body);
doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body);
}
return doc;
}
Expand Down
Loading

0 comments on commit 0c37454

Please sign in to comment.