diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index b662c8a177ec2..8ea2019bb0d9a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -18,8 +18,6 @@ */ #include "multi_level_tiling.h" -#include - #include "../utils.h" namespace tvm { diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index ac2bda83c66cb..279bae4734287 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -16,9 +16,12 @@ * specific language governing permissions and limitations * under the License. */ -#include +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ -#include "../utils.h" +#include +#include +#include "../../support/array.h" namespace tvm { namespace meta_schedule { @@ -206,3 +209,5 @@ ScheduleRule MultiLevelTilingInitCommon(String structure, Optional } // namespace meta_schedule } // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc index 45fb14fb4b757..79fb9df9086c0 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc @@ -16,180 +16,22 @@ * specific language governing permissions and limitations * under the License. */ -#include #include "../utils.h" #include "multi_level_tiling.h" +#include "../../tir/schedule/analysis.h" namespace tvm { namespace meta_schedule { using tir::LoopRV; -/*! \brief Necessary information used for tensorization */ -class TensorizeInfoNode : public Object { - public: - /*! \brief Maps block loops to desc loops */ - Map loop_map; - /*! \brief Maps loops in desc to its index, outer to inner */ - Map desc_loop_indexer; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("loop_map", &loop_map); - v->Visit("desc_loop_indexer", &desc_loop_indexer); - } - - static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); -}; - -class TensorizeInfo : public ObjectRef { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); -}; - -TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); - -Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { - // Try to do tiling automatically if possible - // Now the heuristic is that if block's block var binding is constant + loop var, - // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder - // i, j, k according to the loops outside desc_block - // Collect the loops outside block - arith::Analyzer analyzer; - const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); - // Step 1. Analyze desc_func, extract its block, loops and loop vars - const tir::BlockRealizeNode* desc_block = nullptr; - std::vector desc_loops; - std::unordered_set desc_loop_vars; - const auto* desc_scope_realize = desc_func->body.as(); - ICHECK(desc_scope_realize); - { - auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, - &analyzer](const ObjectRef& obj) -> bool { - // Extract the block - if (const auto* block = obj.as()) { - desc_block = block; - return false; - } - // Extract the loops - if (const auto* loop = obj.as()) { - desc_loops.push_back(loop); - desc_loop_vars.insert(loop->loop_var.get()); - if (!analyzer.CanProve(loop->min == 0)) { - return false; - } - } - return true; - }; - tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); - std::reverse(desc_loops.begin(), desc_loops.end()); - ICHECK(desc_block); - } - // Step 2. Check if `desc_block` matches `block` - // Ignore the scope of buffers when comparing, since we can do cache_read/write - const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); - const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); - std::vector block_loops; - std::unordered_set block_loop_vars; - { - for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { - const auto* loop = loop_sref->StmtAs(); - if (loop == nullptr || loop->body->IsInstance()) { - break; - } - block_loops.push_back(loop); - block_loop_vars.insert(loop->loop_var.get()); - if (!analyzer.CanProve(loop->min == 0)) { - return NullOpt; - } - } - std::reverse(block_loops.begin(), block_loops.end()); - } - // Step 4. Map from block loops to desc block loops - ObjectPtr ret = make_object(); - int n_block_vars = block->iter_values.size(); - int n_desc_vars = desc_block->iter_values.size(); - int offset = n_block_vars - n_desc_vars; - if (offset < 0) { - return NullOpt; - } - // We align the block and desc block's bindings from the right side - // block (v0=..., v1=..., v2=...) - // ^ i_block - // desc_block( v1=..., v2=...) - // ^ i_desc - for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) { - // For each block var binding, we find - const PrimExpr& block_bind = block->iter_values[i_block]; - const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; - // Step 4.1. Find the corresponding loop of the i-th block var of block - const tir::ForNode* block_loop = nullptr; - for (int i = 0, n = block_loops.size(); i < n; ++i) { - // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars - PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); - if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) { - return block_loop_vars.count(var); - })) { - block_loop = block_loops[i]; - break; - } - } - if (block_loop == nullptr) { - return NullOpt; - } - // Step 4.2. Find the corresponding loop of the i-th block var of desc - const tir::ForNode* desc_loop = nullptr; - for (int i = 0, n = desc_loops.size(); i < n; ++i) { - // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars - PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); - if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) { - return desc_loop_vars.count(var); - })) { - desc_loop = desc_loops[i]; - break; - } - } - if (block_loop == nullptr) { - return NullOpt; - } - // Step 4.3. Check divisibility of loop extents - PrimExpr block_extent = analyzer.Simplify(block_loop->extent); - PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); - if (const auto* int_block_extent = block_extent.as()) { - if (const auto* int_desc_extent = desc_extent.as()) { - if (int_block_extent->value % int_desc_extent->value != 0) { - return NullOpt; - } - } else { - return NullOpt; - } - } else { - return NullOpt; - } - // Step 4.4. Maps the result of Step 4.1 to Step 4.2 - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; - auto it = ret->loop_map.find(block_loop_sref); - if (it == ret->loop_map.end()) { - ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); - } else if ((*it).second.get() != desc_loop) { - return NullOpt; - } - } - for (int i = 0, n = desc_loops.size(); i < n; ++i) { - ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); - } - return TensorizeInfo(ret); -} - Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, const String& intrin_name) { - Optional opt_tensorize_info = GetTensorizeLoopMapping( + Optional opt_tensorize_info = GetTensorizeLoopMapping( sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); if (!opt_tensorize_info) return NullOpt; - const TensorizeInfoNode* info = opt_tensorize_info.value().get(); + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); // Construct a mapping from tir loops back to LoopRVs Map loop2rv; { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index e74b9ea264845..6ec5b34e03868 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -645,6 +645,32 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, const P const StmtSRef& dom_high_exclusive, arith::Analyzer* analyzer); +/*! \brief Necessary information used for tensorization */ +class TensorizeInfoNode : public Object { + public: + /*! \brief Maps block loops to desc loops */ + Map loop_map; + /*! \brief Maps loops in desc to its index, outer to inner */ + Map desc_loop_indexer; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("loop_map", &loop_map); + v->Visit("desc_loop_indexer", &desc_loop_indexer); + } + + static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); +}; + +class TensorizeInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); +}; + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 435870471f29f..e2d8dc76969df 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1992,5 +1992,141 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } +TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + // Try to do tiling automatically if possible + // Now the heuristic is that if block's block var binding is constant + loop var, + // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder + // i, j, k according to the loops outside desc_block + // Collect the loops outside block + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + const tir::BlockRealizeNode* desc_block = nullptr; + std::vector desc_loops; + std::unordered_set desc_loop_vars; + const auto* desc_scope_realize = desc_func->body.as(); + ICHECK(desc_scope_realize); + { + auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, + &analyzer](const ObjectRef& obj) -> bool { + // Extract the block + if (const auto* block = obj.as()) { + desc_block = block; + return false; + } + // Extract the loops + if (const auto* loop = obj.as()) { + desc_loops.push_back(loop); + desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return false; + } + } + return true; + }; + tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); + std::reverse(desc_loops.begin(), desc_loops.end()); + ICHECK(desc_block); + } + // Step 2. Check if `desc_block` matches `block` + // Ignore the scope of buffers when comparing, since we can do cache_read/write + const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + std::vector block_loops; + std::unordered_set block_loop_vars; + { + for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { + break; + } + block_loops.push_back(loop); + block_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return NullOpt; + } + } + std::reverse(block_loops.begin(), block_loops.end()); + } + // Step 4. Map from block loops to desc block loops + ObjectPtr ret = make_object(); + int n_block_vars = block->iter_values.size(); + int n_desc_vars = desc_block->iter_values.size(); + int offset = n_block_vars - n_desc_vars; + if (offset < 0) { + return NullOpt; + } + // We align the block and desc block's bindings from the right side + // block (v0=..., v1=..., v2=...) + // ^ i_block + // desc_block( v1=..., v2=...) + // ^ i_desc + for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) { + // For each block var binding, we find + const PrimExpr& block_bind = block->iter_values[i_block]; + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + // Step 4.1. Find the corresponding loop of the i-th block var of block + const tir::ForNode* block_loop = nullptr; + for (int i = 0, n = block_loops.size(); i < n; ++i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) { + return block_loop_vars.count(var); + })) { + block_loop = block_loops[i]; + break; + } + } + if (block_loop == nullptr) { + return NullOpt; + } + // Step 4.2. Find the corresponding loop of the i-th block var of desc + const tir::ForNode* desc_loop = nullptr; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) { + return desc_loop_vars.count(var); + })) { + desc_loop = desc_loops[i]; + break; + } + } + if (block_loop == nullptr) { + return NullOpt; + } + // Step 4.3. Check divisibility of loop extents + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + if (const auto* int_block_extent = block_extent.as()) { + if (const auto* int_desc_extent = desc_extent.as()) { + if (int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + } else { + return NullOpt; + } + } else { + return NullOpt; + } + // Step 4.4. Maps the result of Step 4.1 to Step 4.2 + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + auto it = ret->loop_map.find(block_loop_sref); + if (it == ret->loop_map.end()) { + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + } else if ((*it).second.get() != desc_loop) { + return NullOpt; + } + } + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + } + return TensorizeInfo(ret); +} + } // namespace tir } // namespace tvm