Skip to content

Commit

Permalink
[Meta Schedule][M3a] Traced Schedule (#8623)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
6 people authored Aug 9, 2021
1 parent 00bed97 commit 3145867
Show file tree
Hide file tree
Showing 21 changed files with 625 additions and 295 deletions.
23 changes: 20 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -95,13 +96,15 @@ class ScheduleNode : public runtime::Object {
virtual ~ScheduleNode() = default;

static constexpr const char* _type_key = "tir.Schedule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object);

public:
/*! \brief Get the IRModule associated with this schedule. */
virtual IRModule mod() const { return state()->mod; }
/*! \return The internal state of scheduling */
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution */
virtual Optional<Trace> trace() const = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its symbol table,
* guaranteeing that
Expand Down Expand Up @@ -288,7 +291,7 @@ class Schedule : public runtime::ObjectRef {
/*!
* \brief Construct a concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \return The concrete schedule created
Expand All @@ -297,8 +300,22 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode,
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed include:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ enum ScheduleDebugMask : uint32_t {
* 2) The sref tree of schedulable statements (indicated by the srefs)
* 3) The dependency information of each block scope (block_info)
* 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
* 5) A debug flag, if set, extra checking is enabled (debug_mode)
* 5) A debug flag, if set, extra checking is enabled (debug_mask)
*/
class ScheduleStateNode : public Object {
public:
Expand All @@ -99,13 +99,13 @@ class ScheduleStateNode : public Object {
* and each time after calling the Replace method.
* \sa ScheduleDebugMask
*/
int debug_mode;
int debug_mask;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mod", &mod);
// `block_info` is not visited
// `stmt2ref` is not visited
v->Visit("debug_mode", &debug_mode);
v->Visit("debug_mask", &debug_mask);
}
/*!
* \brief Replace the part of the AST, as being pointed to by `src_sref`,
Expand All @@ -129,7 +129,7 @@ class ScheduleStateNode : public Object {
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
const Map<Block, Block>& block_sref_reuse);
/*!
* \brief Trigger the verification according to the `debug_mode` bitmask.
* \brief Trigger the verification according to the `debug_mask` bitmask.
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
* 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`,
* `region_cover` and `stage_pipeline`
Expand Down Expand Up @@ -186,10 +186,10 @@ class ScheduleState : public ObjectRef {
/*!
* \brief Construct a schedule state from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
*/
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0);
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);

/*! \return The mutable pointer to the ScheduleStateNode */
ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }
Expand Down
96 changes: 54 additions & 42 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""The TensorIR schedule class"""
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
Expand All @@ -25,7 +24,8 @@
from tvm.tir import Block, For, IntImm, PrimFunc

from . import _ffi_api
from .state import ScheduleState, StmtSRef
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
from .trace import Trace


@register_error
Expand Down Expand Up @@ -63,7 +63,20 @@ def __init__(self) -> None:
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name

# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8
ERROR_RENDER_LEVEL_CANDIDATES = Union[str] # pylint: disable=invalid-name
_ERROR_RENDER_LEVEL: Dict[str, int] = {
"detail": 0,
"fast": 1,
"none": 2,
}


def _parse_error_render_level(error_render_level: str) -> int:
if error_render_level not in _ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
return _ERROR_RENDER_LEVEL.get(error_render_level)


@_register_object("tir.Schedule")
Expand All @@ -81,73 +94,77 @@ class Schedule(Object):
Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
"""

ERROR_RENDER_LEVEL = {
"detail": 0,
"fast": 1,
"none": 2,
}

def __init__(
self,
mod: Union[PrimFunc, IRModule],
*,
debug_mode: Union[bool, int] = False,
error_render_level: ERROR_RENDER_LEVEL_CANDIDATES = "detail",
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> None:
"""Construct a concrete TensorIR schedule from an IRModule or a PrimFunc
"""Construct a TensorIR schedule class from an IRModule
Parameters
----------
mod : Union[PrimFunc, IRModule]
The IRModule or PrimFunc to be scheduled
debug_mode : Union[bool, int]
debug_mask : Union[str, int]
Do extra correctness checking after the class creation and each time
scheduling primitive
after calling the Replace method.
Possible choices of `debug_mask`:
1) "all" - Turn on all the checks
2) "none" - Turn off all the checks
3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
error_render_level : str = "detail"
The level of error rendering. Choices: "detail", "fast", "none".
"detail": Render a detailed error message, with the TIR and error locations printed
"fast: Show a simple error message without rendering or string manipulation
"none": Do not show any error message.
- "detail": Render a detailed error message, with the TIR and error locations printed
- "fast: Show a simple error message without rendering or string manipulation
- "none": Do not show any error message.
Note
----
The checks performed includes:
1) VerifySRefTree
2) VerifyCachedFlags
"""
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if isinstance(debug_mode, bool):
if debug_mode:
debug_mode = -1
else:
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
if error_render_level not in Schedule.ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
# call the constructor
self.__init_handle_by_constructor__(
_ffi_api.ConcreteSchedule, # type: ignore # pylint: disable=no-member
mod,
debug_mode,
Schedule.ERROR_RENDER_LEVEL.get(error_render_level),
_ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)

@staticmethod
def _create_non_traced(
mod: Union[PrimFunc, IRModule],
*,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)

########## Utilities ##########

@property
def mod(self) -> IRModule:
"""Returns the AST of the module being scheduled"""
return _ffi_api.ScheduleModule(self) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleGetMod(self) # type: ignore # pylint: disable=no-member

@property
def state(self) -> ScheduleState:
"""Returns the ScheduleState in the current schedule class"""
return _ffi_api.ScheduleGetState(self) # type: ignore # pylint: disable=no-member

@property
def trace(self) -> Optional[Trace]:
"""Returns the internally maintained trace of scheduling program execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member

