diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h new file mode 100644 index 000000000000..48eb051d7813 --- /dev/null +++ b/include/tvm/tir/schedule/schedule.h @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ +#define TVM_TIR_SCHEDULE_SCHEDULE_H_ + +#include + +namespace tvm { +namespace tir { + +/**************** Random variable: BlockRV ****************/ + +/*! \brief A random variable that evaluates to a TensorIR block */ +class BlockRVNode : public runtime::Object { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "tir.BlockRV"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object); +}; + +/*! + * \brief Managed reference to BlockRVNode + * \sa BlockRVNode + */ +class BlockRV : public runtime::ObjectRef { + public: + /*! \brief Constructor */ + TVM_DLL BlockRV(); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode); +}; + +/**************** Random variable: LoopRV ****************/ + +/*! \brief A random variable that evaluates to a TensorIR for loop */ +class LoopRVNode : public runtime::Object { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "tir.LoopRV"; + TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object); +}; + +/*! + * \brief Managed reference to LoopRVNode + * \sa LoopRVNode + */ +class LoopRV : public runtime::ObjectRef { + public: + /*! \brief Constructor */ + TVM_DLL LoopRV(); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode); +}; + +/**************** Random variable: IntRV ****************/ + +/*! \brief An integer random variable */ +using IntRV = PrimExpr; + +using IntRVNode = PrimExprNode; + +/**************** The Schedule class ****************/ + +class Schedule; + +/*! \brief The user-facing schedule class */ +class ScheduleNode : public runtime::Object { + friend class Schedule; + + public: + virtual ~ScheduleNode() = default; + + static constexpr const char* _type_key = "tir.Schedule"; + TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object); + + public: + /*! \brief Get the IRModule associated with this schedule. */ + virtual IRModule mod() const { return state()->mod; } + /*! \return The internal state of scheduling */ + virtual ScheduleState state() const = 0; + /*! + * \brief Returns a copy of the schedule, including both its state and its symbol table, + * guaranteeing that + * 1) SRef tree is completely reconstructed; + * 2) The IRModule being scheduled is not modified; + * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * reconstructed + */ + virtual Schedule Copy() const = 0; + /*! + * \brief Seed the randomness + * \param seed The new random seed, -1 if use device random, otherwise non-negative + */ + virtual void Seed(int64_t seed = -1) { + LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; + } + + public: + /******** Lookup/Remove random variables ********/ + /*! + * \brief Get the block corresponding to the specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return The corresponding block + */ + virtual Block Get(const BlockRV& block_rv) const = 0; + /*! + * \brief Get the for loop corresponding to the specific LoopRV + * \param loop_rv The LoopRV to be looked up + * \return The corresponding for loop + */ + virtual For Get(const LoopRV& loop_rv) const = 0; + /*! + * \brief Get the value corresponding to the specific random variable + * \param int_rv The random variable to be looked up + * \return The corresponding value + */ + virtual int64_t Get(const IntRV& int_rv) const = 0; + /*! + * \brief Get the block sref corresponding to the specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return The corresponding block sref + */ + virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0; + /*! + * \brief Get the loop sref corresponding to the specific LoopRV + * \param loop_rv The LoopRV to be looked up + * \return The corresponding loop sref + */ + virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Get the block/loop sref corresponding to the specific statement + * \param stmt The statement to be looked up + * \return The corresponding block/loop sref + */ + virtual StmtSRef GetSRef(const StmtNode* stmt) const; + /*! + * \brief Get the block/loop sref corresponding to the specific statement + * \param stmt The statement to be looked up + * \return The corresponding block/loop sref + */ + StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); } + /*! + * \brief Remove a block random variable from the symbol table + * \param block_rv The random variable to be removed + */ + virtual void RemoveRV(const BlockRV& block_rv) = 0; + /*! + * \brief Remove a loop random variable from the symbol table + * \param loop_rv The random variable to be removed + */ + virtual void RemoveRV(const LoopRV& loop_rv) = 0; + /*! + * \brief Remove an integer random variable from the symbol table + * \param int_rv The random variable to be removed + */ + virtual void RemoveRV(const IntRV& int_rv) = 0; + + public: + /******** Block/Loop relation ********/ + /*! + * \brief Retrieve a block in a specific function with its name + * \param name The name of the block to be retrieved + * \param func_name The name of the function + * \return The block retrieved + * \note Indexing error is raised if 0 or multiple blocks exist with the specific name + */ + virtual BlockRV GetBlock(const String& name, const String& func_name = "main") = 0; + /*! + * \brief Get the parent loops of the block in its scope, from outer to inner + * \param block_rv The query block + * \return A list of loops above the given block in its scope, from outer to inner + */ + virtual Array GetLoops(const BlockRV& block_rv) = 0; +}; + +/*! + * \brief Managed reference to ScheduleNode + * + * A schedule is a set of transformations that change the order of computation but + * preserve the semantics of computation. Some example of schedules: + * 1) Split a loop into two; + * 2) Reorder two loops; + * 3) Inline the computation of a specific buffer into its consumer + * + * The schedule class stores auxiliary information to schedule correctly and efficiently. + * + * Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html + * + * \sa ScheduleNode + */ +class Schedule : public runtime::ObjectRef { + public: + /*! + * \brief Construct a concrete TensorIR schedule from an IRModule + * \param mod The IRModule to be scheduled + * \param debug_mode Do extra correctness checking after the class creation + * and each time after calling the Replace method. + * \return The concrete schedule created + * \sa ScheduleDebugMask + * \note The checks performed includes: + * 1) VerifySRefTree + * 2) VerifyAffineBinding + * 3) VerifyRegionCover + * 4) VerifyStagePipeline + */ + TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_SCHEDULE_H_ diff --git a/python/tvm/error.py b/python/tvm/error.py index 819f06475e0a..5502fe8e071e 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -59,6 +59,7 @@ def __init__(self, msg): register_error("TypeError", TypeError) register_error("AttributeError", AttributeError) register_error("KeyError", KeyError) +register_error("IndexError", IndexError) @register_error diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index b348da893d64..afe521a74361 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -48,7 +48,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift -from .schedule import StmtSRef, BlockScope, ScheduleState +from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule from . import schedule from . import ir_builder diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 21721f70b5bf..bd1f6b3ead03 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -19,3 +19,4 @@ from .block_scope import BlockScope, Dependency, DepKind, StmtSRef from .state import ScheduleDebugMask, ScheduleState +from .schedule import LoopRV, BlockRV, IntRV, RAND_VAR_TYPE, Schedule diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py new file mode 100644 index 000000000000..4a8482541c5b --- /dev/null +++ b/python/tvm/tir/schedule/schedule.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""The TensorIR schedule class""" +from typing import List, Optional, Union + +from tvm._ffi import register_object as _register_object +from tvm.ir import IRModule, PrimExpr +from tvm.runtime import Object +from tvm.tir import Block, For, IntImm, PrimFunc, Var + +from . import _ffi_api_schedule +from .state import ScheduleState, StmtSRef + + +@_register_object("tir.LoopRV") +class LoopRV(Object): + """A random variable that refers to a loop""" + + +@_register_object("tir.BlockRV") +class BlockRV(Object): + """A random variable that refers to a block""" + + +IntRV = PrimExpr # A random variable that evaluates to an integer + +RAND_VAR_TYPE = Union[IntRV, BlockRV, LoopRV] # pylint: disable=invalid-name + + +@_register_object("tir.Schedule") +class Schedule(Object): + """The user-facing schedule class + + A schedule is a set of transformations that change the order of computation but + preserve the semantics of computation. Some example of schedules: + 1) Split a loop into two; + 2) Reorder two loops; + 3) Inline the computation of a specific buffer into its consumer + + The schedule class stores auxiliary information to schedule correctly and efficiently. + + Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html + """ + + def __init__( + self, + func_or_mod: Union[PrimFunc, IRModule], + debug_mode: Union[bool, int] = False, + ): + """Construct a concrete TensorIR schedule from an IRModule or a PrimFunc + + Parameters + ---------- + func_or_mod : Union[PrimFunc, IRModule] + The IRModule or PrimFunc to be scheduled + debug_mode : Union[bool, int] + Do extra correctness checking after the class creation and each time + scheduling primitive + + Note + ---------- + The checks performed includes: + 1) VerifySRefTree + 2) VerifyAffineBinding + 3) VerifyRegionCover + 4) VerifyStagePipeline + """ + if isinstance(debug_mode, bool): + if debug_mode: + debug_mode = -1 + else: + debug_mode = 0 + if not isinstance(debug_mode, int): + raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") + self.__init_handle_by_constructor__( + _ffi_api_schedule.ConcreteSchedule, # pylint: disable=no-member + func_or_mod, + debug_mode, + ) + + ########## Utilities ########## + + @property + def mod(self) -> IRModule: + """Returns the AST of the module being scheduled""" + return _ffi_api_schedule.ScheduleModule(self) # pylint: disable=no-member + + @property + def state(self) -> ScheduleState: + """Returns the ScheduleState in the current schedule class""" + return _ffi_api_schedule.ScheduleGetState(self) # pylint: disable=no-member + + def copy(self) -> "Schedule": + """Returns a copy of the schedule, including both the state and the symbol table, + * guaranteeing that + * 1) SRef tree is completely reconstructed; + * 2) The IRModule being scheduled is untouched; + * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * reconstructed + Returns + ------- + copy : Schedule + A new copy of the schedule + """ + return _ffi_api_schedule.ScheduleCopy(self) # pylint: disable=no-member + + def seed(self, seed: int) -> None: + """Seed the randomness + Parameters + ---------- + seed : int + The new random seed, -1 if use device random, otherwise non-negative + """ + return _ffi_api_schedule.ScheduleSeed(self, seed) # pylint: disable=no-member + + def show(self, rand_var: RAND_VAR_TYPE) -> str: + """Returns a string representation of the value that the random variable evaluates to + Parameters + ---------- + rand_var : Union[IntRV, BlockRV, LoopRV] + The random variable to be evaluated + Returns + ---------- + str_repr : str + The string representation + """ + return str(self.get(rand_var)) + + ########## Lookup ########## + + def get( + self, + rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef], + ) -> Optional[Union[int, Block, For]]: + """Returns: + - the corresponding Block that a BlockRV evaluates to; + - the corresponding For that a LoopRV evaluates to; + - the corresponding integer that a IntRV evaluates to; + - the corresponding Block that a block sref points to; + - the corresponding For that a loop sref points to; + Parameters + ---------- + rand_var_or_sref : Union[IntRV, BlockRV, LoopRV, StmtSRef] + The random variable / sref to be evaluated + Returns + ---------- + result : Optional[Union[int, Block, For]] + The correpsonding result + """ + if isinstance(rand_var_or_sref, StmtSRef): + return rand_var_or_sref.stmt + result = _ffi_api_schedule.ScheduleGet(self, rand_var_or_sref) # pylint: disable=no-member + if isinstance(result, IntImm): + result = result.value + return result + + def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]: + """Returns the correpsonding sref to the given + 1) LoopRV + 2) BlockRV + 3) Block + 4) For + Parameters + ---------- + rand_var_or_stmt : Union[BlockRV, LoopRV, Block, For] + The random variable / sref to be evaluated + Returns + ---------- + result : Optional[StmtSRef] + The correpsonding result + """ + return _ffi_api_schedule.ScheduleGetSRef( # pylint: disable=no-member + self, rand_var_or_stmt + ) + + def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: + """Remove a random variable from the symbol table + Parameters + ---------- + rand_var : Union[BlockRV, LoopRV, IntRV] + The random variable to be removed + """ + return _ffi_api_schedule.ScheduleRemoveRV(self, rand_var) # pylint: disable=no-member + + ########## Block/Loop relation ########## + + def get_block( + self, + name: str, + func_name: str = "main", + ) -> BlockRV: + """Retrieve a block in a specific function with its name + Parameters + ---------- + name : str + The name of the block + func_name : str = "main" + The name of the function + Returns + ---------- + block : BlockRV + The block retrieved + IndexError is raised if 0 or multiple blocks exist with the specific name. + """ + return _ffi_api_schedule.ScheduleGetBlock( # pylint: disable=no-member + self, + name, + func_name, + ) + + def get_loops(self, block: BlockRV) -> List[LoopRV]: + """Get the parent loops of the block in its scope, from outer to inner + Parameters + ---------- + block : BlockRV + The query block + Returns + ---------- + loops : List[LoopRV] + A list of loops above the given block in its scope, from outer to inner + """ + return _ffi_api_schedule.ScheduleGetLoops(self, block) # pylint: disable=no-member + + +@_register_object("tir.ConcreteSchedule") +class ConcreteSchedule(Schedule): + """A concrete schedule class of TensorIR. Do not use directly, use tvm.tir.Schedule instead.""" diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 32d9f6d4cb51..b21139d37e1f 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -33,6 +33,21 @@ namespace tir { void VerifySRefTree(const ScheduleState& self); /******** Block-loop relation ********/ +/*! + * \brief Retrieve blocks in a specific function with its name + * \param self The schedule state + * \param name The name of the blocks to be retrieved + * \param func_name The name of the function + * \return A list of blocks with the specific name + */ +Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); +/*! + * \brief Get the parent loops of the block in its scope, from outer to inner + * \param self The schedule state + * \param block_sref The query block + * \return A list of loops above the given block in its scope, from outer to inner + */ +Array GetLoops(const StmtSRef& block_sref); /*! * \brief Get the leaf blocks of a scope where a specific block/loop is in * \param self The schedule state diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 005ff373106f..08e7ac749e0f 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -23,6 +23,40 @@ namespace tir { /******** Block-loop relation ********/ +Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { + struct Finder : public StmtVisitor { + explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} + + void VisitStmt_(const BlockNode* block) override { + if (block->name_hint == name_) { + auto it = self_->stmt2ref.find(block); + ICHECK(it != self_->stmt2ref.end()); + results_.push_back(it->second); + } + StmtVisitor::VisitStmt_(block); + } + + const ScheduleState& self_; + const String& name_; + Array results_; + }; + + BaseFunc func = self->mod->Lookup(func_name); + const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); + Finder finder(self, name); + finder(prim_func->body); + return std::move(finder.results_); +} + +Array GetLoops(const StmtSRef& block_sref) { + std::vector result; + for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); + parent = parent->parent) { + result.push_back(GetRef(parent)); + } + return {result.rbegin(), result.rend()}; +} + Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { public: diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc new file mode 100644 index 000000000000..ef12f10fa924 --- /dev/null +++ b/src/tir/schedule/concrete_schedule.cc @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./concrete_schedule.h" + +namespace tvm { +namespace tir { + +Schedule Schedule::Concrete(IRModule mod, int debug_mode) { + ObjectPtr n = make_object(); + n->state_ = ScheduleState(mod, debug_mode); + n->symbol_table_ = {}; + n->analyzer_ = std::make_unique(); + return Schedule(std::move(n)); +} + +/******** Copy ********/ + +/*! \brief Helper class to perform a deep copy of the sref tree */ +class ScheduleCopier { + using TSymbolTable = ConcreteScheduleNode::TSymbolTable; + template + using UMap = std::unordered_map; + template + using SMap = std::unordered_map; + + public: + static void Copy(const ConcreteScheduleNode* self, ScheduleState* new_state, + TSymbolTable* new_symbol_table) { + const ScheduleState& src_state = self->state_; + ScheduleCopier copier(src_state); + ObjectPtr n = make_object(); + n->mod = src_state->mod; + n->block_info = copier.Copy(src_state->block_info); + n->stmt2ref = copier.Copy(src_state->stmt2ref); + n->debug_mode = src_state->debug_mode; + *new_state = ScheduleState(std::move(n)); + *new_symbol_table = copier.Copy(self->symbol_table_); + } + + private: + /*! \brief Create the copier and properly set up the `old2new_` table */ + explicit ScheduleCopier(const ScheduleState& state) { + // Create SRef tree without parents + for (const auto& kv : state->stmt2ref) { + const StmtSRefNode* sref = kv.second.operator->(); + old2new_.emplace(sref, // the old StmtSRef + StmtSRef(/*stmt=*/sref->stmt, // the new StmtSRef + /*parent=*/nullptr, // parent is not set yet + /*seq_index=*/sref->seq_index)); + } + // Fill in the parent field + // Find out the root along the way + for (auto& kv : old2new_) { + const StmtSRefNode* parent = kv.first->parent; + StmtSRef& sref = kv.second; + sref->parent = parent ? old2new_.at(parent).get() : nullptr; + } + } + + /*! \brief Copy StmtSRef */ + StmtSRef Copy(const StmtSRef& sref) { return old2new_.at(sref.operator->()); } + + /*! \brief Copy StmtSRefNode */ + StmtSRef Copy(const StmtSRefNode* sref) { + if (old2new_.count(sref)) { + return old2new_.at(sref); + } + // Handle expired sref + return old2new_[sref] = StmtSRef(nullptr, nullptr, -1); + } + + /*! \brief Copy Array */ + Array Copy(const Array& list) { + Array result; + result.reserve(list.size()); + for (const StmtSRef& elem : list) { + result.push_back(Copy(elem)); + } + return result; + } + + /*! \brief Copy Array */ + Array Copy(const Array& list) { + Array result; + result.reserve(list.size()); + for (const Dependency& elem : list) { + result.push_back(Dependency(Copy(elem->src), Copy(elem->dst), elem->kind)); + } + return result; + } + + /*! \brief Copy SMap> */ + SMap> Copy(const SMap>& map) { + SMap> result; + result.reserve(map.size()); + for (const auto& kv : map) { + result[Copy(kv.first)] = Copy(kv.second); + } + return result; + } + + /*! \brief Copy SMap> */ + SMap> Copy(const SMap>& map) { + SMap> result; + result.reserve(map.size()); + for (const auto& kv : map) { + result[kv.first] = Copy(kv.second); + } + return result; + } + + /*! \brief Copy SMap */ + SMap Copy(const SMap& scopes) { + SMap result; + for (const auto& kv : scopes) { + const StmtSRef& old_sref = kv.first; + const BlockInfo& old_info = kv.second; + BlockInfo new_info = old_info; + ObjectPtr scope = make_object(); + scope->src2deps = Copy(old_info.scope->src2deps); + scope->dst2deps = Copy(old_info.scope->dst2deps); + scope->buffer_writers = Copy(old_info.scope->buffer_writers); + new_info.scope = BlockScope(std::move(scope)); + result[Copy(old_sref)] = std::move(new_info); + } + return result; + } + + /*! \brief Copy the stmt2ref */ + UMap Copy(const UMap& stmt2ref) { + UMap result; + result.reserve(stmt2ref.size()); + for (const auto& kv : stmt2ref) { + const StmtNode* stmt = kv.first; + const StmtSRef& sref = kv.second; + result.emplace(stmt, Copy(sref)); + } + return result; + } + + /*! \brief Copy the symbol table */ + TSymbolTable Copy(const TSymbolTable& tab) { + TSymbolTable result; + for (const auto& kv : tab) { + ObjectRef entry = kv.second; + if (const auto* sref = entry.as()) { + entry = Copy(sref); + } + result.Set(kv.first, entry); + } + return result; + } + + private: + std::unordered_map old2new_; +}; + +void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const { + ScheduleCopier::Copy(this, new_state, new_symbol_table); +} + +Schedule ConcreteScheduleNode::Copy() const { + ObjectPtr n = make_object(); + Copy(&n->state_, &n->symbol_table_); + n->analyzer_ = std::make_unique(); + return Schedule(std::move(n)); +} + +/******** Block/Loop relation ********/ + +BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { + Array blocks = tir::GetBlocks(this->state_, name, func_name); + CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size() + << " blocks with the name: " << name; + return CreateRV(blocks[0]); +} + +Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { + return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(ConcreteScheduleNode); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h new file mode 100644 index 000000000000..06164fb7babc --- /dev/null +++ b/src/tir/schedule/concrete_schedule.h @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ +#define TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ + +#include +#include + +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace tir { + +class ConcreteScheduleNode : public ScheduleNode { + friend class Schedule; + friend class ScheduleCopier; + + public: + using TSymbolTable = Map; + + protected: + /*! \brief The internal state of scheduling */ + ScheduleState state_; + /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ + TSymbolTable symbol_table_; + /*! \brief A persistent stateless arithmetic analyzer. */ + std::unique_ptr analyzer_; + + public: + void VisitAttrs(tvm::AttrVisitor* v) { + // `state_` is not visited + // `symbol_table_` is not visited + // `analyzer_` is not visitied + } + + virtual ~ConcreteScheduleNode() = default; + + static constexpr const char* _type_key = "tir.ConcreteSchedule"; + TVM_DECLARE_BASE_OBJECT_INFO(ConcreteScheduleNode, ScheduleNode); + + public: + ScheduleState state() const final { return state_; } + Schedule Copy() const override; + + public: + /******** Lookup random variables ********/ + inline Block Get(const BlockRV& block_rv) const final; + inline For Get(const LoopRV& loop_rv) const final; + inline int64_t Get(const IntRV& int_rv) const final; + inline StmtSRef GetSRef(const BlockRV& block_rv) const final; + inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } + void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } + void RemoveRV(const IntRV& int_rv) final { RemoveFromSymbolTable(int_rv); } + using ScheduleNode::GetSRef; + + public: + /******** Block/Loop relation ********/ + BlockRV GetBlock(const String& name, const String& func_name = "main") override; + Array GetLoops(const BlockRV& block_rv) override; + + /******** Utility functions ********/ + protected: + /*! + * \brief Copy the schedule state, as well as the symbol table + * \param new_state The ScheduleState copied + * \param new_symbol_table The symbol table copied + */ + void Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const; + /*! + * \brief Add srefs as random variables into the symbol table + * \tparam T The type of the random variables + * \param srefs The srefs to be added to the symbol table + * \return The new random variables created + */ + template + inline Array CreateRV(const Array& srefs); + /*! + * \brief Add an sref as a random variable into the symbol table + * \tparam T The type of the random variable + * \param sref The sref to be added to the symbol table + * \return The new random variable created + */ + template + inline T CreateRV(const StmtSRef& sref); + /*! + * \brief Add an integer as a random variable into the symbol table + * \param number The integer to be added to the symbol table + * \return The new random variable created + */ + inline IntRV CreateRV(int64_t number); + /*! + * \brief Add integers as random variables into the symbol table + * \param numbers The integers to be added to the symbol table + * \return The new random variables created + */ + inline Array CreateRV(const Array& numbers); + /*! \brief Remove a random variable from the symbol table */ + inline void RemoveFromSymbolTable(const ObjectRef& rv); +}; + +// implementations + +/******** Lookup random variables ********/ + +inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { + StmtSRef sref = this->GetSRef(block_rv); + const auto* block = TVM_SREF_TO_BLOCK(block, sref); + return GetRef(block); +} + +inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { + StmtSRef sref = this->GetSRef(loop_rv); + const auto* loop = TVM_SREF_TO_FOR(loop, sref); + return GetRef(loop); +} + +inline int64_t ConcreteScheduleNode::Get(const IntRV& int_rv) const { + auto it = this->symbol_table_.find(int_rv); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding IntRV: " << int_rv; + } + const ObjectRef& obj = (*it).second; + const auto* int_imm = obj.as(); + if (int_imm == nullptr) { + LOG(FATAL) << "ValueError: IntRV's corresponding type is invalid: " + << (obj.defined() ? obj->GetTypeKey() : "None"); + } + return int_imm->value; +} + +inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { + auto it = this->symbol_table_.find(block_rv); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv; + } + const ObjectRef& obj = (*it).second; + const auto* sref = obj.as(); + if (sref == nullptr) { + LOG(FATAL) << "ValueError: BlockRV's corresponding type is invalid: " + << (obj.defined() ? obj->GetTypeKey() : "None"); + } + if (sref->stmt == nullptr) { + LOG(FATAL) << "ValueError: The StmtSRef has expired"; + } + return GetRef(sref); +} + +inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { + static StmtSRef inline_mark = StmtSRef::InlineMark(); + static StmtSRef root_mark = StmtSRef::RootMark(); + auto it = this->symbol_table_.find(loop_rv); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv; + } + const ObjectRef& obj = (*it).second; + if (obj.same_as(inline_mark)) { + return inline_mark; + } + if (obj.same_as(root_mark)) { + return root_mark; + } + const auto* sref = obj.as(); + if (sref == nullptr) { + LOG(FATAL) << "ValueError: LoopRV's corresponding type is invalid: " + << (obj.defined() ? obj->GetTypeKey() : "None"); + } + if (sref->stmt == nullptr) { + LOG(FATAL) << "ValueError: The StmtSRef has expired"; + } + return GetRef(sref); +} + +/******** Adding/Removing elements in the symbol table ********/ + +template +inline Array ConcreteScheduleNode::CreateRV(const Array& srefs) { + Array result; + result.reserve(srefs.size()); + for (const StmtSRef& sref : srefs) { + T rv; + this->symbol_table_.Set(rv, sref); + result.push_back(rv); + } + return result; +} + +template +inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { + T rv; + this->symbol_table_.Set(rv, sref); + return rv; +} + +inline IntRV ConcreteScheduleNode::CreateRV(int64_t number) { + Var rv; + this->symbol_table_.Set(rv, Integer(number)); + return std::move(rv); +} + +inline Array ConcreteScheduleNode::CreateRV(const Array& numbers) { + Array result; + result.reserve(numbers.size()); + for (int64_t number : numbers) { + Var rv; + this->symbol_table_.Set(rv, IntImm(DataType::Int(32), number)); + result.push_back(rv); + } + return result; +} + +inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { + auto it = this->symbol_table_.find(obj); + if (it != this->symbol_table_.end()) { + this->symbol_table_.erase(obj); + } else { + LOG(FATAL) << "IndexError: Cannot find the object in the symbol table: " << obj; + throw; + } +} + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc new file mode 100644 index 000000000000..9cacbd0a536d --- /dev/null +++ b/src/tir/schedule/schedule.cc @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +namespace tvm { +namespace tir { + +/**************** Constructor ****************/ + +BlockRV::BlockRV() { this->data_ = make_object(); } + +LoopRV::LoopRV() { this->data_ = make_object(); } + +/**************** GetSRef ****************/ + +StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { + ScheduleState state = this->state(); + auto it = state->stmt2ref.find(stmt); + if (it == state->stmt2ref.end()) { + LOG(FATAL) << "IndexError: The stmt doesn't exist in the IR"; + } + return it->second; +} + +/**************** FFI ****************/ + +TVM_REGISTER_NODE_TYPE(BlockRVNode); +TVM_REGISTER_NODE_TYPE(LoopRVNode); +TVM_REGISTER_OBJECT_TYPE(ScheduleNode); + +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleModule") // + .set_body_method(&ScheduleNode::mod); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // + .set_body_method(&ScheduleNode::state); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // + .set_body_method(&ScheduleNode::Seed); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // + .set_body_method(&ScheduleNode::Copy); + +/**************** (FFI) Constructor ****************/ + +TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") + .set_body_typed([](ObjectRef obj, int debug_mode) -> Schedule { + IRModule mod{nullptr}; + if (const auto* func = obj.as()) { + mod = IRModule({{GlobalVar("main"), GetRef(func)}}); + } else if (const auto* p_mod = obj.as()) { + mod = GetRef(p_mod); + } else { + LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " + << obj->GetTypeKey(); + } + return Schedule::Concrete(mod, debug_mode); + }); + +/******** (FFI) Lookup random variables ********/ + +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") + .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef { + if (const auto* loop_rv = obj.as()) { + return self->Get(GetRef(loop_rv)); + } + if (const auto* block_rv = obj.as()) { + return self->Get(GetRef(block_rv)); + } + if (const auto* int_rv = obj.as()) { + int64_t result = self->Get(GetRef(int_rv)); + return IntImm(DataType::Int(32), result); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey() + << ". Its value is: " << obj; + throw; + }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") + .set_body_typed([](Schedule self, ObjectRef obj) -> Optional { + if (const auto* loop_rv = obj.as()) { + return self->GetSRef(GetRef(loop_rv)); + } + if (const auto* block_rv = obj.as()) { + return self->GetSRef(GetRef(block_rv)); + } + if (const auto* stmt = obj.as()) { + return self->GetSRef(GetRef(stmt)); + } + LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); + throw; + }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") + .set_body_typed([](Schedule self, ObjectRef obj) -> void { + if (const auto* loop_rv = obj.as()) { + return self->RemoveRV(GetRef(loop_rv)); + } + if (const auto* block_rv = obj.as()) { + return self->RemoveRV(GetRef(block_rv)); + } + if (const auto* int_rv = obj.as()) { + return self->RemoveRV(GetRef(int_rv)); + } + LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); + throw; + }); + +/***** (FFI) Block/Loop relation *****/ + +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") + .set_body_method(&ScheduleNode::GetBlock); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") + .set_body_method(&ScheduleNode::GetLoops); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py new file mode 100644 index 000000000000..af89ca252738 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.ir import IRModule +from tvm.script import ty + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_tir_schedule_creation(): + # Tests: + # - Schedule.__init__ for PrimFunc and IRModule + # - Schedule.mod + # - Schedule.state + sch_1 = tir.Schedule(matmul, debug_mode=True) + sch_2 = tir.Schedule(IRModule({"main": matmul}), debug_mode=True) + assert sch_1.mod["main"].same_as(sch_2.mod["main"]) + assert sch_1.state.mod["main"].same_as(sch_2.state.mod["main"]) + + +def test_tir_schedule_get_block(): + # Tests: + # - Schedule.get_block + # - Schedule.get_sref + # - Schedule.get + sch = tir.Schedule(matmul, debug_mode=True) + block_rv = sch.get_block(name="update") + block_sref = sch.get_sref(block_rv) + block = sch.get(block_rv) + assert block.name_hint == "update" + assert block_sref.stmt.same_as(block) + assert sch.state.get_sref(block).same_as(block_sref) + assert block.same_as(matmul.body.block.body.body.body[1].body.block) + + +def test_tir_schedule_get_loops(): + # Tests: + # - Schedule.get_loops + # - Schedule.get + sch = tir.Schedule(matmul, debug_mode=True) + block_rv = sch.get_block(name="update") + i, j, k = sch.get_loops(block_rv) + assert sch.get(i).loop_var.name == "i" + assert sch.get(j).loop_var.name == "j" + assert sch.get(k).loop_var.name == "k" + + +def test_tir_schedule_copy(): + # Tests: + # - Schedule.copy + sch_1 = tir.Schedule(matmul, debug_mode=True) + block_rv = sch_1.get_block(name="update") + i, j, k = sch_1.get_loops(block_rv) + assert sch_1.get(i).loop_var.name == "i" + assert sch_1.get(j).loop_var.name == "j" + assert sch_1.get(k).loop_var.name == "k" + + sch_2 = sch_1.copy() + assert sch_2.get(block_rv).name_hint == "update" + assert sch_2.get(i).loop_var.name == "i" + assert sch_2.get(j).loop_var.name == "j" + assert sch_2.get(k).loop_var.name == "k" + + +def test_tir_schedule_remove_rv(): + # Tests: + # - Schedule.remove_rv + sch = tir.Schedule(matmul, debug_mode=True) + block_rv = sch.get_block(name="update") + assert sch.get(block_rv).name_hint == "update" + sch.remove_rv(block_rv) + with pytest.raises(IndexError): + sch.get(block_rv) + + +if __name__ == "__main__": + test_tir_schedule_creation() + test_tir_schedule_get_block() + test_tir_schedule_get_loops() + test_tir_schedule_copy() + test_tir_schedule_remove_rv()