From 8bd7a05f23323127eb0dfd088d83d3c30ddc65e9 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 12 Jul 2022 12:06:24 -0700 Subject: [PATCH] address comments --- include/tvm/meta_schedule/schedule_rule.h | 2 -- .../schedule_rule/multi_level_tiling.py | 2 -- .../postproc/rewrite_reduction_block.cc | 14 ++++---- .../postproc/rewrite_tensorize.cc | 32 ++++++++++--------- .../schedule_rule/multi_level_tiling.h | 25 --------------- .../multi_level_tiling_tensor_core.cc | 24 ++++++++++++++ ...eta_schedule_postproc_rewrite_tensorize.py | 2 +- 7 files changed, 50 insertions(+), 51 deletions(-) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 2677864b44691..5e4698db17858 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -182,10 +182,8 @@ class ScheduleRule : public runtime::ObjectRef { * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...) * beforehand * \param structure The tiling structure. Recommended: - * - 'SSRSRS' on CPU * - 'SSSRRSRS' on GPU * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: - * - NullOpt on CPU * - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 9e455f5af4ed3..71fbaee4f60bf 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -148,11 +148,9 @@ class MultiLevelTilingTensorCore(ScheduleRule): TensorIntrin.register(...) beforehand structure : str The tiling structure. Recommended: - - 'SSRSRS' on CPU - 'SSSRRSRS' on GPU tile_bind : Optional[List[str]] For each level of tiles, which thread axis it is bound to. Recommended: - - None on CPU - [blockIdx.y, vthread.x, threadIdx.y] on GPU max_innermost_factor : Optional[int] The maximum size of the innermost factor. None means no limit diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index a31e6204f6d47..ea204e3061336 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -140,15 +140,17 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) { // Remove tensorization annotation as it shouldn't be propagated to the init block. sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); - } - if (Optional tensorize_init = - tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init)) { + Optional tensorize_init = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); // The annotation of tensorization of the init statement should be moved to the init block // after 'DecomposeReduction'. + // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is NullOpt. sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize, - tensorize_init.value()); - sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init); - sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init); + tensorize_init.value_or("")); + if (tensorize_init.defined()) { + sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init); + sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init); + } } ++rewritten; } diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index e4ec270505560..3951ff2e9f2c8 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -38,21 +38,23 @@ void CollectTensorizationJobs( std::string block_name = block_sref->StmtAs()->name_hint; if (Optional intrin_name = tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { - jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { - try { - sch->Tensorize(block, intrin_name.value()); - } catch (const std::exception& e) { - LOG(WARNING) << "Tensorize failed with error " << e.what(); - } - }); - } else if (block_name.find("init") && vectorize_init_loop) { - jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { - Array child_blocks = sch->GetChildBlocks(block); - ICHECK(child_blocks.size() == 1); - Array init_loops = sch->GetLoops(child_blocks[0]); - ICHECK(init_loops.size() == 1); - sch->Vectorize(init_loops[0]); - }); + if (intrin_name.value() != "") { + jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { + try { + sch->Tensorize(block, intrin_name.value()); + } catch (const std::exception& e) { + LOG(WARNING) << "Tensorize failed with error " << e.what(); + } + }); + } else if (block_name.find("init") && vectorize_init_loop) { + jobs->emplace_back(block_name, func_name, [sch,block_name](tir::BlockRV block) { + Array child_blocks = sch->GetChildBlocks(block); + ICHECK(child_blocks.size() == 1) << block_name << child_blocks; + Array init_loops = sch->GetLoops(child_blocks[0]); + ICHECK(init_loops.size() == 1); + sch->Vectorize(init_loops[0]); + }); + } } } }); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 36c2efdafbef1..33982556741b5 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -117,31 +117,6 @@ class State : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); }; -class TensorCoreStateNode : public StateNode { - public: - /*! \brief The Tensor Core reindex block A for Tensor Core computation */ - tir::BlockRV tensor_core_reindex_A; - /*! \brief The Tensor Core reindex block B for Tensor Core computation */ - tir::BlockRV tensor_core_reindex_B; - /*! \brief The Tensor Core reindex store block for Tensor Core computation */ - tir::BlockRV tensor_core_reindex_store; - - State Copy() const final; - - static constexpr const char* _type_key = "meta_schedule.TensorCoreState"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode); -}; - -class TensorCoreState : public State { - public: - explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv, - Array> tiles = {}); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); -}; - -struct AutoTensorizationState : public State {}; - /*! * \brief Helper to apply a sub-rule to a list of auto scheduling states * \tparam FLambda The type of the sub-rule functor diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 7cf23fa8ad938..c41906d9713ff 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -40,6 +40,30 @@ struct TensorCoreIntrinGroup { String store_intrin; }; +class TensorCoreStateNode : public StateNode { + public: + /*! \brief The Tensor Core reindex block A for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_A; + /*! \brief The Tensor Core reindex block B for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_B; + /*! \brief The Tensor Core reindex store block for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_store; + + State Copy() const final; + + static constexpr const char* _type_key = "meta_schedule.TensorCoreState"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode); +}; + +class TensorCoreState : public State { + public: + explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv, + Array> tiles = {}); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); +}; + + TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode); TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array> tiles) { diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py index 6fae11c7fd547..a1184c1edfe77 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py @@ -361,7 +361,7 @@ def main( ) T.reads() T.writes(compute_local[i, j]) - T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) + T.block_attr({"meta_schedule.auto_tensorize": ""}) with T.block("compute_init"): T.reads() T.writes(compute_local[i, j])