diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 615ce90383dd..aaa5442eede3 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -28,12 +28,111 @@ namespace script { namespace ir_builder { namespace tir { +using tvm::tir::Buffer; +using tvm::tir::Var; + +/*! + * \brief The buffer declaration function. + * \param shape The type of the buffer prior to flattening. + * \param dtype The data type in the content of the buffer. + * \param buffer_name The name of the buffer. + * \param data The pointer to the head of the data. + * \param strides The strides of each dimension. + * \param elem_offset The offset in terms of number of dtype elements (including lanes). + * \param storage_scope The optional storage scope of buffer data pointer. + * \param align The alignment requirement of data pointer in bytes. + * \param offset_factor The factor of elem_offset field. + * \param buffer_type The buffer type. + * \param axis_separators The separators between input axes when generating flattened output axes. + * \return The declared buffer. + */ +Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, + Optional> strides, Optional elem_offset, + String storage_scope, int align, int offset_factor, String buffer_type, + Optional> axis_separators); + /*! * \brief The primitive function statement. * \return The PrimFuncFrame. */ PrimFuncFrame PrimFunc(); +/*! + * \brief The PrimFunc variable arguments adding function. + * \param name The name of the variable. + * \param var The variable argument. + * \return The variable. + */ +Var Arg(String name, Var var); + +/*! + * \brief The PrimFunc buffer arguments adding function. + * \param name The name of the buffer. + * \param buffer The buffer argument. + * \return The buffer. + */ +Buffer Arg(String name, Buffer buffer); + +/*! + * \brief The PrimFunc naming statement. + * \param name The name of the PrimFunc. + */ +void FuncName(String name); + +/*! + * \brief The PrimFunc annotation statement. + * \param attrs The annotations of the PrimFunc. + */ +void FuncAttrs(Map attrs); + +/*! + * \brief The PrimFunc return type statement. + * \param ret_type The return type of the PrimFunc. + * \return The return type. + */ +Type FuncRet(Type ret_type); + +/*! + * \brief The buffer match statement. + * \param param The parameter of the PrimFunc to match. + * \param shape The type of the buffer prior to flattening. + * \param dtype The data type in the content of the buffer. + * \param data The pointer to the head of the data. + * \param strides The strides of each dimension. + * \param elem_offset The offset in terms of number of dtype elements (including lanes). + * \param storage_scope The optional storage scope of buffer data pointer. + * \param align The alignment requirement of data pointer in bytes. + * \param offset_factor The factor of elem_offset field. + * \param buffer_type The buffer type. + * \param axis_separators The separators between input axes when generating flattened output axes. + * \return The matched buffer. + */ +Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = DataType::Float(32), + Optional data = NullOpt, Array strides = {}, + PrimExpr elem_offset = PrimExpr(), String storage_scope = "global", + int align = -1, int offset_factor = 0, String buffer_type = "default", + Array axis_separators = {}); + +/*! + * \brief The pre-flattened buffer statement. + * \param postflattened_buffer The original buffer to be flattened. + * \param shape The type of the buffer prior to flattening. + * \param dtype The data type in the content of the buffer. + * \param data The pointer to the head of the data. + * \param strides The strides of each dimension. + * \param elem_offset The offset in terms of number of dtype elements (including lanes). + * \param storage_scope The optional storage scope of buffer data pointer. + * \param align The alignment requirement of data pointer in bytes. + * \param offset_factor The factor of elem_offset field. + * \param buffer_type The buffer type. + * \param axis_separators The separators between input axes when generating flattened output axes. + */ +void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, + DataType dtype = DataType::Float(32), Optional data = NullOpt, + Array strides = {}, PrimExpr elem_offset = PrimExpr(), + String storage_scope = "global", int align = -1, int offset_factor = 0, + String buffer_type = "default", Array axis_separators = {}); + /*! * \brief The block declaration statement. * \param name The name of the block. @@ -48,6 +147,33 @@ BlockFrame Block(String name, bool no_realize = false); */ void Evaluate(PrimExpr value); +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(Optional expr = NullOpt) { \ + DataType dtype = DType; \ + return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \ + } + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle()); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); + +#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST + } // namespace tir } // namespace ir_builder } // namespace script diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 7ba2f6df9418..63fd1291f4bc 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -17,11 +17,89 @@ # pylint: disable=missing-docstring """IRBuilder for TIR""" -from tvm.tir import PrimExpr, StringImm +from numbers import Integral +from typing import Any, Dict, List, Optional, Union, Tuple + +from tvm.ir import Type +from tvm.tir import ( + Buffer, + BufferLoad, + BufferRegion, + PrimExpr, + StringImm, + Var, +) from . import _ffi_api, frame +def buffer_decl( + shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + dtype: str = "float32", + data: Var = None, + strides: List[PrimExpr] = None, + elem_offset: PrimExpr = None, + scope: str = "", + align: int = 0, + offset_factor: int = 0, + buffer_type: str = "", + axis_separators: List[int] = None, +) -> Buffer: + """The buffer declaration function. + + Parameters + ---------- + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + + Returns + ------- + res : Buffer + The declared buffer. + """ + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + return _ffi_api.BufferDecl( # pylint: disable=no-member # type: ignore + shape, + dtype, + "", + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + def prim_func() -> frame.PrimFuncFrame: """The primitive function statement. @@ -33,6 +111,220 @@ def prim_func() -> frame.PrimFuncFrame: return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore +def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var, Buffer]: + """The PrimFunc arguments adding function. + + Parameters + ---------- + name : str + The name of the argument. + + var : Union[Var, Buffer] + The argument of Var or Buffer. + + Returns + ------- + res : Union[Var, Buffer] + The argument. + """ + return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore + + +def func_name(name: str) -> None: + """The PrimFunc naming statement. + + Parameters + ---------- + name : str + The name of the PrimFunc. + """ + _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore + + +def func_attr(attrs: Dict[str, Any]) -> None: + """The PrimFunc annotation statement. + + Parameters + ---------- + attrs : Dict[str, Any] + The annotations of the PrimFunc. + """ + _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore + + +def func_ret(ret_type: Type) -> Type: + """The PrimFunc return type statement. + + Parameters + ---------- + ret_type : Type + The return type of the PrimFunc. + + Returns + ------- + res : Type + The return type. + """ + return _ffi_api.FuncRet(ret_type) # pylint: disable=no-member # type: ignore + + +def match_buffer( + param: Union[Var, BufferLoad, BufferRegion], + shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + dtype: str = "float32", + data: Var = None, + strides: List[PrimExpr] = None, + elem_offset: PrimExpr = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: List[int] = None, +) -> Buffer: + """The buffer match function. + + Note + ---- + This function will perform different behavior, depending on the type of param. + If the param is a var in function parameter, it will create a buffer from DLTensor. + Else if the param is a subregion of other buffers, then create a subregion match inside a block. + + Example + ------- + Match buffer from function parameter + .. code-block:: python + A = T.match_buffer(a, (128, 128), dtype="float32") + + Match buffer from Buffer subregion + .. code-block:: python + A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") + + Parameters + ---------- + param : Union[Var, BufferLoad, BufferRegion] + The parameter of the PrimFunc to match. + + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + + Returns + ------- + res : Buffer + The matched buffer. + """ + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is None: + strides = [] + return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore + param, + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + +def preflattened_buffer( + postflattened: Buffer, + shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + dtype: str = "float32", + data: Var = None, + strides: List[PrimExpr] = None, + elem_offset: PrimExpr = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: List[int] = None, +) -> None: + """The pre-flattened buffer statement. + + Parameters + ---------- + postflattened : Buffer + The original buffer to be flattened. + + shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + The type of the buffer prior to flattening. + + dtype : str + The data type in the content of the buffer. + + data : Var + The pointer to the head of the data. + + strides : List[PrimExpr] + The strides of each dimension. + + elem_offset : PrimExpr + The offset in terms of number of dtype elements (including lanes). + + scope : str + The optional storage scope of buffer data pointer. + + align : int + The alignment requirement of data pointer in bytes. + + offset_factor : int + The factor of elem_offset field. + + buffer_type : str + The buffer type. + + axis_separators : List[int] + The separators between input axes when generating flattened output axes. + """ + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + if strides is None: + strides = [] + _ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore + postflattened, + shape, + dtype, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + ) + + def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: """The block declaration statement. @@ -65,11 +357,344 @@ def evaluate(value: PrimExpr) -> None: return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore +def int8(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type int8 or cast expression to type int8. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type int8 or casted expression with type int8. + """ + return _ffi_api.Int8(expr) # pylint: disable=no-member # type: ignore + + +def int16(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type int16 or cast expression to type int16. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type int16 or casted expression with type int16. + """ + return _ffi_api.Int16(expr) # pylint: disable=no-member # type: ignore + + +def int32(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type int32 or cast expression to type int32. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type int32 or casted expression with type int32. + """ + return _ffi_api.Int32(expr) # pylint: disable=no-member # type: ignore + + +def int64(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type int64 or cast expression to type int64. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type int64 or casted expression with type int64. + """ + return _ffi_api.Int64(expr) # pylint: disable=no-member # type: ignore + + +def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type uint8 or cast expression to type uint8. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type uint8 or casted expression with type uint8. + """ + return _ffi_api.UInt8(expr) # pylint: disable=no-member # type: ignore + + +def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type uint16 or cast expression to type uint16. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type uint16 or casted expression with type uint16. + """ + return _ffi_api.UInt16(expr) # pylint: disable=no-member # type: ignore + + +def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type uint32 or cast expression to type uint32. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type uint32 or casted expression with type uint32. + """ + return _ffi_api.UInt32(expr) # pylint: disable=no-member # type: ignore + + +def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type uint64 or cast expression to type uint64. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type uint64 or casted expression with type uint64. + """ + return _ffi_api.UInt64(expr) # pylint: disable=no-member # type: ignore + + +def float8(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type float8 or cast expression to type float8. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type float8 or casted expression with type float8. + """ + return _ffi_api.Float8(expr) # pylint: disable=no-member # type: ignore + + +def float16(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type float16 or cast expression to type float16. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type float16 or casted expression with type float16. + """ + return _ffi_api.Float16(expr) # pylint: disable=no-member # type: ignore + + +def float32(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type float32 or cast expression to type float32. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type float32 or casted expression with type float32. + """ + return _ffi_api.Float32(expr) # pylint: disable=no-member # type: ignore + + +def float64(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type float64 or cast expression to type float64. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type float64 or casted expression with type float64. + """ + return _ffi_api.Float64(expr) # pylint: disable=no-member # type: ignore + + +def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type int32x4 or cast expression to type int32x4. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type int32x4 or casted expression with type int32x4. + """ + return _ffi_api.Int32x4(expr) # pylint: disable=no-member # type: ignore + + +def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type int32x8 or cast expression to type int32x8. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type int32x8 or casted expression with type int32x8. + """ + return _ffi_api.Int32x8(expr) # pylint: disable=no-member # type: ignore + + +def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type int32x16 or cast expression to type int32x16. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type int32x16 or casted expression with type int32x16. + """ + return _ffi_api.Int32x16(expr) # pylint: disable=no-member # type: ignore + + +def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type boolean or cast expression to type boolean. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type boolean or casted expression with type boolean. + """ + return _ffi_api.Boolean(expr) # pylint: disable=no-member # type: ignore + + +def handle(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type handle or cast expression to type handle. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type handle or casted expression with type handle. + """ + return _ffi_api.Handle(expr) # pylint: disable=no-member # type: ignore + + +def void(expr: Optional[PrimExpr] = None) -> PrimExpr: + """Construct a new tir.Var with type void or cast expression to type void. + + Parameters + ---------- + expr: PrimExpr + The expression to be cast. + + Returns + ------- + res : PrimExpr + The new tir.Var with type void or casted expression with type void. + """ + return _ffi_api.Void(expr) # pylint: disable=no-member # type: ignore + + +def var(dtype, name="") -> Var: + """Construct a new tir.Var. + + Parameters + ---------- + dtype: str + The dtype of the Var. + + name: str + The name of the Var. + + Returns + ------- + res : Var + The result tir.Var. + """ + return Var(name, dtype) # pylint: disable=no-member # type: ignore + + # pylint: enable=invalid-name __all__ = [ + "buffer_decl", + "prim_func", + "arg", + "func_name", + "func_attr", + "func_ret", + "match_buffer", + "preflattened_buffer", "block", "evaluate", - "prim_func", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float8", + "float16", + "float32", + "float64", + "int32x4", + "int32x8", + "int32x16", + "boolean", + "handle", + "void", + "var", ] diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 4c2679ae6b56..e2c1218a7e87 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -28,6 +28,30 @@ namespace tir { using tvm::tir::IterVar; +Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Optional data, + Optional> strides, Optional elem_offset, + String storage_scope, int align, int offset_factor, String buffer_type, + Optional> axis_separators) { + Var buffer_data; + if (!data.defined()) { + DataType storage_dtype = dtype; + if (storage_dtype == DataType::Bool()) { + storage_dtype = DataType::Int(8); + } + buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope)); + } else { + buffer_data = data.value(); + } + if (!elem_offset.defined() && offset_factor) { + DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype; + elem_offset = tvm::tir::Var("elem_offset", shape_dtype); + } + return Buffer(buffer_data, dtype, shape, strides.value_or(Array()), + elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor, + (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault, + axis_separators.value_or(Array())); +} + PrimFuncFrame PrimFunc() { ObjectPtr n = make_object(); n->name = NullOpt; @@ -41,6 +65,98 @@ PrimFuncFrame PrimFunc() { return PrimFuncFrame(n); } +Var Arg(String name, Var var) { + PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); + details::Namer::Name(var, name); + frame->args.push_back(var); + return var; +} + +Buffer Arg(String name, Buffer buffer) { + PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); + details::Namer::Name(buffer, name); + Var handle(buffer->name + "_handle", DataType::Handle()); + frame->args.push_back(handle); + frame->buffer_map.Set(handle, buffer); + return buffer; +} + +void FuncName(String name) { + PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); + } + frame->name = name; +} + +void FuncAttrs(Map attrs) { + using namespace tvm::tir; + PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); + if (frame->attrs.defined()) { + LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs; + } + frame->attrs = attrs; +} + +tvm::Type FuncRet(tvm::Type ret_type) { + PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type"); + if (frame->ret_type.defined()) { + LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is " + << frame->ret_type.value(); + } + frame->ret_type = ret_type; + return ret_type; +} + +Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optional data, + Array strides, PrimExpr elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type_str, Array axis_separators) { + Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, + offset_factor, buffer_type_str, axis_separators); + if (const auto* var = param.as()) { + PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); + Var v = GetRef(var); + for (auto const& arg : frame->args) { + if (arg.same_as(v)) { + frame->buffer_map.Set(v, buffer); + return buffer; + } + } + LOG(FATAL) << "ValueError: Can not bind non-input param to buffer."; + } else if (const auto* buffer_load = param.as()) { + BlockFrame frame = FindBlockFrame("T.match_buffer"); + frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( + buffer, BufferRegionFromLoad(GetRef(buffer_load)))); + } else if (const auto* buffer_region = param.as()) { + BlockFrame frame = FindBlockFrame("T.match_buffer"); + frame->match_buffers.push_back( + tvm::tir::MatchBufferRegion(buffer, GetRef(buffer_region))); + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; + } + return buffer; +} + +void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, DataType dtype, + Optional data, Array strides, PrimExpr elem_offset, + String storage_scope, int align, int offset_factor, String buffer_type_str, + Array axis_separators) { + PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer"); + for (auto const& p : frame->buffer_map) { + if (p.second.same_as(postflattened_buffer)) { + String buffer_name(postflattened_buffer->name + "_preflatten"); + Buffer buffer = + BufferDecl(shape, dtype, buffer_name, data.value_or(p.second->data), strides, elem_offset, + storage_scope, align, offset_factor, buffer_type_str, axis_separators); + details::Namer::Name(buffer, buffer_name); + frame->preflattened_buffer_map.Set(p.first, buffer); + return; + } + } + LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name + << " does not exist."; +} + BlockFrame Block(String name, bool no_realize) { ObjectPtr n = make_object(); n->name = name; @@ -58,9 +174,87 @@ BlockFrame Block(String name, bool no_realize) { } void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } + +using tvm::script::ir_builder::details::Namer; + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + tvm::tir::BufferNode* buffer = + const_cast(node.as()); + buffer->name = name; + Namer::Name(buffer->data, name); + int n = buffer->strides.size(); + for (int i = 0; i < n; ++i) { + PrimExpr e = buffer->strides[i]; + if (const tvm::tir::VarNode* v = e.as()) { + Namer::Name(GetRef(v), name + "_s" + std::to_string(i)); + } + } + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + SizeVarNode* var = const_cast(node.as()); + var->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + VarNode* var = const_cast(node.as()); + var->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::tir; + IterVarNode* var = const_cast(node.as()); + Namer::Name(var->var, name); + }); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") + .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { + using namespace tvm::tir; + if (const auto* var = obj.as()) { + return Arg(name, GetRef(var)); + } + if (const auto* buffer = obj.as()) { + return Arg(name, GetRef(buffer)); + } + LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); + throw; + }); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int64").set_body_typed(Int64); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt8").set_body_typed(UInt8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt16").set_body_typed(UInt16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt32").set_body_typed(UInt32); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt64").set_body_typed(UInt64); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8").set_body_typed(Float8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float16").set_body_typed(Float16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float32").set_body_typed(Float32); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float64").set_body_typed(Float64); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x4").set_body_typed(Int32x4); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x8").set_body_typed(Int32x8); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); } // namespace tir } // namespace ir_builder } // namespace script diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 4f8b3f77c6e1..c29fae1c65e9 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -28,6 +28,10 @@ namespace script { namespace ir_builder { namespace tir { +/*! + * \brief Add tir Stmt to the top frame in IRBuilder frame stack. + * \param stmt The Stmt. + */ inline void AddToParent(tvm::tir::Stmt stmt) { IRBuilder builder = IRBuilder::Current(); if (builder->frames.empty()) { @@ -40,6 +44,11 @@ inline void AddToParent(tvm::tir::Stmt stmt) { } } +/*! + * \brief Convert array of tir Stmt to single Stmt. + * \param stmt The array of Stmt. + * \return The SeqStmt. + */ inline tvm::tir::Stmt AsStmt(const Array& stmt) { using namespace tvm::tir; if (stmt.empty()) { @@ -51,6 +60,11 @@ inline tvm::tir::Stmt AsStmt(const Array& stmt) { } } +/*! + * \brief Check whether the top frame in IRBuilder frame stack is PrimFuncFrame. + * \param method The method name to be printed when throwing exception. + * \return The top frame of PrimFuncFrame. + */ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { if (Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); @@ -60,6 +74,11 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { throw; } +/*! + * \brief Check whether the top frame in IRBuilder frame stack is BlockFrame. + * \param method The method name to be printed when throwing exception. + * \return The top frame of BlockFrame. + */ inline BlockFrame FindBlockFrame(const String& method) { if (Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); @@ -69,6 +88,19 @@ inline BlockFrame FindBlockFrame(const String& method) { throw; } +/*! + * \brief Convert BufferLoad to BufferRegion. + * \param buffer_load The BufferLoad. + * \return The converted BufferRegion. + */ +inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) { + Array ranges; + for (const PrimExpr& index : buffer_load->indices) { + ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1))); + } + return tvm::tir::BufferRegion(buffer_load->buffer, ranges); +} + } // namespace tir } // namespace ir_builder } // namespace script diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 85080c7c65fc..5c93e99909d9 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -25,7 +25,7 @@ from tvm.ir.base import assert_structural_equal -def test_ir_builder_tir_primfunc(): +def test_ir_builder_tir_primfunc_base(): with IRBuilder() as ib: with T.prim_func(): T.evaluate(0) @@ -45,6 +45,48 @@ def test_ir_builder_tir_primfunc(): assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True) +def test_ir_builder_tir_primfunc_complete(): + with IRBuilder() as ib: + with T.prim_func(): + T.arg("a", T.handle()) + T.arg("b", T.var("int64")) + T.arg("c", T.buffer_decl((128, 128), "float32")) + d = T.arg("d", T.handle()) + e = T.arg("e", T.buffer_decl((1024,), "int8")) + T.func_attr({"key": "value"}) + T.func_ret(tvm.ir.PrimType("int64")) + buffer_d = T.match_buffer(d, (64, 64), "int64") + T.preflattened_buffer(e, (32, 32), "int8", data=e.data) + T.evaluate(0) + # the prim_func generated by IRBuilder + prim_func_actual = ib.get() + + # the expected prim_func + c_handle, c_buffer = tir.Var("c_handle", "handle"), tir.decl_buffer( + (128, 128), "float32", name="c" + ) + d_handle, d_buffer = tir.Var("d", "handle"), tir.decl_buffer((64, 64), "int64", name="d") + e_handle, e_buffer = tir.Var("e_handle", "handle"), tir.decl_buffer((1024,), "int8", name="e") + prim_func_expected = tir.PrimFunc( + params=[ + tir.Var("a", "handle"), + tir.Var("b", "int64"), + c_handle, + d_handle, + e_handle, + ], + body=tir.Evaluate(0), + ret_type=tvm.ir.PrimType("int64"), + buffer_map={c_handle: c_buffer, d_handle: d_buffer, e_handle: e_buffer}, + preflattened_buffer_map={ + e_handle: tir.decl_buffer((32, 32), "int8", name="e_preflatten", data=e_buffer.data) + }, + attrs=tvm.ir.make_node("DictAttrs", key="value"), + ) + # Check if the generated ir is expected + assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True) + + def test_ir_builder_tir_block(): with IRBuilder() as ib: with T.block("block"):