Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Support T.buffer_decl using data pointer from Let/Allocate #10099

Merged
merged 5 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .tir.node import Slice, BufferSlice
from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
from .tir.special_stmt import SpecialStmt
from .tir import ty


class CallArgumentReader(object):
Expand Down Expand Up @@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
# add parameters of function
for arg in node.params:
# Note that this case is for T.match_buffer syntax sugar
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
self.transform(arg.ty.func_name), ty.GenericBufferType
):
result = self.handle_match_buffer_type(arg.ty, arg.name)
if not isinstance(result, buffer.Buffer):
self.report_error(
Expand Down Expand Up @@ -1138,6 +1141,33 @@ def transform_TypeTuple(self, node):
"""
return [self.transform(value) for value in node.values]

def transform_TypeApply(self, node):
"""Visitor for Type[Type] expressions.

Mostly used for ``T.Ptr`` expressions.
"""
func = self.transform(node.func_name)

if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"):
self.report_error(
f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), "
f"but found {type(func).__name__} instead.",
node.span,
)

param_types = []
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a transform_TypeTuple impl that may be useful here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had looked at the transform_TypeTuple implementation, and I don't think it's directly applicable. It assumes that there is a synr.ast.TypeTuple node, with types in node.values, and doesn't have a way to accept a list of types directly or to access the node.params of a TypeApply node. The TVMScriptParser.parse_arg_list is the closest I found to the desired functionality, but it requires the node to be an intrinsic, scope handler, or special statement, and doesn't have a case for type annotations.

for param in node.params:
param_type = self.transform(param)
if not isinstance(param_type, ty.TypeGeneric):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this for-loop imply that all params of TypeApply call will be transformed into GenericPtrType? If so can we specify that here as well instead of ty.TypeGeneric

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't, no. The parameters of TypeApply typically aren't GenericPtrType. For example, T.Ptr[T.int32], the T.int32 parameter is a ConcreteType.

At some point, I may see if there's support for renaming ty.TypeGeneric to ty.Type. As it is, ty.ConcreteType is a subclass of ty.TypeGeneric, which doesn't make very much sense to me as there since it require any generic parameters in the user-supplied tvmscript.

self.report_error(f"Expected a type but found {type(param).__name__}", param.span)

param_types.append(param_type)

if len(param_types) == 1:
return func[param_types[0]]
else:
return func[param_types]

def handle_match_buffer_type(self, node, buffer_name):
"""special function to handle syntax sugar for match buffer.

Expand Down
23 changes: 20 additions & 3 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,26 @@ def __call__(self):


class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method
"""TVM script typing class for uniform Type objects"""
"""TVM script typing class for uniform Type objects

Params
------
vtype: Union[str, tvm.ir.Type]

The IR type represented by the type annotation. If a string
(e.g. "float32"), this represents a `ir.PrimType` generated
from that string. If a `ir.Type` is provided, this represents
the type provided.
"""

def __init__(self, vtype):
self.type = vtype
if isinstance(vtype, tvm.ir.Type):
self.type = vtype
else:
self.type = tvm.ir.PrimType(vtype)

def evaluate(self):
return tvm.ir.PrimType(self.type)
return self.type


class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
Expand All @@ -54,6 +67,8 @@ class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
"""

def __getitem__(self, vtype):
if not isinstance(vtype, TypeGeneric):
raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
return ConcreteType(tvm.ir.PointerType(vtype.evaluate()))


Expand All @@ -65,6 +80,8 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method
"""

def __getitem__(self, vtypes):
if isinstance(vtypes, TypeGeneric):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a bug fix right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. If transform_TypeApply was to be able to handle both T.Ptr and T.Tuple cases, then I wanted to make sure that they both accepted the same types of arguments. The two options I considered were (a) always passing a tuple of parameter types even if there is only 1, or (b) passing a bare type when there is only 1 and otherwise passing a tuple of parameter types. Previously, T.Ptr implicitly followed convention (b), while T.Tuple followed convention (a). Option (b) matches python's subscripting syntax, so that's the one that I chose.

vtypes = [vtypes]
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))


Expand Down
126 changes: 100 additions & 26 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,58 @@ enum class ExprPrecedence : int {
kUnknown = 7,
};

