Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TVMScript] IRBuilder methods for Block (apache#12815)
Browse files Browse the repository at this point in the history
This PR introduces remaining IRBuilder methods for `Block`.

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
2 people authored and xinetzone committed Nov 25, 2022
1 parent 6e3e643 commit 856bdcd
Show file tree
Hide file tree
Showing 10 changed files with 442 additions and 52 deletions.
35 changes: 35 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,41 @@ class BlockFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
};

/*!
* \brief A frame that represents the block initialization statment.
*
* \sa BlockInitFrame
*/
class BlockInitFrameNode : public TIRFrameNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }

static constexpr const char* _type_key = "script.ir_builder.tir.BlockInitFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when entering RAII scope.
* \sa tvm::support::With
*/
void EnterWithScope() final;
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to BlockInitFrameNode.
*
* \sa BlockInitFrameNode
*/
class BlockInitFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode);
};

/*!
* \brief A frame that represents the for loop.
*
Expand Down
49 changes: 49 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,55 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
*/
BlockFrame Block(String name, bool no_realize = false);

/*!
* \brief The block initialization statement.
* \return The BlockInitFrame.
*/
BlockInitFrame Init();

/*!
* \brief The block predicate statement.
* \param predicate The predicate condition.
*/
void Where(PrimExpr predicate);

/*!
* \brief The block buffer region reading statement.
* \param buffer_slices The array of buffer regions to read.
*/
void Reads(Array<ObjectRef> buffer_slices);

/*!
* \brief The block buffer region writing statement.
* \param buffer_slices The array of buffer regions to write.
*/
void Writes(Array<ObjectRef> buffer_slices);

/*!
* \brief The block annotation statement.
* \param attrs The annotation of the block.
*/
void BlockAttrs(Map<String, ObjectRef> attrs);

/*!
* \brief The buffer allocation function.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
* \return The allocated buffer.
*/
Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1,
int offset_factor = 0, String buffer_type = "default",
Array<IntImm> axis_separators = {});
namespace axis {

/*!
Expand Down
18 changes: 9 additions & 9 deletions python/tvm/script/ir_builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ class IRBuilderFrame(_Object):
"""

def __enter__(self) -> "IRBuilderFrame":
_ffi_api.IRBuilderFrameEnter(self) # pylint: disable=no-member # type: ignore
_ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
return self

def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member

def add_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback method invoked when exiting the with-scope.
Expand All @@ -75,7 +75,7 @@ def add_callback(self, callback: Callable[[], None]) -> None:
callback : Callable[[], None]
The callback method to be invoked.
"""
_ffi_api.IRBuilderFrameAddCallback( # pylint: disable=no-member # type: ignore
_ffi_api.IRBuilderFrameAddCallback( # type: ignore[attr-defined] # pylint: disable=no-member
self, callback
)

Expand Down Expand Up @@ -104,7 +104,7 @@ class IRBuilder(_Object):
def __init__(self) -> None:
"""Construct an IRBuilder."""
self.__init_handle_by_constructor__(
_ffi_api.IRBuilder # pylint: disable=no-member # type: ignore
_ffi_api.IRBuilder # type: ignore[attr-defined] # pylint: disable=no-member
)

def __enter__(self) -> "IRBuilder":
Expand All @@ -119,11 +119,11 @@ def __enter__(self) -> "IRBuilder":
with IRBuilder() as builder:
assert IRBuilder.current() == builder
"""
_ffi_api.IRBuilderEnter(self) # pylint: disable=no-member # type: ignore
_ffi_api.IRBuilderEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
return self

def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.IRBuilderExit(self) # pylint: disable=no-member # type: ignore
_ffi_api.IRBuilderExit(self) # type: ignore[attr-defined] # pylint: disable=no-member

@staticmethod
def current() -> "IRBuilder":
Expand All @@ -134,11 +134,11 @@ def current() -> "IRBuilder":
builder : IRBuilder
The current IRBuilder.
"""
return _ffi_api.IRBuilderCurrent() # pylint: disable=no-member # type: ignore
return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member

def get(self) -> _Object:
"""Get the constructed IR."""
return _ffi_api.IRBuilderGet(self) # pylint: disable=no-member # type: ignore
return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member

@staticmethod
def name(s: str, v: Any) -> Any:
Expand All @@ -156,7 +156,7 @@ def name(s: str, v: Any) -> Any:
v : Any
The same object with the name set.
"""
return _ffi_api.IRBuilderName(s, v) # pylint: disable=no-member # type: ignore
return _ffi_api.IRBuilderName(s, v) # type: ignore[attr-defined] # pylint: disable=no-member

@staticmethod
def name_many( # pylint: disable=invalid-name
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@


def ir_module() -> IRModuleFrame:
return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore
return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member
7 changes: 6 additions & 1 deletion python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ class BlockFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.BlockInitFrame")
class BlockInitFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.ForFrame")
class ForFrame(TIRFrame):
def __enter__(self) -> Union[Var, List[Var]]:
def __enter__(self) -> Union[Var, List[Var]]: # type: ignore[override]
super().__enter__()
return self.vars if len(self.vars) > 1 else self.vars[0]
Loading

0 comments on commit 856bdcd

Please sign in to comment.