diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index bd23773976..e208377843 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_SCHEDULE_H_ #include +#include namespace tvm { namespace tir { @@ -95,13 +96,15 @@ class ScheduleNode : public runtime::Object { virtual ~ScheduleNode() = default; static constexpr const char* _type_key = "tir.Schedule"; - TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object); + TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object); public: /*! \brief Get the IRModule associated with this schedule. */ virtual IRModule mod() const { return state()->mod; } /*! \return The internal state of scheduling */ virtual ScheduleState state() const = 0; + /*! \return The internally maintained trace of scheduling program execution */ + virtual Optional trace() const = 0; /*! * \brief Returns a copy of the schedule, including both its state and its symbol table, * guaranteeing that @@ -288,7 +291,7 @@ class Schedule : public runtime::ObjectRef { /*! * \brief Construct a concrete TensorIR schedule from an IRModule * \param mod The IRModule to be scheduled - * \param debug_mode Do extra correctness checking after the class creation + * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering * \return The concrete schedule created @@ -297,8 +300,22 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode, + TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask, ScheduleErrorRenderLevel error_render_level); + /*! + * \brief Construct a traced concrete TensorIR schedule from an IRModule + * \param mod The IRModule to be scheduled + * \param debug_mask Do extra correctness checking after the class creation + * and each time after calling the Replace method. + * \param error_render_level The level of error rendering + * \return The concrete schedule created + * \sa ScheduleDebugMask + * \note The checks performed include: + * 1) VerifySRefTree + * 2) VerifyCachedFlags + */ + TVM_DLL static Schedule Traced(IRModule mod, int debug_mask, + ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 077bf938f4..7cd1b00c15 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -80,7 +80,7 @@ enum ScheduleDebugMask : uint32_t { * 2) The sref tree of schedulable statements (indicated by the srefs) * 3) The dependency information of each block scope (block_info) * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref) - * 5) A debug flag, if set, extra checking is enabled (debug_mode) + * 5) A debug flag, if set, extra checking is enabled (debug_mask) */ class ScheduleStateNode : public Object { public: @@ -99,13 +99,13 @@ class ScheduleStateNode : public Object { * and each time after calling the Replace method. * \sa ScheduleDebugMask */ - int debug_mode; + int debug_mask; void VisitAttrs(AttrVisitor* v) { v->Visit("mod", &mod); // `block_info` is not visited // `stmt2ref` is not visited - v->Visit("debug_mode", &debug_mode); + v->Visit("debug_mask", &debug_mask); } /*! * \brief Replace the part of the AST, as being pointed to by `src_sref`, @@ -129,7 +129,7 @@ class ScheduleStateNode : public Object { TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, const Map& block_sref_reuse); /*! - * \brief Trigger the verification according to the `debug_mode` bitmask. + * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. * 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`, * `region_cover` and `stage_pipeline` @@ -186,10 +186,10 @@ class ScheduleState : public ObjectRef { /*! * \brief Construct a schedule state from an IRModule * \param mod The IRModule to be scheduled - * \param debug_mode Do extra correctness checking after the class creation + * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. */ - TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0); + TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0); /*! \return The mutable pointer to the ScheduleStateNode */ ScheduleStateNode* get() const { return static_cast(data_.get()); } diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 22c08398df..4bbb5b9b15 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -14,9 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import """The TensorIR schedule class""" -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error @@ -25,7 +24,8 @@ from tvm.tir import Block, For, IntImm, PrimFunc from . import _ffi_api -from .state import ScheduleState, StmtSRef +from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod +from .trace import Trace @register_error @@ -63,7 +63,20 @@ def __init__(self) -> None: RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name # Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8 -ERROR_RENDER_LEVEL_CANDIDATES = Union[str] # pylint: disable=invalid-name +_ERROR_RENDER_LEVEL: Dict[str, int] = { + "detail": 0, + "fast": 1, + "none": 2, +} + + +def _parse_error_render_level(error_render_level: str) -> int: + if error_render_level not in _ERROR_RENDER_LEVEL: + raise ValueError( + 'error_render_level can be "detail", "fast", or "none", but got: ' + + f"{error_render_level}" + ) + return _ERROR_RENDER_LEVEL.get(error_render_level) @_register_object("tir.Schedule") @@ -81,33 +94,31 @@ class Schedule(Object): Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html """ - ERROR_RENDER_LEVEL = { - "detail": 0, - "fast": 1, - "none": 2, - } - def __init__( self, mod: Union[PrimFunc, IRModule], *, - debug_mode: Union[bool, int] = False, - error_render_level: ERROR_RENDER_LEVEL_CANDIDATES = "detail", + debug_mask: Union[str, int] = "none", + error_render_level: str = "detail", ) -> None: - """Construct a concrete TensorIR schedule from an IRModule or a PrimFunc + """Construct a TensorIR schedule class from an IRModule Parameters ---------- mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled - debug_mode : Union[bool, int] + debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time - scheduling primitive + after calling the Replace method. + Possible choices of `debug_mask`: + 1) "all" - Turn on all the checks + 2) "none" - Turn off all the checks + 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask error_render_level : str = "detail" The level of error rendering. Choices: "detail", "fast", "none". - "detail": Render a detailed error message, with the TIR and error locations printed - "fast: Show a simple error message without rendering or string manipulation - "none": Do not show any error message. + - "detail": Render a detailed error message, with the TIR and error locations printed + - "fast: Show a simple error message without rendering or string manipulation + - "none": Do not show any error message. Note ---- @@ -115,25 +126,26 @@ def __init__( 1) VerifySRefTree 2) VerifyCachedFlags """ - if isinstance(mod, PrimFunc): - mod = IRModule({"main": mod}) - if isinstance(debug_mode, bool): - if debug_mode: - debug_mode = -1 - else: - debug_mode = 0 - if not isinstance(debug_mode, int): - raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") - if error_render_level not in Schedule.ERROR_RENDER_LEVEL: - raise ValueError( - 'error_render_level can be "detail", "fast", or "none", but got: ' - + f"{error_render_level}" - ) + # call the constructor self.__init_handle_by_constructor__( - _ffi_api.ConcreteSchedule, # type: ignore # pylint: disable=no-member - mod, - debug_mode, - Schedule.ERROR_RENDER_LEVEL.get(error_render_level), + _ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member + _parse_mod(mod), + _parse_debug_mask(debug_mask), + _parse_error_render_level(error_render_level), + ) + + @staticmethod + def _create_non_traced( + mod: Union[PrimFunc, IRModule], + *, + debug_mask: Union[str, int] = "none", + error_render_level: str = "detail", + ) -> "Schedule": + """Construct a non-traced TensorIR schedule class from an IRModule.""" + return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member + _parse_mod(mod), + _parse_debug_mask(debug_mask), + _parse_error_render_level(error_render_level), ) ########## Utilities ########## @@ -141,13 +153,18 @@ def __init__( @property def mod(self) -> IRModule: """Returns the AST of the module being scheduled""" - return _ffi_api.ScheduleModule(self) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetMod(self) # type: ignore # pylint: disable=no-member @property def state(self) -> ScheduleState: """Returns the ScheduleState in the current schedule class""" return _ffi_api.ScheduleGetState(self) # type: ignore # pylint: disable=no-member + @property + def trace(self) -> Optional[Trace]: + """Returns the internally maintained trace of scheduling program execution""" + return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member + def copy(self) -> "Schedule": """Returns a copy of the schedule, including both the state and the symbol table, * guaranteeing that @@ -702,8 +719,3 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None: def enter_postproc(self) -> None: """A no-op that marks the start of postprocessing phase of scheduling""" _ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member - - -@_register_object("tir.ConcreteSchedule") -class ConcreteSchedule(Schedule): - """A concrete schedule class of TensorIR. Do not use directly, use tvm.tir.Schedule instead.""" diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index cc2415f150..a371897abd 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -31,10 +31,10 @@ class ScheduleDebugMask(IntEnum): - """The bitmask of the `debug_mode` flag in the ScheduleState class. + """The bitmask of the `debug_mask` flag in the ScheduleState class. - If the `debug_mode` flag has a certain bit on, then the correpsonding - verification pass will be conducted. For example, if `(debug_mode & VERIFY_SREF_TREE) != 0`, + If the `debug_mask` flag has a certain bit on, then the correpsonding + verification pass will be conducted. For example, if `(debug_mask & VERIFY_SREF_TREE) != 0`, then the correctness of the sref tree will be verified after each schedule instruction. Attributes @@ -49,6 +49,27 @@ class ScheduleDebugMask(IntEnum): VERIFY_CACHED_FLAGS = 2 +def _parse_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: + if isinstance(mod, PrimFunc): + mod = IRModule({"main": mod}) + if not isinstance(mod, IRModule): + raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") + return mod + + +def _parse_debug_mask(debug_mask: Union[str, int]) -> int: + if isinstance(debug_mask, str): + if debug_mask == "all": + debug_mask = ScheduleDebugMask.VERIFY_SREF_TREE | ScheduleDebugMask.VERIFY_CACHED_FLAGS + elif debug_mask == "none": + debug_mask = 0 + else: + raise ValueError(f"Unrecognizable `debug_mask`: {debug_mask}") + if isinstance(debug_mask, bool) or not isinstance(debug_mask, int): + raise TypeError(f"`debug_mask` should be integer or boolean, but gets: {debug_mask}") + return debug_mask + + @register_object("tir.ScheduleState") class ScheduleState(Object): """The state of scheduling, which exposes a `Replace` method as @@ -59,24 +80,25 @@ class ScheduleState(Object): 2) The sref tree of schedulable statements (indicated by the srefs) 3) The dependency information of each block scope (block_info) 4) A reverse mapping from the AST nodes to that in the sref tree (get_sref) - 5) A debug flag, if set, extra checking is enabled (debug_mode) + 5) A debug flag, if set, extra checking is enabled (debug_mask) Parameters ---------- mod : IRModule The AST of the module being scheduled - debug_mode : int + debug_mask : int Do extra correctness checking after the object construction and each time after calling the Replace method. """ mod: IRModule - debug_mode: int + debug_mask: int def __init__( self, mod: Union[PrimFunc, IRModule], - debug_mode: Union[bool, int] = False, + *, + debug_mask: Union[str, int] = "none", ) -> None: """Construct a schedule state from an IRModule or a PrimFunc @@ -84,27 +106,18 @@ def __init__( ---------- mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled - debug_mode : Union[bool, int] + debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. - Possible choices of `debug_mode`: - 1) True - Turn on all the checks - 2) False - Turn off all the checks + Possible choices of `debug_mask`: + 1) "all" - Turn on all the checks + 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask """ - if isinstance(mod, PrimFunc): - mod = IRModule({"main": mod}) - if isinstance(debug_mode, bool): - if debug_mode: - debug_mode = -1 - else: - debug_mode = 0 - if not isinstance(debug_mode, int): - raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") self.__init_handle_by_constructor__( _ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member - mod, - debug_mode, + _parse_mod(mod), + _parse_debug_mask(debug_mask), ) def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py new file mode 100644 index 0000000000..66ede31f41 --- /dev/null +++ b/python/tvm/tir/schedule/testing.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities for the TensorIR schedule API""" +from typing import Union + +from tvm import tir +from tvm.ir import IRModule, structural_equal +from tvm.tir import PrimFunc +from tvm.tir.schedule import Trace + + +def verify_trace_roundtrip( + sch: tir.Schedule, + mod: Union[PrimFunc, IRModule], + *, + debug_mask: Union[str, int] = "all", +) -> tir.Schedule: + """Serialize a traced schedule to JSON, then replay the JSON trace by applying to + a fresh new schedule, verifying the reproducibility of scheduling. + + Parameters + ---------- + sch : tir.Schedule + The traced TensorIR schedule to be verified + mod : Union[PrimFunc, IRModule] + The IRModule or PrimFunc to construct the fresh new schedule + debug_mask : Union[str, int] + Do extra correctness checking after the class creation and each time + after calling the Replace method. + Possible choices of `debug_mask`: + 1) "all" - Turn on all the checks + 2) "none" - Turn off all the checks + 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask + """ + # Step 1. Serialize the trace to JSON + trace = sch.trace + assert trace is not None + json_obj = trace.as_json() + # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling + new_sch = tir.Schedule(mod=mod, debug_mask=debug_mask) + Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) + assert structural_equal(new_sch.mod, sch.mod) + # Step 3. Check the consistency of the text format between the old and new traces + py_repr = "\n".join(trace.as_python()) + new_py_repr = "\n".join(new_sch.trace.as_python()) + assert py_repr == new_py_repr + # Step 4. Return the new schedule in case it could be useful + return new_sch diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index df2f068151..610628c6d8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -21,10 +21,10 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int debug_mode, +Schedule Schedule::Concrete(IRModule mod, int debug_mask, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); - n->state_ = ScheduleState(mod, debug_mode); + n->state_ = ScheduleState(mod, debug_mask); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); @@ -50,7 +50,7 @@ class ScheduleCopier { n->mod = src_state->mod; n->block_info = copier.Copy(src_state->block_info); n->stmt2ref = copier.Copy(src_state->stmt2ref); - n->debug_mode = src_state->debug_mode; + n->debug_mask = src_state->debug_mask; *new_state = ScheduleState(std::move(n)); *new_symbol_table = copier.Copy(self->symbol_table_); } @@ -182,8 +182,8 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb Schedule ConcreteScheduleNode::Copy() const { ObjectPtr n = make_object(); n->error_render_level_ = this->error_render_level_; - this->Copy(&n->state_, &n->symbol_table_); - n->analyzer_ = std::make_unique(); + ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); + n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful return Schedule(std::move(n)); } @@ -376,9 +376,5 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { /******** Schedule: blockize & tensorize ********/ -/******** FFI ********/ - -TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index c44ec05d66..ec0dd07924 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -46,19 +46,17 @@ class ConcreteScheduleNode : public ScheduleNode { public: void VisitAttrs(tvm::AttrVisitor* v) { - // `error_render_level_` is not visited // `state_` is not visited + // `error_render_level_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied } virtual ~ConcreteScheduleNode() = default; - static constexpr const char* _type_key = "tir.ConcreteSchedule"; - TVM_DECLARE_BASE_OBJECT_INFO(ConcreteScheduleNode, ScheduleNode); - public: ScheduleState state() const final { return state_; } + Optional trace() const override { return NullOpt; } Schedule Copy() const override; public: diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index eda6ac27d2..3232a3344e 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -44,10 +44,12 @@ TVM_REGISTER_NODE_TYPE(BlockRVNode); TVM_REGISTER_NODE_TYPE(LoopRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleModule") // +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // .set_body_method(&ScheduleNode::mod); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // + .set_body_method(&ScheduleNode::trace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // .set_body_method(&ScheduleNode::Seed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // @@ -58,10 +60,15 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](IRModule mod, int debug_mode, int error_render_level) -> Schedule { - return Schedule::Concrete(mod, debug_mode, + .set_body_typed([](IRModule mod, int debug_mask, int error_render_level) -> Schedule { + return Schedule::Concrete(mod, debug_mask, static_cast(error_render_level)); }); +TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") + .set_body_typed([](IRModule mod, int debug_mask, int error_render_level) -> Schedule { + return Schedule::Traced(mod, debug_mask, + static_cast(error_render_level)); + }); /******** (FFI) Lookup random variables ********/ diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 8f0284f290..6dd09680e9 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -170,13 +170,13 @@ class StateCreator : private StmtVisitor { * \brief The entry function * \param self The schedule state to be completed */ - static ObjectPtr Create(IRModule mod, int debug_mode) { + static ObjectPtr Create(IRModule mod, int debug_mask) { ObjectPtr n = make_object(); ScheduleStateNode* self = n.get(); // Set `n->mod` n->mod = std::move(mod); - // Set `n->debug_mode` - n->debug_mode = debug_mode; + // Set `n->debug_mask` + n->debug_mask = debug_mask; // Set `n->stmt2ref` and `n->block_info` StateCreator creator(self); for (const auto& kv : n->mod->functions) { @@ -411,9 +411,9 @@ class StateCreator : private StmtVisitor { /**************** Constructor ****************/ -ScheduleState::ScheduleState(IRModule mod, int debug_mode) { - CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported"; - data_ = StateCreator::Create(mod, debug_mode); +ScheduleState::ScheduleState(IRModule mod, int debug_mask) { + CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported"; + data_ = StateCreator::Create(mod, debug_mask); } /**************** Replace ****************/ @@ -836,7 +836,7 @@ class ChildReplacer : private StmtMutator { void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, const Map& _block_sref_reuse) { - if (this->debug_mode != 0) { + if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; bool input_correct = (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || @@ -990,8 +990,8 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ new_map->at(g_var) = std::move(ref_new_func); this->mod = GetRef(new_mod); } - uint32_t flag = (debug_mode != -1) // - ? static_cast(debug_mode) // + uint32_t flag = (debug_mask != -1) // + ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { VerifySRefTree(GetRef(this)); @@ -999,9 +999,9 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ } void ScheduleStateNode::DebugVerify() const { - ICHECK_GE(debug_mode, -1); - uint32_t flag = (debug_mode != -1) // - ? static_cast(debug_mode) // + ICHECK_GE(debug_mask, -1); + uint32_t flag = (debug_mask != -1) // + ? static_cast(debug_mask) // : std::numeric_limits::max(); if (flag & ScheduleDebugMask::kVerifySRefTree) { VerifySRefTree(GetRef(this)); @@ -1033,8 +1033,8 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl TVM_REGISTER_NODE_TYPE(ScheduleStateNode); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState") - .set_body_typed([](IRModule mod, int debug_mode) -> ScheduleState { - return ScheduleState(mod, debug_mode); + .set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState { + return ScheduleState(mod, debug_mask); }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") .set_body_method(&ScheduleStateNode::GetBlockScope); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc new file mode 100644 index 0000000000..d664d7f6ce --- /dev/null +++ b/src/tir/schedule/traced_schedule.cc @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./traced_schedule.h" + +namespace tvm { +namespace tir { + +Schedule Schedule::Traced(IRModule mod, int debug_mask, + ScheduleErrorRenderLevel error_render_level) { + ObjectPtr n = make_object(); + n->state_ = ScheduleState(mod, debug_mask); + n->error_render_level_ = error_render_level; + n->symbol_table_ = {}; + n->analyzer_ = std::make_unique(); + n->trace_ = Trace(); + return Schedule(std::move(n)); +} + +Schedule TracedScheduleNode::Copy() const { + ObjectPtr n = make_object(); + n->error_render_level_ = this->error_render_level_; + ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); + n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful + n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); + return Schedule(std::move(n)); +} + +/******** Schedule: Sampling ********/ + +/******** Schedule: Get blocks & loops ********/ + +BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { + BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name); + + static const InstructionKind& kind = InstructionKind::Get("GetBlock"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{}, + /*attrs=*/{name, func_name}, + /*outputs=*/{result})); + return result; +} + +Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetLoops(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetLoops"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +/******** Schedule: Transform loops ********/ + +LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs) { + LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs); + + static const InstructionKind& kind = InstructionKind::Get("Fuse"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rvs.begin(), loop_rvs.end()}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + +Array TracedScheduleNode::Split(const LoopRV& loop_rv, + const Array>& factor_rvs) { + Array results = ConcreteScheduleNode::Split(loop_rv, factor_rvs); + + std::vector inputs; + inputs.reserve(1 + factor_rvs.size()); + inputs.push_back(loop_rv); + for (const ObjectRef& obj : factor_rvs) { + inputs.push_back(obj); + } + + static const InstructionKind& kind = InstructionKind::Get("Split"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/inputs, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +/******** Schedule: Manipulate ForKind ********/ + +/******** Schedule: Insert cache stages ********/ + +/******** Schedule: Compute location ********/ + +void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) { + ConcreteScheduleNode::ComputeInline(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("ComputeInline"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { + ConcreteScheduleNode::ReverseComputeInline(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("ReverseComputeInline"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +/******** Schedule: Reduction ********/ + +BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { + BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis); + static const InstructionKind& kind = InstructionKind::Get("RFactor"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{Integer(factor_axis)}, + /*outputs=*/{result})); + return result; +} + +/******** Schedule: Blockize & Tensorize ********/ + +/******** Schedule: Annotation ********/ + +/******** Schedule: Misc ********/ + +void TracedScheduleNode::EnterPostproc() { + ConcreteScheduleNode::EnterPostproc(); + static const InstructionKind& kind = InstructionKind::Get("EnterPostproc"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h new file mode 100644 index 0000000000..b4518cbba8 --- /dev/null +++ b/src/tir/schedule/traced_schedule.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ +#define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ + +#include "./concrete_schedule.h" + +namespace tvm { +namespace tir { + +class TracedScheduleNode : public ConcreteScheduleNode { + friend class Schedule; + + protected: + Trace trace_; + + public: + void VisitAttrs(tvm::AttrVisitor* v) { + // `state_` is not visited + // `error_render_level_` is not visited + // `symbol_table_` is not visited + // `analyzer_` is not visitied + // `trace_` is not visited + } + + ~TracedScheduleNode() = default; + + public: + Optional trace() const final { return trace_; } + Schedule Copy() const final; + + public: + /******** Schedule: Sampling ********/ + + /******** Schedule: Get blocks & loops ********/ + BlockRV GetBlock(const String& name, const String& func_name = "main") final; + Array GetLoops(const BlockRV& block_rv) final; + /******** Schedule: Transform loops ********/ + LoopRV Fuse(const Array& loop_rvs) final; + Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; + /******** Schedule: Manipulate ForKind ********/ + /******** Schedule: Insert cache stages ********/ + /******** Schedule: Compute location ********/ + void ComputeInline(const BlockRV& block_rv) final; + void ReverseComputeInline(const BlockRV& block_rv) final; + /******** Schedule: Reduction ********/ + BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; + /******** Schedule: Blockize & Tensorize ********/ + /******** Schedule: Annotation ********/ + /******** Schedule: Misc ********/ + void EnterPostproc() final; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 52dc4ccd9f..2fdafe08e6 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -27,7 +27,7 @@ def test_unique_name(): B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main") C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main") func = te.create_prim_func([A, C]) - s = tir.Schedule(func, debug_mode=True) + s = tir.Schedule(func, debug_mask="all") assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef) assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef) @@ -36,7 +36,7 @@ def _check_workload(te_workload, tir_workload): func = te.create_prim_func(te_workload()) tvm.ir.assert_structural_equal(func, tir_workload) # make sure that we can create schedule from the func - s = tir.Schedule(func, debug_mode=True) + s = tir.Schedule(func, debug_mask="all") assert s diff --git a/tests/python/unittest/test_tir_schedule_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py index ced8d78ff1..f66dca30d9 100644 --- a/tests/python/unittest/test_tir_schedule_block_scope.py +++ b/tests/python/unittest/test_tir_schedule_block_scope.py @@ -84,7 +84,7 @@ def f_visit(node): def test_elementwise_dependency(): - s = tir.ScheduleState(elementwise, debug_mode=True) + s = tir.ScheduleState(elementwise, debug_mask="all") root = _get_block(s, "root") block_b = _get_block(s, "B") block_c = _get_block(s, "C") @@ -101,7 +101,7 @@ def test_elementwise_dependency(): def test_matmul_dependency(): - s = tir.ScheduleState(matmul, debug_mode=True) + s = tir.ScheduleState(matmul, debug_mask="all") root = _get_block(s, "root") init = _get_block(s, "init") update = _get_block(s, "update") @@ -126,7 +126,7 @@ def test_matmul_dependency(): def test_war_dependency(): - s = tir.ScheduleState(war_dependency, debug_mode=True) + s = tir.ScheduleState(war_dependency, debug_mask="all") root = _get_block(s, "root") block_c = _get_block(s, "C") block_b = _get_block(s, "B") diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index d6934c6f40..ea322920b8 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -21,6 +21,7 @@ import tvm from tvm import tir from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable @@ -223,34 +224,37 @@ def elementwise_multi_loads_inlined(a: ty.handle, c: ty.handle) -> None: def test_compute_inline_elementwise(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_compute_inline_under_loop(): - sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + sch = tir.Schedule(elementwise_under_loop, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" + verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) def test_compute_inline_as_dce(): - sch = tir.Schedule(elementwise_standalone, debug_mode=True) + sch = tir.Schedule(elementwise_standalone, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_standalone_dce, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" + verify_trace_roundtrip(sch=sch, mod=elementwise_standalone) def test_compute_inline_multi_consumer(): - sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") block_d = sch.get_block("D") @@ -258,101 +262,107 @@ def test_compute_inline_multi_consumer(): tvm.ir.assert_structural_equal(elementwise_multi_consumer_inlined, sch.mod["main"]) assert sch.get(block_c).name_hint == "C" assert sch.get(block_d).name_hint == "D" + verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer) def test_compute_inline_fail_multi_writer(): - sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True, error_render_level="detail") + sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_reverse_compute_inline_elementwise(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_reverse_compute_inline_under_loop(): - sch = tir.Schedule(elementwise_under_loop, debug_mode=True) + sch = tir.Schedule(elementwise_under_loop, debug_mask="all") block_b = sch.get_block("B") block_c = sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) assert sch.get(block_b).name_hint == "B" + verify_trace_roundtrip(sch=sch, mod=elementwise_under_loop) def test_reverse_compute_inline_fail_as_dce(): - sch = tir.Schedule(elementwise_standalone, debug_mode=True) + sch = tir.Schedule(elementwise_standalone, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_b) def test_reverse_compute_inline_fail_multi_producer(): - sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mode=True) + sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") block_d = sch.get_block("D") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_d) def test_reverse_compute_inline_fail_multi_reader(): - sch = tir.Schedule(fail_multi_reader_writer, debug_mode=True) + sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") block_c = sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) def test_reverse_compute_multi_reverse_loads(): - sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mode=True) + sch = tir.Schedule(elementwise_multi_reverse_loads, debug_mask="all") block_c = sch.get_block("C") sch.reverse_compute_inline(block_c) tvm.ir.assert_structural_equal(elementwise_multi_reverse_loads_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads) def test_reverse_compute_fail_multi_reverse_loads(): - sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_c = sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c) def test_opaque_access_load(): - sch = tir.Schedule(opaque_access_load, debug_mode=True) + sch = tir.Schedule(opaque_access_load, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_opaque_access_store(): - sch = tir.Schedule(opaque_access_store, debug_mode=True) + sch = tir.Schedule(opaque_access_store, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_buffer_matched(): - sch = tir.Schedule(buffer_matched, debug_mode=True) + sch = tir.Schedule(buffer_matched, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b) def test_compute_inline_predicate(): - sch = tir.Schedule(elementwise_predicate, debug_mode=True) + sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_predicate_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) def test_compute_inline_multi_loads(): - sch = tir.Schedule(elementwise_multi_loads, debug_mode=True) + sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_b = sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py index 6f56eb5988..6fcd0dc2ae 100644 --- a/tests/python/unittest/test_tir_schedule_error.py +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -42,7 +42,7 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: def test_tir_schedule_error_detail(): - sch = tir.Schedule(matmul, debug_mode=True, error_render_level="detail") + sch = tir.Schedule(matmul, debug_mask="all", error_render_level="detail") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg,) = excinfo.value.args @@ -50,7 +50,7 @@ def test_tir_schedule_error_detail(): def test_tir_schedule_error_fast(): - sch = tir.Schedule(matmul, debug_mode=True, error_render_level="fast") + sch = tir.Schedule(matmul, debug_mask="all", error_render_level="fast") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg,) = excinfo.value.args @@ -58,7 +58,7 @@ def test_tir_schedule_error_fast(): def test_tir_schedule_error_none(): - sch = tir.Schedule(matmul, debug_mode=True, error_render_level="none") + sch = tir.Schedule(matmul, debug_mask="all", error_render_level="none") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg,) = excinfo.value.args diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index b285f72ca5..6b4ac23503 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring import sys import numpy as np @@ -22,8 +23,9 @@ import tvm.testing from tvm import tir from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip -# pylint: disable=no-member,invalid-name,unused-variable,missing-function-docstring,missing-module-docstring +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @tvm.script.tir @@ -452,198 +454,144 @@ def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] -# pylint: enable=no-member,invalid-name,unused-variable +# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg def test_reduction_rfactor_matmul(): - s = tir.Schedule(transformed_matmul, debug_mode=True) - C = s.get_block("update") - _, _, _, _, kii = s.get_loops(C) + s = tir.Schedule(transformed_matmul, debug_mask="all") + _, _, _, _, kii = s.get_loops(s.get_block("update")) rf_block = s.rfactor(kii, 0) tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - - func = tvm.build(s.mod["main"], target="llvm") - a_np = np.random.uniform(size=(128, 128)).astype("float32") - b_np = np.random.uniform(size=(128, 128)).astype("float32") - a = tvm.nd.array(a_np) - b = tvm.nd.array(b_np) - c = tvm.nd.array(np.zeros((128, 128), dtype="float32")) - func(a, b, c) - c_np = np.matmul(a_np, b_np.T) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + verify_trace_roundtrip(s, mod=transformed_matmul) def test_reduction_rfactor_square_sum(): - s = tir.Schedule(square_sum, debug_mode=True) - C = s.get_block("C") - _, _, j = s.get_loops(C) + s = tir.Schedule(square_sum, debug_mask="all") + _, _, j = s.get_loops(s.get_block("C")) rf_block = s.rfactor(j, 1) tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - - func = tvm.build(s.mod["main"], target="llvm") - a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") - a = tvm.nd.array(a_np) - c = tvm.nd.array(np.zeros((16,), dtype="float32")) - func(a, c) - c_np = np.sum(a_np * a_np, axis=(1, 2)) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + verify_trace_roundtrip(s, mod=square_sum) def test_reduction_rfactor_square_sum_square_root(): - s = tir.Schedule(transformed_square_sum_square_root, debug_mode=True) - C = s.get_block("C") - _, _, fi = s.get_loops(C) - rf_block = s.rfactor(fi, 0) + s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all") + _, _, f_i = s.get_loops(s.get_block("C")) + rf_block = s.rfactor(f_i, 0) tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - - func = tvm.build(s.mod["main"], target="llvm") - a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") - a = tvm.nd.array(a_np) - d = tvm.nd.array(np.zeros((16,), dtype="float32")) - func(a, d) - d_np = np.sqrt(np.sum(a_np * a_np, axis=(1, 2))) - tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-4, atol=1e-4) + verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) def test_reduction_rfactor_loop_multiple_children(): - s = tir.Schedule(matmul_loop_multiple_children, debug_mode=True) - C = s.get_block("C") - k, _, _ = s.get_loops(C) + s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") + k, _, _ = s.get_loops(s.get_block("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_stage_pipeline(): - s = tir.Schedule(matmul_not_stage_pipeline, debug_mode=True) - C = s.get_block("C") - _, _, k = s.get_loops(C) + s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_reduction_block1(): - s = tir.Schedule(element_wise, debug_mode=True) - B = s.get_block("B") - i, _ = s.get_loops(B) + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(i, 0) def test_reduction_rfactor_not_reduction_block2(): - s = tir.Schedule(rowsum_not_quasi_affine, debug_mode=True) - B = s.get_block("B") - _, k = s.get_loops(B) + s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_reduction_block3(): - s = tir.Schedule(rowsum_not_dominant, debug_mode=True) - B = s.get_block("B") - _, k = s.get_loops(B) + s = tir.Schedule(rowsum_not_dominant, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_serial_loop(): - s = tir.Schedule(rowsum_not_serial, debug_mode=True) - B = s.get_block("B") - _, k = s.get_loops(B) + s = tir.Schedule(rowsum_not_serial, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_not_same_buffer_access(): - s = tir.Schedule(matmul_not_same_buffer_access, debug_mode=True) - C = s.get_block("C") - _, _, k = s.get_loops(C) + s = tir.Schedule(matmul_not_same_buffer_access, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) -def test_reduction_rfactor_factor_axis_range(): - s = tir.Schedule(transformed_matmul, debug_mode=True) - C = s.get_block("update") - _, _, _, _, kii = s.get_loops(C) +def test_reduction_rfactor_factor_axis_range_fail(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + _, _, _, _, kii = s.get_loops(s.get_block("update")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(kii, 3) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(kii, -4) + +def test_reduction_rfactor_factor_axis_range(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + _, _, _, _, kii = s.get_loops(s.get_block("update")) rf_block = s.rfactor(kii, -3) tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - - func = tvm.build(s.mod["main"], target="llvm") - a_np = np.random.uniform(size=(128, 128)).astype("float32") - b_np = np.random.uniform(size=(128, 128)).astype("float32") - a = tvm.nd.array(a_np) - b = tvm.nd.array(b_np) - c = tvm.nd.array(np.zeros((128, 128), dtype="float32")) - func(a, b, c) - c_np = np.matmul(a_np, b_np.T) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + verify_trace_roundtrip(s, mod=transformed_matmul) def test_reduction_rfactor_wrong_reduce_pattern1(): - s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mode=True) - B = s.get_block("B") - _, k = s.get_loops(B) + s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_wrong_reduce_pattern2(): - s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mode=True) - B = s.get_block("B") - _, k = s.get_loops(B) + s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0) def test_reduction_rfactor_wrong_loops1(): - s = tir.Schedule(rowsum, debug_mode=True) - B = s.get_block("B") - i, _ = s.get_loops(B) + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(i, 0) def test_reduction_rfactor_wrong_loops2(): - s = tir.Schedule(rowsum_transformed, debug_mode=True) - B = s.get_block("B") - _, _, ki = s.get_loops(B) + s = tir.Schedule(rowsum_transformed, debug_mask="all") + _, _, k_i = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(ki, 0) + s.rfactor(k_i, 0) def test_reduction_rfactor_zero_dim(): - s = tir.Schedule(rowsum_zero_dim, debug_mode=True) - B = s.get_block("B") - (k,) = s.get_loops(B) + s = tir.Schedule(rowsum_zero_dim, debug_mask="all") + (k,) = s.get_loops(s.get_block("B")) s.rfactor(k, 0) tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) + verify_trace_roundtrip(s, mod=rowsum_zero_dim) + - func = tvm.build(s.mod["main"], target="llvm") - a_np = np.random.uniform(size=(128,)).astype("float32") - a = tvm.nd.array(a_np) - b = tvm.nd.array(np.array(1, dtype="float32")) - func(a, b) - b_np = np.array(np.sum(a_np)) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-4, atol=1e-4) - - -def test_reduction_rfactor_outermost_loop_multiple_children(): - s = tir.Schedule(multiple_reduction_blocks, debug_mode=True) - D = s.get_block("D") - E = s.get_block("E") - F = s.get_block("F") - _, _, k2o, k2i = s.get_loops(D) - _, _, k3o, k3i = s.get_loops(E) - _, _, k4o, k4i = s.get_loops(F) +def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: disable=invalid-name + s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") + _, _, k2o, k2i = s.get_loops(s.get_block("D")) + _, _, k3o, k3i = s.get_loops(s.get_block("E")) + _, _, k4o, k4i = s.get_loops(s.get_block("F")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k2o, 0) with pytest.raises(tvm.tir.ScheduleError): @@ -657,18 +605,13 @@ def test_reduction_rfactor_outermost_loop_multiple_children(): with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k4i, 0) - C = s.get_block("C") - i, j1, k1o, k1i = s.get_loops(C) + +def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disable=invalid-name + s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") + _, _, k1o, _ = s.get_loops(s.get_block("C")) s.rfactor(k1o, 2) tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) - - func = tvm.build(s.mod["main"], target="llvm") - a_np = np.random.uniform(size=(16, 16, 16)).astype("float32") - a = tvm.nd.array(a_np) - f = tvm.nd.array(np.zeros((16, 16), dtype="float32")) - func(a, f) - f_np = np.sum(a_np, axis=2) * 4369 - tvm.testing.assert_allclose(f.numpy(), f_np, rtol=1e-4, atol=1e-4) + verify_trace_roundtrip(s, mod=multiple_reduction_blocks) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 9ac15b8c19..2284f9d996 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -21,6 +21,7 @@ import tvm from tvm import tir from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable @@ -322,55 +323,59 @@ def opaque_access_split(a: ty.handle, b: ty.handle) -> None: def test_fuse(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_split(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.split(i, factors=[2, 1, 64]) sch.split(j, factors=[4, 32]) sch.split(k, factors=[16, 8]) tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_split_with_inferred_factor(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.split(i, factors=[None, 1, 64]) sch.split(j, factors=[2, None, 64]) sch.split(k, factors=[2, 1, None]) tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_split_with_predicate(): - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.split(i, factors=[1000, 2, 3]) sch.split(j, factors=[None, 129]) sch.split(k, factors=[3, None]) tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) def test_fuse_fail_not_only_child(): - sch = tir.Schedule(elementwise_with_seq, debug_mode=True) + sch = tir.Schedule(elementwise_with_seq, debug_mask="all") block_b = sch.get_block("B") - i, j, k = sch.get_loops(block_b) + _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) def test_fuse_split_fail_with_annotation(): - sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + sch = tir.Schedule(elementwise_with_anno, debug_mask="all") block_b = sch.get_block("B") - i, j, k = sch.get_loops(block_b) + _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) with pytest.raises(tvm.tir.ScheduleError): @@ -378,9 +383,9 @@ def test_fuse_split_fail_with_annotation(): def test_fuse_split_fail_not_start_with_zero(): - sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + sch = tir.Schedule(elementwise_with_anno, debug_mask="all") block_b = sch.get_block("B") - i, j, k = sch.get_loops(block_b) + _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) with pytest.raises(tvm.tir.ScheduleError): @@ -388,15 +393,16 @@ def test_fuse_split_fail_not_start_with_zero(): def test_fuse_with_opaque_block(): - sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") block_opaque = sch.get_block("opaque") i, j, k = sch.get_loops(block_opaque) sch.fuse(i, j, k) tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block) def test_fuse_with_opaque_access(): - sch = tir.Schedule(opaque_access, debug_mode=True) + sch = tir.Schedule(opaque_access, debug_mask="all") block_a = sch.get_block("A") i, j = sch.get_loops(block_a) sch.fuse(i, j) @@ -404,31 +410,34 @@ def test_fuse_with_opaque_access(): i, j = sch.get_loops(block_b) sch.fuse(i, j) tvm.ir.assert_structural_equal(opaque_access_fused, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) def test_split_with_opaque_block(): - sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + sch = tir.Schedule(elementwise_with_opaque_block, debug_mask="all") block_opaque = sch.get_block("opaque") - i, j, k = sch.get_loops(block_opaque) + i, _, _ = sch.get_loops(block_opaque) sch.split(i, factors=[None, 16]) tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_with_opaque_block) def test_split_with_opaque_access(): - sch = tir.Schedule(opaque_access, debug_mode=True) + sch = tir.Schedule(opaque_access, debug_mask="all") block_a = sch.get_block("A") - i, j = sch.get_loops(block_a) + _, j = sch.get_loops(block_a) sch.split(j, factors=[None, 4]) block_b = sch.get_block("B") - i, j = sch.get_loops(block_b) + _, j = sch.get_loops(block_b) sch.split(j, factors=[None, 4]) tvm.ir.assert_structural_equal(opaque_access_split, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) def test_fuse_split_fail_with_thread_binding(): - sch = tir.Schedule(elementwise_with_thread_binding, debug_mode=True) + sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all") block_b = sch.get_block("B") - i, j, k = sch.get_loops(block_b) + _, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.fuse(j, k) with pytest.raises(tvm.tir.ScheduleError): @@ -436,19 +445,21 @@ def test_fuse_split_fail_with_thread_binding(): def test_fuse_symbolic(): - sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + sch = tir.Schedule(elementwise_symbolic, debug_mask="all") block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) def test_split_symbolic(): - sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + sch = tir.Schedule(elementwise_symbolic, debug_mask="all") block_b = sch.get_block("B") - i, j, k = sch.get_loops(block_b) + _, _, k = sch.get_loops(block_b) sch.split(k, factors=[10, None]) tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index ca2ee796a2..856d6a5c17 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -78,7 +78,7 @@ def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None: def replace_ir_builder(deep_copy=False, realize=False): new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) - s = tir.ScheduleState(new_func, debug_mode=True) + s = tir.ScheduleState(new_func, debug_mask="all") target = tvm.tir.Block( iter_vars=[], reads=[], @@ -106,7 +106,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) other_func = tvm.script.from_source(tvm.script.asscript(elementwise)) mod = IRModule(functions={"main": new_func, "other": other_func}) - s = tir.ScheduleState(mod, debug_mode=True) + s = tir.ScheduleState(mod, debug_mask="all") target = tvm.tir.Block( iter_vars=[], reads=[], @@ -132,7 +132,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): def replace_ir_builder_with_opaque(): func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block)) - s = tir.ScheduleState(func, debug_mode=True) + s = tir.ScheduleState(func, debug_mask="all") gc.collect() return s @@ -292,7 +292,7 @@ def test_replace_root_copy3(): def test_replace_block_remap(): func = elementwise - s = tir.ScheduleState(func, debug_mode=True) + s = tir.ScheduleState(func, debug_mask="all") # The target stmt target = matmul.body.block.body.body.body[0].block sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index f77ec0318e..075b6cd689 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -282,7 +282,7 @@ def f_visit(node): def test_elementwise(): - s = tir.ScheduleState(elementwise, debug_mode=True) + s = tir.ScheduleState(elementwise, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -303,7 +303,7 @@ def test_elementwise(): def test_matmul(): - s = tir.ScheduleState(matmul, debug_mode=True) + s = tir.ScheduleState(matmul, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "init")) == CachedFlags( affine_binding=True, @@ -324,7 +324,7 @@ def test_matmul(): def test_block_in_opaque_block(): - s = tir.ScheduleState(block_in_opaque_block, debug_mode=True) + s = tir.ScheduleState(block_in_opaque_block, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -355,7 +355,7 @@ def test_block_in_opaque_block(): def test_write_after_read(): - s = tir.ScheduleState(write_after_read, debug_mode=True) + s = tir.ScheduleState(write_after_read, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -376,7 +376,7 @@ def test_write_after_read(): def test_loop_carried_dependency(): - s = tir.ScheduleState(loop_carried_dependency, debug_mode=True) + s = tir.ScheduleState(loop_carried_dependency, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -397,7 +397,7 @@ def test_loop_carried_dependency(): def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name - s = tir.ScheduleState(concatenate_multi_producer, debug_mode=True) + s = tir.ScheduleState(concatenate_multi_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, @@ -423,7 +423,7 @@ def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name def test_concatenate_multi_producer_uncovered(): # pylint: disable=invalid-name - s = tir.ScheduleState(concatenate_multi_producer_uncovered, debug_mode=True) + s = tir.ScheduleState(concatenate_multi_producer_uncovered, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, @@ -449,7 +449,7 @@ def test_concatenate_multi_producer_uncovered(): # pylint: disable=invalid-name def test_lca_at_loop(): - s = tir.ScheduleState(lca_at_loop, debug_mode=True) + s = tir.ScheduleState(lca_at_loop, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( affine_binding=True, @@ -470,7 +470,7 @@ def test_lca_at_loop(): def test_multi_producer_consumer(): - s = tir.ScheduleState(multi_producer_consumer, debug_mode=True) + s = tir.ScheduleState(multi_producer_consumer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( affine_binding=True, @@ -496,7 +496,7 @@ def test_multi_producer_consumer(): def test_elementwise_affine_producer(): - s = tir.ScheduleState(elementwise_affine_producer, debug_mode=True) + s = tir.ScheduleState(elementwise_affine_producer, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -517,7 +517,7 @@ def test_elementwise_affine_producer(): def test_subblock(): - s = tir.ScheduleState(elementwise_subblock, debug_mode=True) + s = tir.ScheduleState(elementwise_subblock, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -543,7 +543,7 @@ def test_subblock(): def test_subblock_uncovered(): - s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mode=True) + s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -569,7 +569,7 @@ def test_subblock_uncovered(): def test_thread_binding(): - s = tir.ScheduleState(bound_to_thread, debug_mode=True) + s = tir.ScheduleState(bound_to_thread, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -590,7 +590,7 @@ def test_thread_binding(): def test_equal_ranked_threads(): - s = tir.ScheduleState(equal_ranked_threads, debug_mode=True) + s = tir.ScheduleState(equal_ranked_threads, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -611,7 +611,7 @@ def test_equal_ranked_threads(): def test_warp_memory(): - s = tir.ScheduleState(warp_memory, debug_mode=True) + s = tir.ScheduleState(warp_memory, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, @@ -632,7 +632,7 @@ def test_warp_memory(): def test_warp_memory_negative(): - s = tir.ScheduleState(warp_memory_negative, debug_mode=True) + s = tir.ScheduleState(warp_memory_negative, debug_mask="all") # pylint: disable=protected-access assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( affine_binding=True, diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index cafc6fe1d2..da7b096ade 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -169,7 +169,7 @@ def test_trace_construct_pop_2(): def test_trace_apply_to_schedule(): trace = _make_trace_2(BlockRV()) - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") trace.apply_to_schedule(sch, remove_postproc=False, decision_provider=None) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) @@ -232,7 +232,7 @@ def test_trace_simplified_2(): def test_apply_json_to_schedule_1(): trace = _make_trace_2(BlockRV()) json_obj = trace.as_json() - sch = tir.Schedule(elementwise, debug_mode=True) + sch = tir.Schedule(elementwise, debug_mask="all") Trace.apply_json_to_schedule(json_obj, sch) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 07658978db..dcaeaaad61 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -19,9 +19,11 @@ import pytest import tvm + from tvm import tir from tvm.ir import IRModule from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable @@ -47,8 +49,8 @@ def test_tir_schedule_creation(): # - Schedule.__init__ for PrimFunc and IRModule # - Schedule.mod # - Schedule.state - sch_1 = tir.Schedule(matmul, debug_mode=True) - sch_2 = tir.Schedule(IRModule({"main": matmul}), debug_mode=True) + sch_1 = tir.Schedule(matmul, debug_mask="all") + sch_2 = tir.Schedule(IRModule({"main": matmul}), debug_mask="all") assert sch_1.mod["main"].same_as(sch_2.mod["main"]) assert sch_1.state.mod["main"].same_as(sch_2.state.mod["main"]) @@ -58,7 +60,7 @@ def test_tir_schedule_get_block(): # - Schedule.get_block # - Schedule.get_sref # - Schedule.get - sch = tir.Schedule(matmul, debug_mode=True) + sch = tir.Schedule(matmul, debug_mask="all") block_rv = sch.get_block(name="update") block_sref = sch.get_sref(block_rv) block = sch.get(block_rv) @@ -72,7 +74,7 @@ def test_tir_schedule_get_loops(): # Tests: # - Schedule.get_loops # - Schedule.get - sch = tir.Schedule(matmul, debug_mode=True) + sch = tir.Schedule(matmul, debug_mask="all") block_rv = sch.get_block(name="update") i, j, k = sch.get_loops(block_rv) assert sch.get(i).loop_var.name == "i" @@ -80,10 +82,10 @@ def test_tir_schedule_get_loops(): assert sch.get(k).loop_var.name == "k" -def test_tir_schedule_copy(): +def test_tir_schedule_copy_1(): # Tests: # - Schedule.copy - sch_1 = tir.Schedule(matmul, debug_mode=True) + sch_1 = tir.Schedule(matmul, debug_mask="all") block_rv = sch_1.get_block(name="update") i, j, k = sch_1.get_loops(block_rv) assert sch_1.get(i).loop_var.name == "i" @@ -97,10 +99,40 @@ def test_tir_schedule_copy(): assert sch_2.get(k).loop_var.name == "k" +def test_tir_schedule_copy_2(): + sch = tir.Schedule(mod=matmul, debug_mask="all") + i, j, k = sch.get_loops(sch.get_block("update")) + sch_copy = sch.copy() + assert not sch.get_sref(i).same_as(sch_copy.get_sref(i)) + assert not sch.get_sref(j).same_as(sch_copy.get_sref(j)) + assert not sch.get_sref(k).same_as(sch_copy.get_sref(k)) + assert sch.get_sref(i).stmt.same_as(sch_copy.get_sref(i).stmt) + assert sch.get_sref(j).stmt.same_as(sch_copy.get_sref(j).stmt) + assert sch.get_sref(k).stmt.same_as(sch_copy.get_sref(k).stmt) + i_0, i_1 = sch.split(i, factors=[None, 64]) + j_0, j_1 = sch_copy.split(j, factors=[None, 32]) + + assert sch.get_sref(i_0).stmt.extent == 2 + assert sch.get_sref(i_1).stmt.extent == 64 + with pytest.raises(IndexError): + sch_copy.get_sref(i_0) + with pytest.raises(IndexError): + sch_copy.get_sref(i_1) + + with pytest.raises(IndexError): + sch.get_sref(j_0) + with pytest.raises(IndexError): + sch.get_sref(j_1) + assert sch_copy.get_sref(j_0).stmt.extent == 4 + assert sch_copy.get_sref(j_1).stmt.extent == 32 + verify_trace_roundtrip(sch, mod=matmul) + verify_trace_roundtrip(sch_copy, mod=matmul) + + def test_tir_schedule_remove_rv(): # Tests: # - Schedule.remove_rv - sch = tir.Schedule(matmul, debug_mode=True) + sch = tir.Schedule(matmul, debug_mask="all") block_rv = sch.get_block(name="update") assert sch.get(block_rv).name_hint == "update" sch.remove_rv(block_rv)