-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TVMScript] Support T.buffer_decl using data pointer from Let/Allocate #10099
Changes from all commits
fcab55e
0c589aa
1b70fa2
a18667c
2454682
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,7 @@ | |
from .tir.node import Slice, BufferSlice | ||
from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler | ||
from .tir.special_stmt import SpecialStmt | ||
from .tir import ty | ||
|
||
|
||
class CallArgumentReader(object): | ||
|
@@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: | |
# add parameters of function | ||
for arg in node.params: | ||
# Note that this case is for T.match_buffer syntax sugar | ||
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)): | ||
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance( | ||
self.transform(arg.ty.func_name), ty.GenericBufferType | ||
): | ||
result = self.handle_match_buffer_type(arg.ty, arg.name) | ||
if not isinstance(result, buffer.Buffer): | ||
self.report_error( | ||
|
@@ -1138,6 +1141,33 @@ def transform_TypeTuple(self, node): | |
""" | ||
return [self.transform(value) for value in node.values] | ||
|
||
def transform_TypeApply(self, node): | ||
"""Visitor for Type[Type] expressions. | ||
|
||
Mostly used for ``T.Ptr`` expressions. | ||
""" | ||
func = self.transform(node.func_name) | ||
|
||
if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"): | ||
self.report_error( | ||
f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), " | ||
f"but found {type(func).__name__} instead.", | ||
node.span, | ||
) | ||
|
||
param_types = [] | ||
for param in node.params: | ||
param_type = self.transform(param) | ||
if not isinstance(param_type, ty.TypeGeneric): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this for-loop imply that all params of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't, no. The parameters of At some point, I may see if there's support for renaming |
||
self.report_error(f"Expected a type but found {type(param).__name__}", param.span) | ||
|
||
param_types.append(param_type) | ||
|
||
if len(param_types) == 1: | ||
return func[param_types[0]] | ||
else: | ||
return func[param_types] | ||
|
||
def handle_match_buffer_type(self, node, buffer_name): | ||
"""special function to handle syntax sugar for match buffer. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,13 +38,26 @@ def __call__(self): | |
|
||
|
||
class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method | ||
"""TVM script typing class for uniform Type objects""" | ||
"""TVM script typing class for uniform Type objects | ||
|
||
Params | ||
------ | ||
vtype: Union[str, tvm.ir.Type] | ||
|
||
The IR type represented by the type annotation. If a string | ||
(e.g. "float32"), this represents a `ir.PrimType` generated | ||
from that string. If a `ir.Type` is provided, this represents | ||
the type provided. | ||
""" | ||
|
||
def __init__(self, vtype): | ||
self.type = vtype | ||
if isinstance(vtype, tvm.ir.Type): | ||
self.type = vtype | ||
else: | ||
self.type = tvm.ir.PrimType(vtype) | ||
|
||
def evaluate(self): | ||
return tvm.ir.PrimType(self.type) | ||
return self.type | ||
|
||
|
||
class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method | ||
|
@@ -54,6 +67,8 @@ class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method | |
""" | ||
|
||
def __getitem__(self, vtype): | ||
if not isinstance(vtype, TypeGeneric): | ||
raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}") | ||
return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) | ||
|
||
|
||
|
@@ -65,6 +80,8 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method | |
""" | ||
|
||
def __getitem__(self, vtypes): | ||
if isinstance(vtypes, TypeGeneric): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be a bug fix right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's correct. If |
||
vtypes = [vtypes] | ||
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a
transform_TypeTuple
impl that may be useful hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had looked at the
transform_TypeTuple
implementation, and I don't think it's directly applicable. It assumes that there is asynr.ast.TypeTuple
node, with types innode.values
, and doesn't have a way to accept a list of types directly or to access thenode.params
of aTypeApply
node. TheTVMScriptParser.parse_arg_list
is the closest I found to the desired functionality, but it requires the node to be an intrinsic, scope handler, or special statement, and doesn't have a case for type annotations.