From e6b1aa150317e5da142a5bafbdfb55aeac6e7a7e Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 1 Jun 2022 19:24:04 -0700 Subject: [PATCH] [IRBuilder] Misc fix-ups for the python binding (#39) * [IRBuilder] Misc fix-ups for the python binding * addon --- python/tvm/script/builder/_ffi_api.py | 2 +- python/tvm/script/builder/builder.py | 31 +++++++------ python/tvm/script/builder/frame.py | 11 +++-- python/tvm/script/builder/tir/__init__.py | 14 +++--- python/tvm/script/builder/tir/_ffi_api.py | 6 +-- python/tvm/script/builder/tir/axis.py | 9 ++-- python/tvm/script/builder/tir/base.py | 3 +- python/tvm/script/builder/tir/block_frame.py | 7 ++- python/tvm/script/builder/tir/for_frame.py | 32 +++++++------ .../tvm/script/builder/tir/prim_func_frame.py | 16 +++---- python/tvm/script/builder/tir/var.py | 13 ++++-- src/script/builder/frame.cc | 9 +--- src/script/builder/tir/base.cc | 30 ------------ src/script/builder/tir/block_frame.cc | 4 -- src/script/builder/tir/for_frame.cc | 46 +++++++++---------- src/script/builder/tir/for_frame.h | 10 ++-- src/script/builder/tir/prim_func_frame.cc | 9 ++-- .../python/unittest/test_tvmscript_builder.py | 4 +- 18 files changed, 108 insertions(+), 148 deletions(-) diff --git a/python/tvm/script/builder/_ffi_api.py b/python/tvm/script/builder/_ffi_api.py index ec20ad798f80..3410494ded4d 100644 --- a/python/tvm/script/builder/_ffi_api.py +++ b/python/tvm/script/builder/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.builder""" import tvm._ffi -tvm._ffi._init_api("script.builder", __name__) +tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/builder/builder.py b/python/tvm/script/builder/builder.py index 3d449ef1975a..6ec31abecb5f 100644 --- a/python/tvm/script/builder/builder.py +++ b/python/tvm/script/builder/builder.py @@ -15,44 +15,47 @@ # specific language governing permissions and limitations # under the License. """TVM Script IR Builder""" -from typing import List -from tvm._ffi import register_object as _register_object -from .frame import Frame +from typing import List, TypeVar +from tvm._ffi import register_object as _register_object from tvm.runtime import Object from . import _ffi_api - -from typing import TypeVar +from .frame import Frame @_register_object("script.builder.Builder") class Builder(Object): def __init__(self) -> None: - self.__init_handle_by_constructor__(_ffi_api.Builder) + self.__init_handle_by_constructor__( + _ffi_api.Builder # pylint: disable=no-member # type: ignore + ) def __enter__(self) -> "Builder": - _ffi_api.BuilderEnter(self) + _ffi_api.BuilderEnter(self) # pylint: disable=no-member # type: ignore return self - def __exit__(self, ptype, value, trace) -> None: - _ffi_api.BuilderExit(self) + def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument + _ffi_api.BuilderExit(self) # pylint: disable=no-member # type: ignore @staticmethod - def current(self) -> "Builder": - return _ffi_api.BuilderCurrent(self) + def current() -> "Builder": + return _ffi_api.BuilderCurrent() # pylint: disable=no-member # type: ignore def get(self) -> Frame: - return _ffi_api.BuilderGet(self) + return _ffi_api.BuilderGet(self) # pylint: disable=no-member # type: ignore DefType = TypeVar("DefType", bound=Object) def def_(name: str, var: DefType) -> DefType: - return _ffi_api.Def(name, var) + return _ffi_api.Def(name, var) # pylint: disable=no-member # type: ignore -def def_many(names: List[str], vars: List[DefType]) -> List[DefType]: +def def_many( + names: List[str], + vars: List[DefType], # pylint: disable=redefine-builtin +) -> List[DefType]: assert len(names) == len(vars) return [def_(name, var) for name, var in zip(names, vars)] diff --git a/python/tvm/script/builder/frame.py b/python/tvm/script/builder/frame.py index 7f6ac8972fcf..faf0c4271d7d 100644 --- a/python/tvm/script/builder/frame.py +++ b/python/tvm/script/builder/frame.py @@ -16,7 +16,6 @@ # under the License. """TVM Script Frames""" from tvm._ffi import register_object as _register_object - from tvm.runtime import Object from . import _ffi_api @@ -25,14 +24,16 @@ @_register_object("script.builder.Frame") class Frame(Object): def __enter__(self) -> "Frame": - _ffi_api.FrameEnter(self) + _ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore return self - def __exit__(self, ptype, value, trace) -> None: - _ffi_api.FrameExit(self) + def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument + _ffi_api.FrameExit(self) # pylint: disable=no-member # type: ignore @_register_object("script.builder.IRModuleFrame") class IRModuleFrame(Frame): def __init__(self) -> None: - self.__init_handle_by_constructor__(_ffi_api.IRModuleFrame) + self.__init_handle_by_constructor__( + _ffi_api.IRModuleFrame # pylint: disable=no-member # type: ignore + ) diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py index c206e5e84059..e6e431e5bced 100644 --- a/python/tvm/script/builder/tir/__init__.py +++ b/python/tvm/script/builder/tir/__init__.py @@ -17,17 +17,17 @@ # pylint: disable=unused-import """Namespace for the TVMScript TIR Builder API.""" +from . import axis from .base import TIRFrame +from .block_frame import block from .for_frame import ( ForFrame, - serial, + grid, parallel, - vectorized, - unroll, + serial, thread_binding, - grid, + unroll, + vectorized, ) -from .prim_func_frame import prim_func, arg -from .block_frame import block +from .prim_func_frame import arg, prim_func from .var import Buffer -from . import axis diff --git a/python/tvm/script/builder/tir/_ffi_api.py b/python/tvm/script/builder/tir/_ffi_api.py index df97ad7ae7f2..4e40e7261fd3 100644 --- a/python/tvm/script/builder/tir/_ffi_api.py +++ b/python/tvm/script/builder/tir/_ffi_api.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.script.builder""" +"""FFI APIs for tvm.script.builder.tir""" import tvm._ffi -from .. import _ffi_api as _base_ffi_api - -tvm._ffi._init_api("script.builder.tir", __name__) +tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/builder/tir/axis.py b/python/tvm/script/builder/tir/axis.py index 9bb3e75650b5..0371aa5bd802 100644 --- a/python/tvm/script/builder/tir/axis.py +++ b/python/tvm/script/builder/tir/axis.py @@ -16,22 +16,23 @@ # under the License. """TVM Script TIR Axis""" -from . import _ffi_api from tvm.ir import Range from tvm.tir import IterVar +from . import _ffi_api + def spatial(dom, binding, dtype="int32") -> IterVar: if not isinstance(dom, Range): dom = Range(0, dom) - return _ffi_api.AxisSpatial(dom, binding, dtype) + return _ffi_api.AxisSpatial(dom, binding, dtype) # pylint: disable=no-member # type: ignore def reduce(dom, binding, dtype="int32") -> IterVar: if not isinstance(dom, Range): dom = Range(0, dom) - return _ffi_api.AxisReduce(dom, binding, dtype) + return _ffi_api.AxisReduce(dom, binding, dtype) # pylint: disable=no-member # type: ignore def remap(kinds, bindings, dtype="int32") -> IterVar: - return _ffi_api.AxisRemap(kinds, bindings, dtype) + return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/builder/tir/base.py b/python/tvm/script/builder/tir/base.py index e19c47b0a478..9159c8d914b5 100644 --- a/python/tvm/script/builder/tir/base.py +++ b/python/tvm/script/builder/tir/base.py @@ -17,10 +17,9 @@ """TVM Script TIR Frame""" from tvm._ffi import register_object as _register_object -from . import _ffi_api from ..frame import Frame @_register_object("script.builder.tir.TIRFrame") class TIRFrame(Frame): - pass + ... diff --git a/python/tvm/script/builder/tir/block_frame.py b/python/tvm/script/builder/tir/block_frame.py index c90b0fac87c8..aa447a409e89 100644 --- a/python/tvm/script/builder/tir/block_frame.py +++ b/python/tvm/script/builder/tir/block_frame.py @@ -16,16 +16,15 @@ # under the License. """TVM Script TIR Block Frame""" from tvm._ffi import register_object as _register_object -from .base import TIRFrame - from . import _ffi_api +from .base import TIRFrame @_register_object("script.builder.tir.BlockFrame") class BlockFrame(TIRFrame): - pass + ... def block(name) -> BlockFrame: - return _ffi_api.BlockFrame(name) + return _ffi_api.BlockFrame(name) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/builder/tir/for_frame.py b/python/tvm/script/builder/tir/for_frame.py index 13b2599ae233..051565492882 100644 --- a/python/tvm/script/builder/tir/for_frame.py +++ b/python/tvm/script/builder/tir/for_frame.py @@ -15,42 +15,44 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR For Frame""" -from tvm._ffi import register_object as _register_object +from typing import List +from tvm._ffi import register_object as _register_object from tvm.tir import Var from . import _ffi_api -from ._ffi_api import _base_ffi_api +from .. import _ffi_api as _base_ffi_api from .base import TIRFrame -from typing import List @_register_object("script.builder.tir.ForFrame") class ForFrame(TIRFrame): def __enter__(self) -> List[Var]: - _base_ffi_api.FrameEnter(self) + _base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore return self.vars -def serial(min_val, extent, attrs) -> ForFrame: - return _ffi_api.Serial(min_val, extent, attrs) +def serial(start, stop, annotations) -> ForFrame: + return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore -def parallel(min_val, extent, attrs) -> ForFrame: - return _ffi_api.Parallel(min_val, extent, attrs) +def parallel(start, stop, annotations) -> ForFrame: + return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore -def vectorized(min_val, extent, attrs) -> ForFrame: - return _ffi_api.Vectorized(min_val, extent, attrs) +def vectorized(start, stop, annotations) -> ForFrame: + return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore -def unroll(min_val, extent, attrs) -> ForFrame: - return _ffi_api.Unroll(min_val, extent, attrs) +def unroll(start, stop, annotations) -> ForFrame: + return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore -def thread_binding(min_val, extent, attrs) -> ForFrame: - return _ffi_api.ThreadBinding(min_val, extent, attrs) +def thread_binding(start, stop, thread, annotations) -> ForFrame: + return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore + start, stop, thread, annotations + ) def grid(*extents) -> ForFrame: - return _ffi_api.Grid(extents) + return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py index 4a223af55f19..370fe361552d 100644 --- a/python/tvm/script/builder/tir/prim_func_frame.py +++ b/python/tvm/script/builder/tir/prim_func_frame.py @@ -15,26 +15,24 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR Prim Func Frame""" -from tvm._ffi import register_object as _register_object +from typing import Union -from tvm.tir.expr import Var +from tvm._ffi import register_object as _register_object from tvm.tir.buffer import Buffer - +from tvm.tir.expr import Var from . import _ffi_api from .base import TIRFrame -from typing import Union - @_register_object("script.builder.tir.PrimFuncFrame") class PrimFuncFrame(TIRFrame): - pass + ... def prim_func(name) -> PrimFuncFrame: - return _ffi_api.PrimFuncFrame(name) + return _ffi_api.PrimFuncFrame(name) # pylint: disable=no-member # type: ignore -def arg(name, arg) -> Union[Var, Buffer]: - return _ffi_api.Arg(name, arg) +def arg(name, obj) -> Union[Var, Buffer]: + return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/builder/tir/var.py b/python/tvm/script/builder/tir/var.py index fa06ee63c14a..18a8ecd59bbe 100644 --- a/python/tvm/script/builder/tir/var.py +++ b/python/tvm/script/builder/tir/var.py @@ -15,12 +15,15 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR Buffer""" -from tvm._ffi import register_object as _register_object - -from tvm.tir.buffer import Buffer +from tvm import tir from . import _ffi_api -def Buffer(shape, dtype, name="buffer", storage_scope="") -> Buffer: - return _ffi_api.Buffer(shape, dtype, name, storage_scope) +def Buffer( # pylint: disable=invalid-name + shape, + dtype, + name="buffer", + storage_scope="", +) -> tir.Buffer: + return _ffi_api.Buffer(shape, dtype, name, storage_scope) # pylint: disable=no-member # type: ignore diff --git a/src/script/builder/frame.cc b/src/script/builder/frame.cc index 56280a0b5ec5..ab2bf7774e27 100644 --- a/src/script/builder/frame.cc +++ b/src/script/builder/frame.cc @@ -24,14 +24,9 @@ namespace tvm { namespace script { namespace builder { -void FrameNode::EnterWithScope() { - LOG(INFO) << "EnterWithScope: " << this->GetTypeKey(); - // Push to the current builder - Builder::Current()->frames.push_back(GetRef(this)); -} +void FrameNode::EnterWithScope() { Builder::Current()->frames.push_back(GetRef(this)); } void FrameNode::ExitWithScope() { - LOG(INFO) << "ExitWithScope: " << this->GetTypeKey(); for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { (*it)(); } @@ -60,9 +55,7 @@ void IRModuleFrameNode::ExitWithScope() { TVM_REGISTER_NODE_TYPE(FrameNode); TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); - TVM_REGISTER_GLOBAL("script.builder.FrameEnter").set_body_method(&FrameNode::EnterWithScope); - TVM_REGISTER_GLOBAL("script.builder.FrameExit").set_body_method(&FrameNode::ExitWithScope); } // namespace builder diff --git a/src/script/builder/tir/base.cc b/src/script/builder/tir/base.cc index d5206c9a7348..671090c283ed 100644 --- a/src/script/builder/tir/base.cc +++ b/src/script/builder/tir/base.cc @@ -35,36 +35,6 @@ namespace tir { TVM_REGISTER_NODE_TYPE(TIRFrameNode); -void TestPOC() { - namespace T = tvm::script::builder::tir; - using namespace ::tvm::tir; - - With builder; - { - With _{T::PrimFunc_("main")}; - Buffer A = T::Arg("A", T::Buffer_({128, 128, 128}, DataType::Float(32))); - Buffer B = T::Arg("B", T::Buffer_({128, 128, 128}, DataType::Float(32))); - { - With _{T::Grid({128, 128, 128})}; - Var i = Def("i", _()->vars[0]); - Var j = Def("j", _()->vars[1]); - Var k = Def("k", _()->vars[2]); - { - With _{T::Block_("block")}; - IterVar vi = Def("vi", T::axis::Spatial(Range(0, 128), i)); - IterVar vj = Def("vj", T::axis::Spatial(Range(0, 128), j)); - IterVar vk = Def("vk", T::axis::Reduce(Range(0, 128), k)); - } - LOG(INFO) << "ForFrame:\n" << _()->stmts; - } - LOG(INFO) << "PrimFuncFrame:\n" << _()->stmts; - } - PrimFunc func = builder()->Get(); - LOG(INFO) << "func:\n" << AsTVMScript(func); -} - -TVM_REGISTER_GLOBAL("test_poc").set_body_typed(TestPOC); - } // namespace tir } // namespace builder } // namespace script diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc index d892df8a1b90..f2167f557589 100644 --- a/src/script/builder/tir/block_frame.cc +++ b/src/script/builder/tir/block_frame.cc @@ -145,13 +145,9 @@ Array Remap(String kinds, Array bindings, DataType } // namespace axis TVM_REGISTER_NODE_TYPE(BlockFrameNode); - TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block_); - TVM_REGISTER_GLOBAL("script.builder.tir.AxisSpatial").set_body_typed(axis::Spatial); - TVM_REGISTER_GLOBAL("script.builder.tir.AxisReduce").set_body_typed(axis::Reduce); - TVM_REGISTER_GLOBAL("script.builder.tir.AxisRemap").set_body_typed(axis::Remap); } // namespace tir diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc index f242453395ee..0b6d289e9fc9 100644 --- a/src/script/builder/tir/for_frame.cc +++ b/src/script/builder/tir/for_frame.cc @@ -18,6 +18,7 @@ */ #include "./for_frame.h" +#include #include namespace tvm { @@ -30,19 +31,21 @@ void ForFrameNode::ExitWithScope() { AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts))); } -#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ - ForFrame Method(PrimExpr min, PrimExpr extent, Map attrs) { \ - using namespace tvm::tir; \ - ObjectPtr n = make_object(); \ - int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ - n->vars = {Var("v", DataType::Int(bits))}; \ - n->doms = {Range(min, extent)}; \ - n->f_make_for_loop = [attrs](Array vars, Array doms, Stmt body) { \ - ICHECK_EQ(vars.size(), 1); \ - ICHECK_EQ(doms.size(), 1); \ - return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, attrs); \ - }; \ - return ForFrame(n); \ +#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, Map annotations) { \ + using namespace tvm::tir; \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType::Int(bits))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->f_make_for_loop = [annotations](Array vars, Array doms, Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, annotations); \ + }; \ + return ForFrame(n); \ } TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Serial, tvm::tir::ForKind::kSerial); @@ -52,18 +55,21 @@ TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Unroll, tvm::tir::ForKind::kUnrolled); #undef TVM_SCRIPT_BUILDER_TIR_FOR_CREATE -ForFrame ThreadBinding(PrimExpr min, PrimExpr extent, String thread, Map attrs) { +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, + Map annotations) { using namespace tvm::tir; + PrimExpr min = start; + PrimExpr extent = arith::Analyzer().Simplify(stop - start); ObjectPtr n = make_object(); int bits = std::max(min.dtype().bits(), extent.dtype().bits()); n->vars = {Var("v", DataType::Int(bits))}; - n->doms = {Range(min, extent)}; - n->f_make_for_loop = [attrs, thread](Array vars, Array doms, Stmt body) -> For { + n->doms = {Range::FromMinExtent(min, extent)}; + n->f_make_for_loop = [annotations, thread](Array vars, Array doms, Stmt body) -> For { ICHECK_EQ(vars.size(), 1); ICHECK_EQ(doms.size(), 1); IterVar iter_var(Range(nullptr), NullValue(), IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, - attrs); + annotations); }; return ForFrame(n); } @@ -93,17 +99,11 @@ ForFrame Grid(Array extents) { } TVM_REGISTER_NODE_TYPE(ForFrameNode); - TVM_REGISTER_GLOBAL("script.builder.tir.Serial").set_body_typed(Serial); - TVM_REGISTER_GLOBAL("script.builder.tir.Parallel").set_body_typed(Parallel); - TVM_REGISTER_GLOBAL("script.builder.tir.Vectorized").set_body_typed(Vectorized); - TVM_REGISTER_GLOBAL("script.builder.tir.Unroll").set_body_typed(Unroll); - TVM_REGISTER_GLOBAL("script.builder.tir.ThreadBinding").set_body_typed(ThreadBinding); - TVM_REGISTER_GLOBAL("script.builder.tir.Grid").set_body_typed(Grid); } // namespace tir diff --git a/src/script/builder/tir/for_frame.h b/src/script/builder/tir/for_frame.h index 2bff8dcd5f5e..e4d87cd7572a 100644 --- a/src/script/builder/tir/for_frame.h +++ b/src/script/builder/tir/for_frame.h @@ -59,11 +59,11 @@ class ForFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); }; -ForFrame Serial(PrimExpr min, PrimExpr extent, Map annotations); -ForFrame Parallel(PrimExpr min, PrimExpr extent, Map annotations); -ForFrame Vectorized(PrimExpr min, PrimExpr extent, Map annotations); -ForFrame Unroll(PrimExpr min, PrimExpr extent, Map annotations); -ForFrame ThreadBinding(PrimExpr min, PrimExpr extent, String thread, +ForFrame Serial(PrimExpr start, PrimExpr stop, Map annotations); +ForFrame Parallel(PrimExpr start, PrimExpr stop, Map annotations); +ForFrame Vectorized(PrimExpr start, PrimExpr stop, Map annotations); +ForFrame Unroll(PrimExpr start, PrimExpr stop, Map annotations); +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, Map annotations); ForFrame Grid(Array extents); diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 70ba8c9adae7..9d85e2193c56 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -73,19 +73,18 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { } TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); - TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc_); - TVM_REGISTER_GLOBAL("script.builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; if (const auto* var = obj.as()) { return Arg(name, GetRef(var)); - } else if (const auto* buffer = obj.as()) { + } + if (const auto* buffer = obj.as()) { return Arg(name, GetRef(buffer)); - } else { - LOG(FATAL) << "ValueError: Unexpected type for TIR Arg."; } + LOG(FATAL) << "ValueError: Unexpected type for TIR Arg."; + throw; }); } // namespace tir diff --git a/tests/python/unittest/test_tvmscript_builder.py b/tests/python/unittest/test_tvmscript_builder.py index 7fb4cc1a0eed..83ab995e76b7 100644 --- a/tests/python/unittest/test_tvmscript_builder.py +++ b/tests/python/unittest/test_tvmscript_builder.py @@ -21,8 +21,7 @@ def test_builder_basic(): - b = Builder() - with b: + with Builder() as b: with T.prim_func(name="main"): A = T.arg("A", T.Buffer((128, 128, 128), "float32")) B = T.arg("B", T.Buffer((128, 128, 128), "float32")) @@ -33,7 +32,6 @@ def test_builder_basic(): vj = def_("vj", T.axis.spatial(128, j)) vk = def_("vk", T.axis.reduce(128, k)) print(b.get().script()) - tvm._ffi.get_global_func("test_poc")() if __name__ == "__main__":