diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 9966a3ef1fd2..3612bb81a6bc 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -39,4 +39,3 @@ tune_tir, ) from .tune_context import TuneContext -from . import tensor_intrin diff --git a/python/tvm/meta_schedule/postproc/rewrite_vnni.py b/python/tvm/meta_schedule/postproc/rewrite_vnni.py new file mode 100644 index 000000000000..b4de67184f61 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_vnni.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that tensorize VNNI related components.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc +import tvm.tir.tensor_intrin + + +@register_object("meta_schedule.RewriteVNNI") +class RewriteVNNI(Postproc): + """A postprocessor that tensorize VNNI related components.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteVNNI, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 0148bd0b4243..708acf6aa9a2 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -25,10 +25,9 @@ import tvm.tir from tvm.runtime import Object, String -from tvm import te from tvm.target import Target from tvm.ir import Span -from tvm.tir import IntImm, IterVar +from tvm.tir import IntImm, IterVar, Var from .node import BufferSlice from .utils import buffer_slice_to_region @@ -800,7 +799,7 @@ def var(dtype, span): self.context.report_error( f"VarDef expected assign to only one var, but got {names}", span ) - v = te.var(names[0], dtype, span=span) + v = Var(names[0], dtype, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(var, def_symbol=True) @@ -821,7 +820,7 @@ def buffer_var(dtype, storage_scope, span): f"VarDef expected assign to only one var, but got {names}", span ) ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - v = te.var(names[0], ptr_type, span=span) + v = Var(names[0], ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(buffer_var, def_symbol=True) @@ -841,7 +840,7 @@ def env_thread(env_name, span): self.context.report_error( f"VarDef expected assign to only one var, but got {names}", span ) - v = te.var(names[0], span=span) + v = Var(names[0], span=span) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v, self.node) diff --git a/python/tvm/meta_schedule/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py similarity index 100% rename from python/tvm/meta_schedule/tensor_intrin/__init__.py rename to python/tvm/tir/tensor_intrin/__init__.py diff --git a/python/tvm/meta_schedule/tensor_intrin/vnni.py b/python/tvm/tir/tensor_intrin/vnni.py similarity index 98% rename from python/tvm/meta_schedule/tensor_intrin/vnni.py rename to python/tvm/tir/tensor_intrin/vnni.py index 674f007dac2a..9393aa8bb7b6 100644 --- a/python/tvm/meta_schedule/tensor_intrin/vnni.py +++ b/python/tvm/tir/tensor_intrin/vnni.py @@ -16,7 +16,6 @@ # under the License. from tvm import tir from tvm.script import tir as T -from tvm.script.registry import register @T.prim_func