Skip to content

Commit

Permalink
[M3a] Traced Schedule (#423)
Browse files Browse the repository at this point in the history
* [M3a] Traced Schedule


Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
  • Loading branch information
junrushao and MasterJH5574 authored Aug 2, 2021
1 parent 7653972 commit 8e0b7c2
Show file tree
Hide file tree
Showing 12 changed files with 489 additions and 155 deletions.
19 changes: 18 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -95,13 +96,15 @@ class ScheduleNode : public runtime::Object {
virtual ~ScheduleNode() = default;

static constexpr const char* _type_key = "tir.Schedule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object);

public:
/*! \brief Get the IRModule associated with this schedule. */
virtual IRModule mod() const { return state()->mod; }
/*! \return The internal state of scheduling */
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution */
virtual Optional<Trace> trace() const = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its symbol table,
* guaranteeing that
Expand Down Expand Up @@ -299,6 +302,20 @@ class Schedule : public runtime::ObjectRef {
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
99 changes: 64 additions & 35 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""The TensorIR schedule class"""
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
Expand All @@ -26,6 +25,7 @@

from . import _ffi_api
from .state import ScheduleState, StmtSRef
from .trace import Trace


@register_error
Expand Down Expand Up @@ -63,7 +63,37 @@ def __init__(self) -> None:
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name

# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8
ERROR_RENDER_LEVEL_CANDIDATES = Union[str] # pylint: disable=invalid-name
_ERROR_RENDER_LEVEL = {
"detail": 0,
"fast": 1,
"none": 2,
}


def _preprocess_constructor_arguments(
mod: Union[PrimFunc, IRModule],
debug_mode: Union[bool, int] = False,
error_render_level: str = "detail",
) -> Tuple[IRModule, int, int]:
# preprocess `mod`
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
# preprocess `debug_mode`
if isinstance(debug_mode, bool):
if debug_mode:
debug_mode = -1
else:
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
# preprocess `error_render_level`
if error_render_level not in _ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
error_render_level_int: int = _ERROR_RENDER_LEVEL.get(error_render_level)
return mod, debug_mode, error_render_level_int


@_register_object("tir.Schedule")
Expand All @@ -81,20 +111,14 @@ class Schedule(Object):
Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
"""

ERROR_RENDER_LEVEL = {
"detail": 0,
"fast": 1,
"none": 2,
}

def __init__(
self,
mod: Union[PrimFunc, IRModule],
*,
debug_mode: Union[bool, int] = False,
error_render_level: ERROR_RENDER_LEVEL_CANDIDATES = "detail",
error_render_level: str = "detail",
) -> None:
"""Construct a concrete TensorIR schedule from an IRModule or a PrimFunc
"""Construct a TensorIR schedule class from an IRModule
Parameters
----------
Expand All @@ -115,39 +139,49 @@ def __init__(
1) VerifySRefTree
2) VerifyCachedFlags
"""
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if isinstance(debug_mode, bool):
if debug_mode:
debug_mode = -1
else:
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
if error_render_level not in Schedule.ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
# call the constructor
self.__init_handle_by_constructor__(
_ffi_api.ConcreteSchedule, # type: ignore # pylint: disable=no-member
mod,
debug_mode,
Schedule.ERROR_RENDER_LEVEL.get(error_render_level),
_ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member
*_preprocess_constructor_arguments(
mod,
debug_mode,
error_render_level,
),
)

@staticmethod
def _create_non_traced(
mod: Union[PrimFunc, IRModule],
*,
debug_mode: Union[bool, int] = False,
error_render_level: str = "detail",
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
*_preprocess_constructor_arguments(
mod,
debug_mode,
error_render_level,
)
)

########## Utilities ##########

@property
def mod(self) -> IRModule:
"""Returns the AST of the module being scheduled"""
return _ffi_api.ScheduleModule(self) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleGetMod(self) # type: ignore # pylint: disable=no-member

@property
def state(self) -> ScheduleState:
"""Returns the ScheduleState in the current schedule class"""
return _ffi_api.ScheduleGetState(self) # type: ignore # pylint: disable=no-member

@property
def trace(self) -> Optional[Trace]:
"""Returns the internally maintained trace of scheduling program execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member

def copy(self) -> "Schedule":
"""Returns a copy of the schedule, including both the state and the symbol table,
* guaranteeing that
Expand Down Expand Up @@ -702,8 +736,3 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
def enter_postproc(self) -> None:
"""A no-op that marks the start of postprocessing phase of scheduling"""
_ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member


@_register_object("tir.ConcreteSchedule")
class ConcreteSchedule(Schedule):
"""A concrete schedule class of TensorIR. Do not use directly, use tvm.tir.Schedule instead."""
62 changes: 62 additions & 0 deletions python/tvm/tir/schedule/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Testing utilities for the TensorIR schedule API"""
from typing import Union

from tvm import tir
from tvm.ir import IRModule, structural_equal
from tvm.tir import PrimFunc
from tvm.tir.schedule import Trace


def verify_trace_roundtrip(
sch: tir.Schedule,
mod: Union[PrimFunc, IRModule],
*,
debug_mod: Union[bool, int] = True,
) -> tir.Schedule:
"""Serialize a traced schedule to JSON, then replay the JSON trace by applying to
a fresh new schedule, verifying the reproducibility of scheduling.
Parameters
----------
sch : tir.Schedule
The traced TensorIR schedule to be verified
mod : Union[PrimFunc, IRModule]
The IRModule or PrimFunc to construct the fresh new schedule
debug_mode : Union[bool, int]
Do extra correctness checking after the class creation and each time
after calling the Replace method.
Possible choices of `debug_mode`:
1) True - Turn on all the checks
2) False - Turn off all the checks
3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
"""
# Step 1. Serialize the trace to JSON
trace = sch.trace
assert trace is not None
json_obj = trace.as_json()
# Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling
new_sch = tir.Schedule(mod=mod, debug_mode=debug_mod)
Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
assert structural_equal(new_sch.mod, sch.mod)
# Step 3. Check the consistency of the text format between the old and new traces
py_repr = "\n".join(trace.as_python())
new_py_repr = "\n".join(new_sch.trace.as_python())
assert py_repr == new_py_repr
# Step 4. Return the new schedule in case it could be useful
return new_sch
8 changes: 2 additions & 6 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb
Schedule ConcreteScheduleNode::Copy() const {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->error_render_level_ = this->error_render_level_;
this->Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>();
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
return Schedule(std::move(n));
}

Expand Down Expand Up @@ -376,9 +376,5 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {

/******** Schedule: blockize & tensorize ********/

/******** FFI ********/

TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode);

} // namespace tir
} // namespace tvm
6 changes: 2 additions & 4 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,17 @@ class ConcreteScheduleNode : public ScheduleNode {

public:
void VisitAttrs(tvm::AttrVisitor* v) {
// `error_render_level_` is not visited
// `state_` is not visited
// `error_render_level_` is not visited
// `symbol_table_` is not visited
// `analyzer_` is not visitied
}

virtual ~ConcreteScheduleNode() = default;

static constexpr const char* _type_key = "tir.ConcreteSchedule";
TVM_DECLARE_BASE_OBJECT_INFO(ConcreteScheduleNode, ScheduleNode);

public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
Schedule Copy() const override;

public:
Expand Down
9 changes: 8 additions & 1 deletion src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ TVM_REGISTER_NODE_TYPE(BlockRVNode);
TVM_REGISTER_NODE_TYPE(LoopRVNode);
TVM_REGISTER_OBJECT_TYPE(ScheduleNode);

TVM_REGISTER_GLOBAL("tir.schedule.ScheduleModule") //
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") //
.set_body_method<Schedule>(&ScheduleNode::mod);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") //
.set_body_method<Schedule>(&ScheduleNode::state);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") //
.set_body_method<Schedule>(&ScheduleNode::trace);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
.set_body_method<Schedule>(&ScheduleNode::Seed);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
Expand All @@ -62,6 +64,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
return Schedule::Concrete(mod, debug_mode,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
});
TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
.set_body_typed([](IRModule mod, int debug_mode, int error_render_level) -> Schedule {
return Schedule::Traced(mod, debug_mode,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
});

/******** (FFI) Lookup random variables ********/

Expand Down
Loading

0 comments on commit 8e0b7c2

Please sign in to comment.