def copy(self) -> "Schedule":
"""Returns a copy of the schedule, including both the state and the symbol table,
* guaranteeing that
Expand Down Expand Up @@ -702,8 +719,3 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
def enter_postproc(self) -> None:
"""A no-op that marks the start of postprocessing phase of scheduling"""
_ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member


@_register_object("tir.ConcreteSchedule")
class ConcreteSchedule(Schedule):
"""A concrete schedule class of TensorIR. Do not use directly, use tvm.tir.Schedule instead."""
57 changes: 35 additions & 22 deletions python/tvm/tir/schedule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@


class ScheduleDebugMask(IntEnum):
"""The bitmask of the `debug_mode` flag in the ScheduleState class.
"""The bitmask of the `debug_mask` flag in the ScheduleState class.
If the `debug_mode` flag has a certain bit on, then the correpsonding
verification pass will be conducted. For example, if `(debug_mode & VERIFY_SREF_TREE) != 0`,
If the `debug_mask` flag has a certain bit on, then the correpsonding
verification pass will be conducted. For example, if `(debug_mask & VERIFY_SREF_TREE) != 0`,
then the correctness of the sref tree will be verified after each schedule instruction.
Attributes
Expand All @@ -49,6 +49,27 @@ class ScheduleDebugMask(IntEnum):
VERIFY_CACHED_FLAGS = 2


def _parse_mod(mod: Union[PrimFunc, IRModule]) -> IRModule:
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if not isinstance(mod, IRModule):
raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
return mod


def _parse_debug_mask(debug_mask: Union[str, int]) -> int:
if isinstance(debug_mask, str):
if debug_mask == "all":
debug_mask = ScheduleDebugMask.VERIFY_SREF_TREE | ScheduleDebugMask.VERIFY_CACHED_FLAGS
elif debug_mask == "none":
debug_mask = 0
else:
raise ValueError(f"Unrecognizable `debug_mask`: {debug_mask}")
if isinstance(debug_mask, bool) or not isinstance(debug_mask, int):
raise TypeError(f"`debug_mask` should be integer or boolean, but gets: {debug_mask}")
return debug_mask


@register_object("tir.ScheduleState")
class ScheduleState(Object):
"""The state of scheduling, which exposes a `Replace` method as
Expand All @@ -59,52 +80,44 @@ class ScheduleState(Object):
2) The sref tree of schedulable statements (indicated by the srefs)
3) The dependency information of each block scope (block_info)
4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
5) A debug flag, if set, extra checking is enabled (debug_mode)
5) A debug flag, if set, extra checking is enabled (debug_mask)
Parameters
----------
mod : IRModule
The AST of the module being scheduled
debug_mode : int
debug_mask : int
Do extra correctness checking after the object construction
and each time after calling the Replace method.
"""

mod: IRModule
debug_mode: int
debug_mask: int

def __init__(
self,
mod: Union[PrimFunc, IRModule],
debug_mode: Union[bool, int] = False,
*,
debug_mask: Union[str, int] = "none",
) -> None:
"""Construct a schedule state from an IRModule or a PrimFunc
Parameters
----------
mod : Union[PrimFunc, IRModule]
The IRModule or PrimFunc to be scheduled
debug_mode : Union[bool, int]
debug_mask : Union[str, int]
Do extra correctness checking after the class creation and each time
after calling the Replace method.
Possible choices of `debug_mode`:
1) True - Turn on all the checks
2) False - Turn off all the checks
Possible choices of `debug_mask`:
1) "all" - Turn on all the checks
2) "none" - Turn off all the checks
3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
"""
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if isinstance(debug_mode, bool):
if debug_mode:
debug_mode = -1
else:
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
self.__init_handle_by_constructor__(
_ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member
mod,
debug_mode,
_parse_mod(mod),
_parse_debug_mask(debug_mask),
)

def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:
Expand Down
Loading

0 comments on commit 3145867

Please sign in to comment.