Skip to content

Commit

Permalink
[TVMScript] Updated buffer_var printing
Browse files Browse the repository at this point in the history
LetStmt and AllocateNode can both be used to generate handles that are
used in Buffer objects.  In these cases, the Buffer declarations must
go after the handle declaration, not in the function header.
  • Loading branch information
Lunderberg committed Jan 28, 2022
1 parent 66dac85 commit f9041dc
Showing 1 changed file with 99 additions and 26 deletions.
125 changes: 99 additions & 26 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,57 @@ enum class ExprPrecedence : int {
kUnknown = 7,
};

/*! \brief Utility used for identifying usage of a buffer_var
*
* \details Find the Buffer object that corresponds to a variable or
* allocation, based on the BufferLoad/BufferStore instances that
* occur within the allocation's body.
*/
class BufferUsageFinder : public StmtExprVisitor {
public:
static void FindUsage(Map<Var, Array<Buffer>>& usage, Stmt body) {
BufferUsageFinder visitor(usage);
visitor.VisitStmt(body);
}

void VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
if (!usage_.count(var)) {
usage_.Set(var, {});
}
}

void VisitExpr_(const BufferLoadNode* op) final {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}

private:
BufferUsageFinder(Map<Var, Array<Buffer>>& usage) : usage_(usage) {}

void VisitBuffer(const Buffer& buffer) {
if (buffers_visited_.count(buffer.get())) {
return;
}
buffers_visited_.insert(buffer.get());

Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
arr.push_back(buffer);
usage_.Set(buffer->data, arr);
}

// The search result.
Map<Var, Array<Buffer>>& usage_;
// The buffers that have been visited so far, to avoid duplicate
// entries in the search result.
std::unordered_set<const BufferNode*> buffers_visited_;
};

/*!
* \brief The printer for TVMScript
* \details The printer obtain the precedence of the top-level operation when printing each
Expand Down Expand Up @@ -138,6 +189,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
* 3. The iter range is equal to loop range
*/
std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
/*!
* \brief Map from variables to the buffers they are used in.
*
* Used for identifying buffers that should be declared after the
* LetStmt or Allocate that generates their data pointer, rather
* than in the header.
*/
Map<Var, Array<Buffer>> buffer_var_usage_;

Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
Expand Down Expand Up @@ -201,6 +260,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintRange(const RangeNode* op);
Doc PrintArray(const ArrayNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body);
Doc AllocBufferDeclaration(const Buffer& buf);
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
Doc PrintBlockVarRemaps();
Expand Down Expand Up @@ -830,11 +890,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
<< PrintBody(op->body));
} else {
if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
<< Doc::NewLine() << PrintBody(op->body);
<< Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
<< PrintBody(op->body);
}
return doc;
}
Expand Down Expand Up @@ -923,33 +985,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {

Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
var_not_in_headers_.insert(op->buffer_var.get());
Doc doc;

auto storage_scope = GetPtrStorageScope(op->buffer_var);
Doc func_call;
func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
<< ", " << Print(storage_scope);
if (!is_one(op->condition)) {
func_call << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
func_call << ", annotations={";
func_call << PrintAnnotations(op->annotations);
func_call << "}";
}
func_call << ")";

Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", "
<< PrintDType(op->dtype) << ", " << Print(storage_scope);
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ") as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine()
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body)
<< PrintBody(op->body));
} else {
doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents)
<< ", " << PrintDType(op->dtype) << ", " << Print(storage_scope);
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ")" << Doc::NewLine() << PrintBody(op->body);
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine()
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
return doc;
Expand Down Expand Up @@ -1458,6 +1517,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}

Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) {
if (!buffer_var_usage_.count(buffer_var)) {
BufferUsageFinder::FindUsage(buffer_var_usage_, body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({});
Doc decls;
for (const auto& buf_usage : buffer_usage) {
decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
<< memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
buf_not_in_headers_.insert(buf_usage.get());
}
return decls;
}

Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
if (op->region.size() == 0) {
Expand Down

0 comments on commit f9041dc

Please sign in to comment.