/*! \brief Utility used for identifying usage of a buffer_var
*
* \details Find the Buffer object that corresponds to a variable or
* allocation, based on the BufferLoad/BufferStore instances that
* occur within the allocation's body.
*/
class BufferUsageFinder : public StmtExprVisitor {
public:
static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) {
BufferUsageFinder visitor(std::move(usage));
visitor.VisitStmt(body);
return std::move(visitor.usage_);
}

void VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
if (!usage_.count(var)) {
usage_.Set(var, {});
}
}

void VisitExpr_(const BufferLoadNode* op) final {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}

private:
explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}

void VisitBuffer(const Buffer& buffer) {
if (buffers_visited_.count(buffer.get())) {
return;
}
buffers_visited_.insert(buffer.get());

Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
arr.push_back(buffer);
usage_.Set(buffer->data, arr);
}

// The search result.
Map<Var, Array<Buffer>> usage_;
// The buffers that have been visited so far, to avoid duplicate
// entries in the search result.
std::unordered_set<const BufferNode*> buffers_visited_;
};

/*!
* \brief The printer for TVMScript
* \details The printer obtain the precedence of the top-level operation when printing each
Expand Down Expand Up @@ -138,6 +190,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
* 3. The iter range is equal to loop range
*/
std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
/*!
* \brief Map from variables to the buffers they are used in.
*
* Used for identifying buffers that should be declared after the
* LetStmt or Allocate that generates their data pointer, rather
* than in the header.
*/
Map<Var, Array<Buffer>> buffer_var_usage_;

Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
Expand Down Expand Up @@ -201,6 +261,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintRange(const RangeNode* op);
Doc PrintArray(const ArrayNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body);
Doc AllocBufferDeclaration(const Buffer& buf);
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
Doc PrintBlockVarRemaps();
Expand Down Expand Up @@ -830,11 +891,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
<< PrintBody(op->body));
} else {
if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
<< Doc::NewLine() << PrintBody(op->body);
<< Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body);
}
return doc;
}
Expand Down Expand Up @@ -923,33 +986,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {

Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
var_not_in_headers_.insert(op->buffer_var.get());
Doc doc;

auto storage_scope = GetPtrStorageScope(op->buffer_var);
Doc func_call;
func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
<< ", " << Print(storage_scope);
if (!is_one(op->condition)) {
func_call << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
func_call << ", annotations={";
func_call << PrintAnnotations(op->annotations);
func_call << "}";
}
func_call << ")";

Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", "
<< PrintDType(op->dtype) << ", " << Print(storage_scope);
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ") as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine()
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body)
<< PrintBody(op->body));
} else {
doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents)
<< ", " << PrintDType(op->dtype) << ", " << Print(storage_scope);
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ")" << Doc::NewLine() << PrintBody(op->body);
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
return doc;
Expand Down Expand Up @@ -1458,6 +1518,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}

Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) {
if (!buffer_var_usage_.count(buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({});
Doc decls;
for (const auto& buf_usage : buffer_usage) {
decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
<< memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
buf_not_in_headers_.insert(buf_usage.get());
}
return decls;
}

Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
if (op->region.size() == 0) {
Expand Down
47 changes: 47 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3255,5 +3255,52 @@ def test_root_attr():
tvm.ir.assert_structural_equal(func, rt_func, True)


@T.prim_func
def func_T_ptr_let_statement(
args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32
) -> None:
# The T.Ptr declaration in the parameter list should parse
# correctly, and should be usable as the data pointer in a buffer.
arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle)

arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle")
arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle")

# Functions that return a "handle" can be assigned to a T.Ptr
# variable. A variable annotated with T.Ptr still has dtype of
# T.handle, but has type annotation as a pointer type.
A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle")

# The buffer declaration has a data pointer defined earlier in
# this function. It should only be defined after the data pointer
# has been defined, and should not be hoisted into the header of
# the function as other buffer_decl statements can be.
A = T.buffer_decl([1024], dtype="float32", data=A_data)
B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle")
B = T.buffer_decl([1024], dtype="float32", data=B_data)

B[0] = A[0]


def test_T_ptr_let_statement():
func = func_T_ptr_let_statement
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


@T.prim_func
def func_T_ptr_allocate() -> None:
A_data: T.Ptr[T.float32] = T.allocate([1024], "float32", "global")
A = T.buffer_decl([1024], dtype="float32", data=A_data)

A[0] = 0.0


def test_T_ptr_allocate():
func = func_T_ptr_allocate
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))