Skip to content

Commit

Permalink
[TVMScript] IRModule TVMScript Parser.
Browse files Browse the repository at this point in the history
This PR adds the TVMScript parser/ir_builder support based on the
blockbuilder.  This commit contains the non-relax portions from
apache#13932.

Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Tianqi Chen <tianqi.tchen@gmail.com>
Co-authored-by: Yuchen Jin <yuchenj@cs.washington.edu>
Co-authored-by: Steven S. Lyubomirsky <slyubomirsky@gmail.com>
Co-authored-by: Yong Wu <yongcale@gmail.com>
  • Loading branch information
7 people authored and Lunderberg committed Apr 5, 2023
1 parent 7a73254 commit bd0d81a
Show file tree
Hide file tree
Showing 16 changed files with 245 additions and 34 deletions.
11 changes: 8 additions & 3 deletions include/tvm/script/ir_builder/ir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@ namespace ir {
*/
class IRModuleFrameNode : public IRBuilderFrameNode {
public:
Array<GlobalVar> global_vars;
Array<BaseFunc> functions;
/*! \brief A map from string names to global variables that ensures global uniqueness. */
Map<String, GlobalVar> global_var_map;
/*!
* \brief A map from GlobalVar to all global functions.
* \note Only defined functions are in the map, while declared functions are not included.
*/
Map<GlobalVar, BaseFunc> functions;

void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_vars);
v->Visit("global_vars", &global_var_map);
v->Visit("functions", &functions);
}

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/script/ir_builder/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ namespace ir {
*/
TVM_DLL IRModuleFrame IRModule();

/*!
* \brief Declare a Function without given the specific function implementation.
* \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
* \param func_name The function unique name.
* \param func_signature A Function w/o body, which used to specify the function signature
* (i.e. func params and func return type/shape).
* \return The corresponding GlobalVar.
*/
TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature);

/*!
* \brief Define the function which is declared before.
* \param func_name The function unique name.
* \param func The given function implementation
*/
TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func);

} // namespace ir
} // namespace ir_builder
} // namespace script
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/script/ir_builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame":
_ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
return self

def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member
def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument
if exc_type is None and exc_value is None:
# Do not execute `FrameExit` if the with scope exits because of exceptions
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member

def add_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback method invoked when exiting the with-scope.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/ir_builder/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# under the License.
"""Package tvm.script.ir_builder.ir"""
from .frame import IRModuleFrame
from .ir import ir_module
from .ir import decl_function, def_function, ir_module
45 changes: 45 additions & 0 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,54 @@
# under the License.
"""Package tvm.script.ir_builder.ir.ir"""

from tvm.ir import BaseFunc, GlobalVar

from . import _ffi_api
from .frame import IRModuleFrame


def ir_module() -> IRModuleFrame:
"""Start a ir_module frame.
Returns
-------
frame: IRModuleFrame
The constructed frame.
"""
return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member


def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
"""Declare a Function without given the specific function implementation.
Parameters
----------
func_name : str
The function unique name.
func_signature: Optional[BaseFunc]
A Function w/o body, which used to specify the function signature
(i.e. func params and func return type/shape).
Note
----
It is usually used in cross-function call. And we can specify the function by `DefFunction`
Returns
-------
gv : GlobalVar
The corresponding GlobalVar.
"""

return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member
func_name, func_signature
)


def def_function(func_name: str, func: BaseFunc) -> None:
"""Define the function which is declared before.
Parameters
----------
func_name : str
The function unique name.
func: BaseFunc
The given function implementation
"""
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member
2 changes: 1 addition & 1 deletion python/tvm/script/parser/core/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel)
level : diagnostics.DiagnosticLevel
The diagnostic level.
"""
lineno = node.lineno or self.source.start_line
lineno = node.lineno or 1
col_offset = node.col_offset or self.source.start_column
end_lineno = node.end_lineno or lineno
end_col_offset = node.end_col_offset or col_offset
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any:
else:
value = self._eval_expr(node.__class__(**fields))
except Exception as e: # pylint: disable=broad-except,invalid-name
self.parser.report_error(node, str(e))
self.parser.report_error(node, e)
return self._add_intermediate_result(value)

def _eval_lambda(self, node: doc.Lambda) -> Any:
Expand Down
50 changes: 35 additions & 15 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def context():
return context()


def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument
pass


class VarTableFrame:
"""The variable table frame.
A frame of variable table stores the variables created in one block or scope.
Expand Down Expand Up @@ -260,6 +264,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
node = self.diag.source.as_ast()
self.visit(node)

def get_dispatch_token(self, node: doc.FunctionDef) -> str:
if not isinstance(node, doc.FunctionDef):
self.report_error(node, "Only can get dispatch token for function.")
if not node.decorator_list:
self.report_error(node, "Function must be decorated")
# TODO: only the last decorator is parsed
decorator = self.eval_expr(node.decorator_list[-1])
if not hasattr(decorator, "dispatch_token"):
self.report_error(node, "The parser does not understand the decorator")
return decorator.dispatch_token

