Skip to content

Commit

Permalink
[TVMScript] IRBuilder methods for PrimFunc (#12755)
Browse files Browse the repository at this point in the history
This PR introduces remaining IRBuilder methods for `PrimFunc`.

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww authored Sep 14, 2022
1 parent a0cbefb commit 3d7439e
Show file tree
Hide file tree
Showing 5 changed files with 1,022 additions and 3 deletions.
126 changes: 126 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,111 @@ namespace script {
namespace ir_builder {
namespace tir {

using tvm::tir::Buffer;
using tvm::tir::Var;

/*!
* \brief The buffer declaration function.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param buffer_name The name of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
* \return The declared buffer.
*/
Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
String storage_scope, int align, int offset_factor, String buffer_type,
Optional<Array<IntImm>> axis_separators);

/*!
* \brief The primitive function statement.
* \return The PrimFuncFrame.
*/
PrimFuncFrame PrimFunc();

/*!
* \brief The PrimFunc variable arguments adding function.
* \param name The name of the variable.
* \param var The variable argument.
* \return The variable.
*/
Var Arg(String name, Var var);

/*!
* \brief The PrimFunc buffer arguments adding function.
* \param name The name of the buffer.
* \param buffer The buffer argument.
* \return The buffer.
*/
Buffer Arg(String name, Buffer buffer);

/*!
* \brief The PrimFunc naming statement.
* \param name The name of the PrimFunc.
*/
void FuncName(String name);

/*!
* \brief The PrimFunc annotation statement.
* \param attrs The annotations of the PrimFunc.
*/
void FuncAttrs(Map<String, ObjectRef> attrs);

/*!
* \brief The PrimFunc return type statement.
* \param ret_type The return type of the PrimFunc.
* \return The return type.
*/
Type FuncRet(Type ret_type);

/*!
* \brief The buffer match statement.
* \param param The parameter of the PrimFunc to match.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
* \return The matched buffer.
*/
Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
int align = -1, int offset_factor = 0, String buffer_type = "default",
Array<IntImm> axis_separators = {});

/*!
* \brief The pre-flattened buffer statement.
* \param postflattened_buffer The original buffer to be flattened.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
*/
void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
String storage_scope = "global", int align = -1, int offset_factor = 0,
String buffer_type = "default", Array<IntImm> axis_separators = {});

/*!
* \brief The block declaration statement.
* \param name The name of the block.
Expand All @@ -48,6 +147,33 @@ BlockFrame Block(String name, bool no_realize = false);
*/
void Evaluate(PrimExpr value);

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
DataType dtype = DType; \
return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
}

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST

} // namespace tir
} // namespace ir_builder
} // namespace script
Expand Down
Loading

0 comments on commit 3d7439e

Please sign in to comment.