diff --git a/python/tvm/script/builder/_ffi_api.py b/python/tvm/script/builder/_ffi_api.py
index ec20ad798f80..3410494ded4d 100644
--- a/python/tvm/script/builder/_ffi_api.py
+++ b/python/tvm/script/builder/_ffi_api.py
@@ -17,4 +17,4 @@
"""FFI APIs for tvm.script.builder"""
import tvm._ffi
-tvm._ffi._init_api("script.builder", __name__)
+tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/script/builder/builder.py b/python/tvm/script/builder/builder.py
index 3d449ef1975a..6ec31abecb5f 100644
--- a/python/tvm/script/builder/builder.py
+++ b/python/tvm/script/builder/builder.py
@@ -15,44 +15,47 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script IR Builder"""
-from typing import List
-from tvm._ffi import register_object as _register_object
-from .frame import Frame
+from typing import List, TypeVar
+from tvm._ffi import register_object as _register_object
from tvm.runtime import Object
from . import _ffi_api
-
-from typing import TypeVar
+from .frame import Frame
@_register_object("script.builder.Builder")
class Builder(Object):
def __init__(self) -> None:
- self.__init_handle_by_constructor__(_ffi_api.Builder)
+ self.__init_handle_by_constructor__(
+ _ffi_api.Builder # pylint: disable=no-member # type: ignore
+ )
def __enter__(self) -> "Builder":
- _ffi_api.BuilderEnter(self)
+ _ffi_api.BuilderEnter(self) # pylint: disable=no-member # type: ignore
return self
- def __exit__(self, ptype, value, trace) -> None:
- _ffi_api.BuilderExit(self)
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.BuilderExit(self) # pylint: disable=no-member # type: ignore
@staticmethod
- def current(self) -> "Builder":
- return _ffi_api.BuilderCurrent(self)
+ def current() -> "Builder":
+ return _ffi_api.BuilderCurrent() # pylint: disable=no-member # type: ignore
def get(self) -> Frame:
- return _ffi_api.BuilderGet(self)
+ return _ffi_api.BuilderGet(self) # pylint: disable=no-member # type: ignore
DefType = TypeVar("DefType", bound=Object)
def def_(name: str, var: DefType) -> DefType:
- return _ffi_api.Def(name, var)
+ return _ffi_api.Def(name, var) # pylint: disable=no-member # type: ignore
-def def_many(names: List[str], vars: List[DefType]) -> List[DefType]:
+def def_many(
+ names: List[str],
+ vars: List[DefType], # pylint: disable=redefine-builtin
+) -> List[DefType]:
assert len(names) == len(vars)
return [def_(name, var) for name, var in zip(names, vars)]
diff --git a/python/tvm/script/builder/frame.py b/python/tvm/script/builder/frame.py
index 7f6ac8972fcf..faf0c4271d7d 100644
--- a/python/tvm/script/builder/frame.py
+++ b/python/tvm/script/builder/frame.py
@@ -16,7 +16,6 @@
# under the License.
"""TVM Script Frames"""
from tvm._ffi import register_object as _register_object
-
from tvm.runtime import Object
from . import _ffi_api
@@ -25,14 +24,16 @@
@_register_object("script.builder.Frame")
class Frame(Object):
def __enter__(self) -> "Frame":
- _ffi_api.FrameEnter(self)
+ _ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
return self
- def __exit__(self, ptype, value, trace) -> None:
- _ffi_api.FrameExit(self)
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.FrameExit(self) # pylint: disable=no-member # type: ignore
@_register_object("script.builder.IRModuleFrame")
class IRModuleFrame(Frame):
def __init__(self) -> None:
- self.__init_handle_by_constructor__(_ffi_api.IRModuleFrame)
+ self.__init_handle_by_constructor__(
+ _ffi_api.IRModuleFrame # pylint: disable=no-member # type: ignore
+ )
diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py
index c206e5e84059..e6e431e5bced 100644
--- a/python/tvm/script/builder/tir/__init__.py
+++ b/python/tvm/script/builder/tir/__init__.py
@@ -17,17 +17,17 @@
# pylint: disable=unused-import
"""Namespace for the TVMScript TIR Builder API."""
+from . import axis
from .base import TIRFrame
+from .block_frame import block
from .for_frame import (
ForFrame,
- serial,
+ grid,
parallel,
- vectorized,
- unroll,
+ serial,
thread_binding,
- grid,
+ unroll,
+ vectorized,
)
-from .prim_func_frame import prim_func, arg
-from .block_frame import block
+from .prim_func_frame import arg, prim_func
from .var import Buffer
-from . import axis
diff --git a/python/tvm/script/builder/tir/_ffi_api.py b/python/tvm/script/builder/tir/_ffi_api.py
index df97ad7ae7f2..4e40e7261fd3 100644
--- a/python/tvm/script/builder/tir/_ffi_api.py
+++ b/python/tvm/script/builder/tir/_ffi_api.py
@@ -14,9 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.script.builder"""
+"""FFI APIs for tvm.script.builder.tir"""
import tvm._ffi
-from .. import _ffi_api as _base_ffi_api
-
-tvm._ffi._init_api("script.builder.tir", __name__)
+tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/script/builder/tir/axis.py b/python/tvm/script/builder/tir/axis.py
index 9bb3e75650b5..0371aa5bd802 100644
--- a/python/tvm/script/builder/tir/axis.py
+++ b/python/tvm/script/builder/tir/axis.py
@@ -16,22 +16,23 @@
# under the License.
"""TVM Script TIR Axis"""
-from . import _ffi_api
from tvm.ir import Range
from tvm.tir import IterVar
+from . import _ffi_api
+
def spatial(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
- return _ffi_api.AxisSpatial(dom, binding, dtype)
+ return _ffi_api.AxisSpatial(dom, binding, dtype) # pylint: disable=no-member # type: ignore
def reduce(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
- return _ffi_api.AxisReduce(dom, binding, dtype)
+ return _ffi_api.AxisReduce(dom, binding, dtype) # pylint: disable=no-member # type: ignore
def remap(kinds, bindings, dtype="int32") -> IterVar:
- return _ffi_api.AxisRemap(kinds, bindings, dtype)
+ return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/builder/tir/base.py b/python/tvm/script/builder/tir/base.py
index e19c47b0a478..9159c8d914b5 100644
--- a/python/tvm/script/builder/tir/base.py
+++ b/python/tvm/script/builder/tir/base.py
@@ -17,10 +17,9 @@
"""TVM Script TIR Frame"""
from tvm._ffi import register_object as _register_object
-from . import _ffi_api
from ..frame import Frame
@_register_object("script.builder.tir.TIRFrame")
class TIRFrame(Frame):
- pass
+ ...
diff --git a/python/tvm/script/builder/tir/block_frame.py b/python/tvm/script/builder/tir/block_frame.py
index c90b0fac87c8..aa447a409e89 100644
--- a/python/tvm/script/builder/tir/block_frame.py
+++ b/python/tvm/script/builder/tir/block_frame.py
@@ -16,16 +16,15 @@
# under the License.
"""TVM Script TIR Block Frame"""
from tvm._ffi import register_object as _register_object
-from .base import TIRFrame
-
from . import _ffi_api
+from .base import TIRFrame
@_register_object("script.builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
- pass
+ ...
def block(name) -> BlockFrame:
- return _ffi_api.BlockFrame(name)
+ return _ffi_api.BlockFrame(name) # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/builder/tir/for_frame.py b/python/tvm/script/builder/tir/for_frame.py
index 13b2599ae233..051565492882 100644
--- a/python/tvm/script/builder/tir/for_frame.py
+++ b/python/tvm/script/builder/tir/for_frame.py
@@ -15,42 +15,44 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR For Frame"""
-from tvm._ffi import register_object as _register_object
+from typing import List
+from tvm._ffi import register_object as _register_object
from tvm.tir import Var
from . import _ffi_api
-from ._ffi_api import _base_ffi_api
+from .. import _ffi_api as _base_ffi_api
from .base import TIRFrame
-from typing import List
@_register_object("script.builder.tir.ForFrame")
class ForFrame(TIRFrame):
def __enter__(self) -> List[Var]:
- _base_ffi_api.FrameEnter(self)
+ _base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
return self.vars
-def serial(min_val, extent, attrs) -> ForFrame:
- return _ffi_api.Serial(min_val, extent, attrs)
+def serial(start, stop, annotations) -> ForFrame:
+ return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore
-def parallel(min_val, extent, attrs) -> ForFrame:
- return _ffi_api.Parallel(min_val, extent, attrs)
+def parallel(start, stop, annotations) -> ForFrame:
+ return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore
-def vectorized(min_val, extent, attrs) -> ForFrame:
- return _ffi_api.Vectorized(min_val, extent, attrs)
+def vectorized(start, stop, annotations) -> ForFrame:
+ return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore
-def unroll(min_val, extent, attrs) -> ForFrame:
- return _ffi_api.Unroll(min_val, extent, attrs)
+def unroll(start, stop, annotations) -> ForFrame:
+ return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore
-def thread_binding(min_val, extent, attrs) -> ForFrame:
- return _ffi_api.ThreadBinding(min_val, extent, attrs)
+def thread_binding(start, stop, thread, annotations) -> ForFrame:
+ return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore
+ start, stop, thread, annotations
+ )
def grid(*extents) -> ForFrame:
- return _ffi_api.Grid(extents)
+ return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py
index 4a223af55f19..370fe361552d 100644
--- a/python/tvm/script/builder/tir/prim_func_frame.py
+++ b/python/tvm/script/builder/tir/prim_func_frame.py
@@ -15,26 +15,24 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Prim Func Frame"""
-from tvm._ffi import register_object as _register_object
+from typing import Union
-from tvm.tir.expr import Var
+from tvm._ffi import register_object as _register_object
from tvm.tir.buffer import Buffer
-
+from tvm.tir.expr import Var
from . import _ffi_api
from .base import TIRFrame
-from typing import Union
-
@_register_object("script.builder.tir.PrimFuncFrame")
class PrimFuncFrame(TIRFrame):
- pass
+ ...
def prim_func(name) -> PrimFuncFrame:
- return _ffi_api.PrimFuncFrame(name)
+ return _ffi_api.PrimFuncFrame(name) # pylint: disable=no-member # type: ignore
-def arg(name, arg) -> Union[Var, Buffer]:
- return _ffi_api.Arg(name, arg)
+def arg(name, obj) -> Union[Var, Buffer]:
+ return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore
diff --git a/python/tvm/script/builder/tir/var.py b/python/tvm/script/builder/tir/var.py
index fa06ee63c14a..18a8ecd59bbe 100644
--- a/python/tvm/script/builder/tir/var.py
+++ b/python/tvm/script/builder/tir/var.py
@@ -15,12 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Buffer"""
-from tvm._ffi import register_object as _register_object
-
-from tvm.tir.buffer import Buffer
+from tvm import tir
from . import _ffi_api
-def Buffer(shape, dtype, name="buffer", storage_scope="") -> Buffer:
- return _ffi_api.Buffer(shape, dtype, name, storage_scope)
+def Buffer( # pylint: disable=invalid-name
+ shape,
+ dtype,
+ name="buffer",
+ storage_scope="",
+) -> tir.Buffer:
+ return _ffi_api.Buffer(shape, dtype, name, storage_scope) # pylint: disable=no-member # type: ignore
diff --git a/src/script/builder/frame.cc b/src/script/builder/frame.cc
index 56280a0b5ec5..ab2bf7774e27 100644
--- a/src/script/builder/frame.cc
+++ b/src/script/builder/frame.cc
@@ -24,14 +24,9 @@ namespace tvm {
namespace script {
namespace builder {
-void FrameNode::EnterWithScope() {
- LOG(INFO) << "EnterWithScope: " << this->GetTypeKey();
- // Push to the current builder
- Builder::Current()->frames.push_back(GetRef(this));
-}
+void FrameNode::EnterWithScope() { Builder::Current()->frames.push_back(GetRef(this)); }
void FrameNode::ExitWithScope() {
- LOG(INFO) << "ExitWithScope: " << this->GetTypeKey();
for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
(*it)();
}
@@ -60,9 +55,7 @@ void IRModuleFrameNode::ExitWithScope() {
TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
-
TVM_REGISTER_GLOBAL("script.builder.FrameEnter").set_body_method(&FrameNode::EnterWithScope);
-
TVM_REGISTER_GLOBAL("script.builder.FrameExit").set_body_method(&FrameNode::ExitWithScope);
} // namespace builder
diff --git a/src/script/builder/tir/base.cc b/src/script/builder/tir/base.cc
index d5206c9a7348..671090c283ed 100644
--- a/src/script/builder/tir/base.cc
+++ b/src/script/builder/tir/base.cc
@@ -35,36 +35,6 @@ namespace tir {
TVM_REGISTER_NODE_TYPE(TIRFrameNode);
-void TestPOC() {
- namespace T = tvm::script::builder::tir;
- using namespace ::tvm::tir;
-
- With builder;
- {
- With _{T::PrimFunc_("main")};
- Buffer A = T::Arg("A", T::Buffer_({128, 128, 128}, DataType::Float(32)));
- Buffer B = T::Arg("B", T::Buffer_({128, 128, 128}, DataType::Float(32)));
- {
- With _{T::Grid({128, 128, 128})};
- Var i = Def("i", _()->vars[0]);
- Var j = Def("j", _()->vars[1]);
- Var k = Def("k", _()->vars[2]);
- {
- With _{T::Block_("block")};
- IterVar vi = Def("vi", T::axis::Spatial(Range(0, 128), i));
- IterVar vj = Def("vj", T::axis::Spatial(Range(0, 128), j));
- IterVar vk = Def("vk", T::axis::Reduce(Range(0, 128), k));
- }
- LOG(INFO) << "ForFrame:\n" << _()->stmts;
- }
- LOG(INFO) << "PrimFuncFrame:\n" << _()->stmts;
- }
- PrimFunc func = builder()->Get();
- LOG(INFO) << "func:\n" << AsTVMScript(func);
-}
-
-TVM_REGISTER_GLOBAL("test_poc").set_body_typed(TestPOC);
-
} // namespace tir
} // namespace builder
} // namespace script
diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc
index d892df8a1b90..f2167f557589 100644
--- a/src/script/builder/tir/block_frame.cc
+++ b/src/script/builder/tir/block_frame.cc
@@ -145,13 +145,9 @@ Array Remap(String kinds, Array bindings, DataType
} // namespace axis
TVM_REGISTER_NODE_TYPE(BlockFrameNode);
-
TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block_);
-
TVM_REGISTER_GLOBAL("script.builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
-
TVM_REGISTER_GLOBAL("script.builder.tir.AxisReduce").set_body_typed(axis::Reduce);
-
TVM_REGISTER_GLOBAL("script.builder.tir.AxisRemap").set_body_typed(axis::Remap);
} // namespace tir
diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc
index f242453395ee..0b6d289e9fc9 100644
--- a/src/script/builder/tir/for_frame.cc
+++ b/src/script/builder/tir/for_frame.cc
@@ -18,6 +18,7 @@
*/
#include "./for_frame.h"
+#include
#include
namespace tvm {
@@ -30,19 +31,21 @@ void ForFrameNode::ExitWithScope() {
AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts)));
}
-#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \
- ForFrame Method(PrimExpr min, PrimExpr extent, Map attrs) { \
- using namespace tvm::tir; \
- ObjectPtr n = make_object(); \
- int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
- n->vars = {Var("v", DataType::Int(bits))}; \
- n->doms = {Range(min, extent)}; \
- n->f_make_for_loop = [attrs](Array vars, Array doms, Stmt body) { \
- ICHECK_EQ(vars.size(), 1); \
- ICHECK_EQ(doms.size(), 1); \
- return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, attrs); \
- }; \
- return ForFrame(n); \
+#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \
+ ForFrame Method(PrimExpr start, PrimExpr stop, Map annotations) { \
+ using namespace tvm::tir; \
+ PrimExpr min = start; \
+ PrimExpr extent = arith::Analyzer().Simplify(stop - start); \
+ ObjectPtr n = make_object(); \
+ int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
+ n->vars = {Var("v", DataType::Int(bits))}; \
+ n->doms = {Range::FromMinExtent(min, extent)}; \
+ n->f_make_for_loop = [annotations](Array vars, Array doms, Stmt body) { \
+ ICHECK_EQ(vars.size(), 1); \
+ ICHECK_EQ(doms.size(), 1); \
+ return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, annotations); \
+ }; \
+ return ForFrame(n); \
}
TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Serial, tvm::tir::ForKind::kSerial);
@@ -52,18 +55,21 @@ TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Unroll, tvm::tir::ForKind::kUnrolled);
#undef TVM_SCRIPT_BUILDER_TIR_FOR_CREATE
-ForFrame ThreadBinding(PrimExpr min, PrimExpr extent, String thread, Map attrs) {
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
+ Map annotations) {
using namespace tvm::tir;
+ PrimExpr min = start;
+ PrimExpr extent = arith::Analyzer().Simplify(stop - start);
ObjectPtr n = make_object();
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
n->vars = {Var("v", DataType::Int(bits))};
- n->doms = {Range(min, extent)};
- n->f_make_for_loop = [attrs, thread](Array vars, Array doms, Stmt body) -> For {
+ n->doms = {Range::FromMinExtent(min, extent)};
+ n->f_make_for_loop = [annotations, thread](Array vars, Array doms, Stmt body) -> For {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
IterVar iter_var(Range(nullptr), NullValue(), IterVarType::kThreadIndex, thread);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
- attrs);
+ annotations);
};
return ForFrame(n);
}
@@ -93,17 +99,11 @@ ForFrame Grid(Array extents) {
}
TVM_REGISTER_NODE_TYPE(ForFrameNode);
-
TVM_REGISTER_GLOBAL("script.builder.tir.Serial").set_body_typed(Serial);
-
TVM_REGISTER_GLOBAL("script.builder.tir.Parallel").set_body_typed(Parallel);
-
TVM_REGISTER_GLOBAL("script.builder.tir.Vectorized").set_body_typed(Vectorized);
-
TVM_REGISTER_GLOBAL("script.builder.tir.Unroll").set_body_typed(Unroll);
-
TVM_REGISTER_GLOBAL("script.builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
-
TVM_REGISTER_GLOBAL("script.builder.tir.Grid").set_body_typed(Grid);
} // namespace tir
diff --git a/src/script/builder/tir/for_frame.h b/src/script/builder/tir/for_frame.h
index 2bff8dcd5f5e..e4d87cd7572a 100644
--- a/src/script/builder/tir/for_frame.h
+++ b/src/script/builder/tir/for_frame.h
@@ -59,11 +59,11 @@ class ForFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
};
-ForFrame Serial(PrimExpr min, PrimExpr extent, Map annotations);
-ForFrame Parallel(PrimExpr min, PrimExpr extent, Map annotations);
-ForFrame Vectorized(PrimExpr min, PrimExpr extent, Map annotations);
-ForFrame Unroll(PrimExpr min, PrimExpr extent, Map annotations);
-ForFrame ThreadBinding(PrimExpr min, PrimExpr extent, String thread,
+ForFrame Serial(PrimExpr start, PrimExpr stop, Map annotations);
+ForFrame Parallel(PrimExpr start, PrimExpr stop, Map annotations);
+ForFrame Vectorized(PrimExpr start, PrimExpr stop, Map annotations);
+ForFrame Unroll(PrimExpr start, PrimExpr stop, Map annotations);
+ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
Map annotations);
ForFrame Grid(Array extents);
diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc
index 70ba8c9adae7..9d85e2193c56 100644
--- a/src/script/builder/tir/prim_func_frame.cc
+++ b/src/script/builder/tir/prim_func_frame.cc
@@ -73,19 +73,18 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) {
}
TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
-
TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc_);
-
TVM_REGISTER_GLOBAL("script.builder.tir.Arg")
.set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
using namespace tvm::tir;
if (const auto* var = obj.as()) {
return Arg(name, GetRef(var));
- } else if (const auto* buffer = obj.as()) {
+ }
+ if (const auto* buffer = obj.as()) {
return Arg(name, GetRef(buffer));
- } else {
- LOG(FATAL) << "ValueError: Unexpected type for TIR Arg.";
}
+ LOG(FATAL) << "ValueError: Unexpected type for TIR Arg.";
+ throw;
});
} // namespace tir
diff --git a/tests/python/unittest/test_tvmscript_builder.py b/tests/python/unittest/test_tvmscript_builder.py
index 7fb4cc1a0eed..83ab995e76b7 100644
--- a/tests/python/unittest/test_tvmscript_builder.py
+++ b/tests/python/unittest/test_tvmscript_builder.py
@@ -21,8 +21,7 @@
def test_builder_basic():
- b = Builder()
- with b:
+ with Builder() as b:
with T.prim_func(name="main"):
A = T.arg("A", T.Buffer((128, 128, 128), "float32"))
B = T.arg("B", T.Buffer((128, 128, 128), "float32"))
@@ -33,7 +32,6 @@ def test_builder_basic():
vj = def_("vj", T.axis.spatial(128, j))
vk = def_("vk", T.axis.reduce(128, k))
print(b.get().script())
- tvm._ffi.get_global_func("test_poc")()
if __name__ == "__main__":