def with_dispatch_token(self, token: str):
"""Add a new dispatching token as with statement.
Expand Down Expand Up @@ -389,6 +404,8 @@ def report_error(
# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
elif isinstance(err, KeyError):
msg = "KeyError: " + str(err)
else:
msg = str(err)
self.diag.error(node, msg)
Expand Down Expand Up @@ -458,30 +475,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any:
"""
return _dispatch(self, "tvm_annotation")(self, node)

def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
"""The general function definition visiting method.
def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name
"""The general function definition visit method.
Parameters
----------
node : doc.FunctionDef
The doc AST function definition node.
Returns
-------
res : Any
The visiting result.
The doc FunctionDef node.
"""
if not node.decorator_list:
self.report_error(node, "Function must be decorated")
# TODO: only the last decorator is parsed
decorator = self.eval_expr(node.decorator_list[-1])
if not hasattr(decorator, "dispatch_token"):
self.report_error(node, "The parser does not understand the decorator")
token = decorator.dispatch_token
token = self.get_dispatch_token(node)
current_token = self.dispatch_tokens[-1]
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is None:
self.report_error(node, "The parser does not understand the decorator")
pre_func = dispatch.get(
token=current_token, type_name="pre_token_switch", default=_do_nothing
)
post_func = dispatch.get(
token=current_token, type_name="post_token_switch", default=_do_nothing
)
pre_func(self, node)
_dispatch_wrapper(func)(self, node)
post_func(self, node)

def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
token = self.get_dispatch_token(node)
with self.with_dispatch_token(token):
_dispatch(self, "tvm_declare_function")(self, node)

def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
"""The general class definition visiting method.
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/parser/ir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
node : doc.ClassDef
The doc AST class definition node.
"""

with self.var_table.with_frame():
with I.ir_module():
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit_tvm_declare_function(stmt)
with self.with_dispatch_token("ir"):
self.visit_body(node.body)

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer:
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # pylint: disable=no-member # type: ignore
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member


class PtrProxy:
Expand All @@ -93,7 +93,7 @@ class PtrProxy:
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member

@deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.ir import PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var

from ...ir_builder import ir as I
from ...ir_builder import tir as T
from ...ir_builder.base import IRBuilder
from ...ir_builder.base import IRBuilderFrame as Frame
Expand Down Expand Up @@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None:
The doc AST return node.
"""
self.report_error(node, "Return is not allowed.")


@dispatch.register(token="tir", type_name="tvm_declare_function")
def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
"""The function declaration step for tir
Parameters
----------
self : Parser
The visiting parser.
node : doc.Return
The doc AST return node.
"""

ret_type = None
if node.returns is not None:
ret_type = self.eval_expr(node.returns)
if callable(ret_type):
ret_type = PrimType(ret_type().dtype)

# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
global_var = I.decl_function(node.name, func_signature)
self.var_table.add(node.name, global_var)
12 changes: 8 additions & 4 deletions src/script/ir_builder/ir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ namespace ir_builder {
namespace ir {

void IRModuleFrameNode::ExitWithScope() {
ICHECK_EQ(functions.size(), global_vars.size());
int n = functions.size();
Map<GlobalVar, BaseFunc> func_map;
for (int i = 0; i < n; ++i) {
func_map.Set(global_vars[i], functions[i]);
CHECK_EQ(functions.size(), global_var_map.size())
<< "All functions must be defined in the IRModule. Got " << global_var_map.size()
<< "declared function(s), but only " << functions.size() << "defined function(s).";
for (const auto& kv : functions) {
const GlobalVar& gv = kv.first;
const BaseFunc& func = kv.second;
CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined";
func_map.Set(gv, func);
}
IRBuilder builder = IRBuilder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
Expand Down
32 changes: 31 additions & 1 deletion src/script/ir_builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,49 @@
#include <tvm/runtime/registry.h>
#include <tvm/script/ir_builder/ir/ir.h>

#include "./utils.h"

namespace tvm {
namespace script {
namespace ir_builder {
namespace ir {

IRModuleFrame IRModule() {
ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
n->global_vars.clear();
n->global_var_map.clear();
n->functions.clear();
return IRModuleFrame(n);
}

GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) {
IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
CHECK(!frame->global_var_map.count(func_name))
<< "ValueError: function " << func_name << " already exists";
GlobalVar gv = GlobalVar(func_name);
CHECK(frame->functions.find(gv) == frame->functions.end())
<< "ValueError: function " << func_name << " has already been defined.";
frame->global_var_map.Set(func_name, gv);
if (func_signature.defined()) {
frame->functions.Set(gv, func_signature);
}
return gv;
}

void DefFunction(const String& func_name, const BaseFunc& func) {
IRModuleFrame frame = FindModuleFrame("I.DefFunction");
auto it = frame->global_var_map.find(func_name);
CHECK(it != frame->global_var_map.end())
<< "ValueError: function " << func_name << " does not exist, please declare it first.";
const GlobalVar& gv = (*it).second;
frame->functions.Set(gv, func);
if (func->checked_type_.defined()) {
gv->checked_type_ = func->checked_type_;
}
}

TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);

} // namespace ir
} // namespace ir_builder
Expand Down
Loading

0 comments on commit bd0d81a

Please sign in to comment.