Skip to content

Commit

Permalink
[SparseTIR][Schedule] SparseBlockRV, GetSparseBlock, SparseReorder (a…
Browse files Browse the repository at this point in the history
…pache#23)

* Test initialization

* Fix a stupid bug of ReprPrinter

* Add SparseBlockRV

* Schedule: GetSparseBlock

* Schedule: Reorder
  • Loading branch information
MasterJH5574 committed Dec 22, 2021
1 parent 4f55ec3 commit 0f7f083
Show file tree
Hide file tree
Showing 14 changed files with 1,082 additions and 174 deletions.
50 changes: 50 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>
#include <tvm/tir/sparse.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -85,6 +86,27 @@ using ExprRV = PrimExpr;

using ExprRVNode = PrimExprNode;

/**************** Random variable: SparseBlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR sparse block */
class SparseBlockRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.SparseBlockRV";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockRVNode, runtime::Object);
};

/*!
* \brief Managed reference to SparseBlockRVNode
* \sa SparseBlockRVNode
*/
class SparseBlockRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL SparseBlockRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SparseBlockRV, runtime::ObjectRef, SparseBlockRVNode);
};

/**************** The Schedule class ****************/

class Schedule;
Expand Down Expand Up @@ -143,6 +165,12 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding expr
*/
virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
/*!
* \brief Get the sparse block corresponding to the specific random variable
* \param sp_block_rv The random variable to be looked up
* \return SparseBlock The corresponding sparse block
*/
virtual SparseBlock Get(const SparseBlockRV& sp_block_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
Expand Down Expand Up @@ -188,6 +216,11 @@ class ScheduleNode : public runtime::Object {
* \param expr_rv The random variable to be removed
*/
virtual void RemoveRV(const ExprRV& expr_rv) = 0;
/*!
* \brief Remove an sparse block random variable from the symbol table
* \param sp_block_rv The random variable to be removed
*/
virtual void RemoveRV(const SparseBlockRV& sp_block_rv) = 0;

public:
/******** Schedule: Sampling ********/
Expand Down Expand Up @@ -495,6 +528,23 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
/******** Schedule: SparseTIR schedules ********/
/*!
* \brief Retrieve a sparse block in a specific function with its name
* \param name The name of the sparse block to be retrieved
* \param func_name The name of the function
* \return The sparse block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*/
virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0;
/*!
* \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator
* dependency.
* \param block The block to be transformed
* \param new_order The new order of the sparse iterators, whose length should equal to the number
* of the input block's sparse iterators
*/
virtual void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) = 0;
};

/*!
Expand Down
74 changes: 67 additions & 7 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, SparseBlock
from tvm.tir.sparse import SpIterVar

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
Expand Down Expand Up @@ -56,12 +57,23 @@ def __init__(self) -> None:
)


@_register_object("tir.SparseBlockRV")
class SparseBlockRV(Object):
"""A random variable that refers to a sparse block"""

def __init__(self) -> None:
"""Construct a new SparseBlockRV."""
self.__init_handle_by_constructor__(
_ffi_api.SparseBlockRV # type: ignore # pylint: disable=no-member
)


# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370
# This feature is not supported until python 3.10:
# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias
ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV, SparseBlockRV] # pylint: disable=invalid-name

# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8
_ERROR_RENDER_LEVEL: Dict[str, int] = {
Expand Down Expand Up @@ -227,7 +239,7 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str:
Parameters
----------
rand_var : Union[ExprRV, BlockRV, LoopRV]
rand_var : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV]
The random variable to be evaluated
Returns
Expand All @@ -243,22 +255,23 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str:
def get(
self,
rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef],
) -> Optional[Union[int, Block, For]]:
) -> Optional[Union[int, Block, For, SparseBlock]]:
"""Returns:
- the corresponding Block that a BlockRV evaluates to;
- the corresponding For that a LoopRV evaluates to;
- the corresponding integer that a ExprRV evaluates to;
- the corresponding SparseBlock that a SparseBlockRV evaluates to;
- the corresponding Block that a block sref points to;
- the corresponding For that a loop sref points to;
Parameters
----------
rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef]
rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV, StmtSRef]
The random variable / sref to be evaluated
Returns
-------
result : Optional[Union[int, Block, For]]
result : Optional[Union[int, Block, For, SparseBlock]]
The corresponding result
"""
if isinstance(rand_var_or_sref, StmtSRef):
Expand Down Expand Up @@ -296,7 +309,7 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:
Parameters
----------
rand_var : Union[BlockRV, LoopRV, ExprRV]
rand_var : Union[BlockRV, LoopRV, ExprRV, SparseBlockRV]
The random variable to be removed
"""
return _ffi_api.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member
Expand Down Expand Up @@ -1862,3 +1875,50 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
def enter_postproc(self) -> None:
"""A no-op that marks the start of postprocessing phase of scheduling"""
_ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member

########## Schedule: SparseTIR schedules ##########

def get_sparse_block(
self,
name: str,
func_name: str = "main",
) -> SparseBlock:
"""Retrieve a sparse block in a specific function with its name
Parameters
----------
name : str
The name of the sparse block
func_name : str = "main"
The name of the function
Returns
-------
block : SparseBlockRV
The sparse block retrieved
IndexError is raised if 0 or multiple blocks exist with the specific name.
"""
return _ffi_api.ScheduleGetSparseBlock( # type: ignore # pylint: disable=no-member
self,
name,
func_name,
)

def sparse_reorder(self, block: SparseBlockRV, new_order: List[SpIterVar]) -> None:
"""Reorder a list of sparse iterators. It requires the new order to not break the iterator
dependency.
Parameters
----------
block : SparseBlockRV
The queried sparse block
new_order : List[SpIterVar]
The The new order of the sparse iterators, whose length should equal to the number
of the input block's sparse iterators
"""
return _ffi_api.ScheduleSparseReorder( # type: ignore # pylint: disable=no-member
self,
block,
new_order,
)
2 changes: 1 addition & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferStoreNode*>(node.get());
auto* op = static_cast<const SparseBufferStoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
Expand Down
Loading

0 comments on commit 0f7f083

Please sign in to comment.