Skip to content

Commit

Permalink
ast.Expr and concise scope (apache#54)
Browse files Browse the repository at this point in the history
* `ast.Expr` and concise scope

* add `BufferStore`
  • Loading branch information
cyx-6 authored and Hzfengsy committed Jul 27, 2022
1 parent aeadf59 commit c69e44b
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 1 deletion.
16 changes: 16 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,22 @@ struct PackedFuncValueConverter<PrimExpr> {
}
};

template <>
struct PackedFuncValueConverter<Array<PrimExpr>> {
static Array<PrimExpr> From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) return Array<PrimExpr>(nullptr);
Array<ObjectRef> vals = val.AsObjectRef<Array<ObjectRef>>();
Array<PrimExpr> exprs;
for (const ObjectRef& v : vals) {
TVMValue value;
value.v_handle = const_cast<void*>(static_cast<const void*>(v.get()));
exprs.push_back(
PackedFuncValueConverter<PrimExpr>::From(TVMArgValue(value, kTVMObjectHandle)));
}
return exprs;
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/script/builder/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ def __enter__(self) -> "Frame":

def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.FrameExit(self) # pylint: disable=no-member # type: ignore

def add_callback(self, callback) -> None: # pylint: disable=unused-argument
_ffi_api.FrameAddCallback(self, callback) # pylint: disable=no-member # type: ignore
3 changes: 3 additions & 0 deletions python/tvm/script/parse/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name
def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name
return _dispatch(self, "Assign")(self, node)

def visit_Expr(self, node: doc.Expr) -> Any:
return _dispatch(self, "Expr")(self, node) # pylint: disable=invalid-name


def _handle_function(self: Parser, node: doc.FunctionDef) -> None:
if not node.decorator_list:
Expand Down
16 changes: 15 additions & 1 deletion python/tvm/script/parse/tir/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from .. import dispatch, doc
from ..parser import Parser

from functools import partial


@dispatch.register(token="tir", type_name="For")
def visit_for(self: Parser, node: doc.For) -> None:
Expand All @@ -44,7 +46,14 @@ def visit_assign(self: Parser, node: doc.Assign) -> None:
self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
lhs = node.targets[0]
rhs = self.eval_expr(node.value)
self.eval_assign(target=lhs, source=rhs)
if isinstance(rhs, Frame):
rhs.add_callback(partial(rhs.__exit__, None, None, None))
res = rhs.__enter__()
self.eval_assign(target=lhs, source=res)
elif isinstance(lhs, doc.Subscript):
T.buffer_store(self.eval_expr(lhs.value), rhs, self.eval_expr(lhs.slice))
else:
self.eval_assign(target=lhs, source=rhs)


@dispatch.register(token="tir", type_name="With")
Expand Down Expand Up @@ -101,3 +110,8 @@ def visit_tvm_annotation(self: Parser, node: doc.expr):
if callable(annotation):
annotation = annotation()
return annotation


@dispatch.register(token="tir", type_name="Expr")
def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
self.eval_expr(node.value)
9 changes: 9 additions & 0 deletions src/script/builder/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,18 @@ void FrameNode::ExitWithScope() {
Builder::Current()->frames.pop_back();
}

void FrameNode::AddCallback(runtime::TypedPackedFunc<void()> callback) {
if (Builder::Current()->frames.empty()) {
LOG(FATAL) << "ValueError: No frames in Builder to add callback";
}
Builder::Current()->frames.back()->callbacks.push_back(callback);
}

TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_GLOBAL("script.builder.FrameEnter").set_body_method<Frame>(&FrameNode::EnterWithScope);
TVM_REGISTER_GLOBAL("script.builder.FrameExit").set_body_method<Frame>(&FrameNode::ExitWithScope);
TVM_REGISTER_GLOBAL("script.builder.FrameAddCallback")
.set_body_method<Frame>(&FrameNode::AddCallback);

} // namespace builder
} // namespace script
Expand Down
2 changes: 2 additions & 0 deletions src/script/builder/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class FrameNode : public runtime::Object {
virtual ~FrameNode() = default;
virtual void EnterWithScope();
virtual void ExitWithScope();

void AddCallback(runtime::TypedPackedFunc<void()> callback);
};

class Frame : public runtime::ObjectRef {
Expand Down
31 changes: 31 additions & 0 deletions tests/python/tvmscript/test_parse_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import tvm
from tvm.ir import structural_equal
from tvm.script.builder import ir as I
from tvm.script.builder import tir as T

Expand All @@ -19,6 +20,8 @@ def elementwise(
vi = T.axis.S(128, i + 1)
vj = T.axis.S(128, j + 20)
vk = T.axis.R(128, k - i)
A[vi + 1, vj] = 0
B[vi, vj, vk] = 1

# pylint: enable=unused-argument,unused-variable,invalid-name

Expand Down Expand Up @@ -79,9 +82,37 @@ def elementwise() -> None:
vj = T.axis.S(128, vvv[10] + 20)


def test_parse_concise_scope():
# pylint: disable=unused-argument,unused-variable,invalid-name
@T.prim_func
def concise_scope(
A: T.handle,
) -> None:
A_local = T.allocate([64], "float32", "local")
B_local = T.allocate([64], "float32", "local")
C_local = T.allocate([64], "float32", "local")
T.evaluate(1)
T.evaluate(2)
T.evaluate(3)

@T.prim_func
def normal_scope(
A: T.handle,
) -> None:
with T.allocate([64], "float32", "local") as A_local:
with T.allocate([64], "float32", "local") as B_local:
with T.allocate([64], "float32", "local") as C_local:
T.evaluate(1)
T.evaluate(2)
T.evaluate(3)

assert structural_equal(normal_scope, concise_scope)


if __name__ == "__main__":
test_parse_elementwise()
test_parse_skip()
test_parse_class()
test_parse_atomic()
test_parse_report_error()
test_parse_concise_scope()

0 comments on commit c69e44b

Please sign in to comment.