diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 5e4698db1785..b5f4a17b698d 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -174,13 +174,13 @@ class ScheduleRule : public runtime::ObjectRef { Optional> reuse_read, Optional> reuse_write); /*! - * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core - * intrinsics - * \param intrin_group A group of tensor core intrinsics. The map should contains key "init", - * "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization, - * loading operand A, loading operand B, tensor core computation, storing the result. The value of - * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...) - * beforehand + * \brief Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate + * tensor core intrinsics + * \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key + * "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for + * initialization, loading operand A, loading operand B, tensor core computation, storing the + * result. The value of the map should be names of tensor intrinsics, must be registerd via + * TensorIntrin.register(...) beforehand * \param structure The tiling structure. Recommended: * - 'SSSRRSRS' on GPU * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: @@ -193,9 +193,10 @@ class ScheduleRule : public runtime::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingTensorCore( - Map intrin_group, String structure, Optional> tile_binds, - Optional max_innermost_factor, Optional> vector_load_lens, - Optional> reuse_read, Optional> reuse_write); + Array> intrin_groups, String structure, + Optional> tile_binds, Optional max_innermost_factor, + Optional> vector_load_lens, Optional> reuse_read, + Optional> reuse_write); /*! * \brief Create a rule: add-rfactor to some blocks if needed diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 71fbaee4f60b..a728a91eb74e 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -135,15 +135,15 @@ def __init__( @register_object("meta_schedule.MultiLevelTilingTensorCore") class MultiLevelTilingTensorCore(ScheduleRule): - """Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core - intrinsics. + """Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate tensor + core intrinsics. Parameters ---------- - intrin_group : Mapping[str, str] - A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b", - "compute", "store", which represent the tensor intrin for initialization, loading operand A, - loading operand B, tensor core computation, storing the result. + intrin_groups : List[Mapping[str, str]] + A list of groups of tensor core intrinsics. The map should contains key "init", "load_a", + "load_b", "compute", "store", which represent the tensor intrin for initialization, + loading operand A, loading operand B, tensor core computation, storing the result. The value of the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...) beforehand structure : str @@ -165,7 +165,7 @@ class MultiLevelTilingTensorCore(ScheduleRule): def __init__( self, - intrin_group: Mapping[str, str], + intrin_groups: List[Mapping[str, str]], structure: str, tile_binds: Optional[List[str]] = None, max_innermost_factor: Optional[int] = None, @@ -175,7 +175,7 @@ def __init__( ) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member - intrin_group, + intrin_groups, structure, tile_binds, max_innermost_factor, diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 717be5951240..ea748ddc0538 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Default schedule rules""" +from typing import List, Union from tvm.meta_schedule.schedule_rule import ( AddRFactor, AutoBind, @@ -114,18 +115,29 @@ def multi_level_tiling(target: Target) -> ScheduleRule: def multi_level_tiling_tensor_core( target: Target, - write_reuse_scope="shared", - in_dtype="float16", - out_dtype="float32", - trans_b=False, + write_reuse_scope: str = "shared", + in_dtype: Union[str, List[str]] = "float16", + out_dtype: Union[str, List[str]] = "float32", + trans_b: Union[bool, List[bool]] = False, ) -> ScheduleRule: """Default schedule rules for with multi-level tiling reuse for tensor core""" assert write_reuse_scope in ["shared", "global"] + if not isinstance(in_dtype, list): + in_dtype = [in_dtype] + if not isinstance(out_dtype, list): + out_dtype = [out_dtype] + if not isinstance(trans_b, list): + trans_b = [trans_b] + if target.kind.name == "cuda": + intrin_groups = [ + tensor_intrin.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b) + for _in_dtype in in_dtype + for _out_dtype in out_dtype + for _trans_b in trans_b + ] return MultiLevelTilingTensorCore( - intrin_group=tensor_intrin.get_wmma_intrin_group( - write_reuse_scope, in_dtype, out_dtype, trans_b - ), + intrin_groups=intrin_groups, structure="SSSRRSRS", tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], max_innermost_factor=4, # 64 // tensor intrin size diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 91df62fc3663..6d34f7b64e34 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -38,10 +38,42 @@ struct TensorCoreIntrinGroup { String load_b_intrin; String compute_intrin; String store_intrin; + + /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the + * following keys: + * - init + * - load_a + * - load_b + * - compute + * - store + * The values of the keys should be the names of the corresponding intrinsics and should be + * registered via TensorIntrin.Register beforehand. + */ + static TensorCoreIntrinGroup FromConfig(const Map& config); }; +TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map& config) { + auto f_initialize_intrin = [&config](String key_name, String* intrin_name) { + CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set."; + *intrin_name = config.at(key_name); + // Check the existence of the intrin + tir::TensorIntrin::Get(*intrin_name); + }; + TensorCoreIntrinGroup intrin_group; + f_initialize_intrin("init", &intrin_group.init_intrin); + f_initialize_intrin("load_a", &intrin_group.load_a_intrin); + f_initialize_intrin("load_b", &intrin_group.load_b_intrin); + f_initialize_intrin("compute", &intrin_group.compute_intrin); + f_initialize_intrin("store", &intrin_group.store_intrin); + return intrin_group; +} + class TensorCoreStateNode : public StateNode { public: + /*! \brief The tensor core intrinsic group. */ + TensorCoreIntrinGroup intrin_group; + /*! \brief The auto tensorization maping info. */ + tir::AutoTensorizeMappingInfo mapping_info{nullptr}; /*! \brief The Tensor Core reindex block A for Tensor Core computation */ tir::BlockRV tensor_core_reindex_A; /*! \brief The Tensor Core reindex block B for Tensor Core computation */ @@ -57,16 +89,21 @@ class TensorCoreStateNode : public StateNode { class TensorCoreState : public State { public: - explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv, - Array> tiles = {}); + explicit TensorCoreState(TensorCoreIntrinGroup intrin_group, + tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, + BlockRV block_rv, Array> tiles = {}); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); }; TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode); -TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array> tiles) { +TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, + tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, + BlockRV block_rv, Array> tiles) { ObjectPtr node = make_object(); + node->intrin_group = intrin_group; + node->mapping_info = mapping_info; node->sch = std::move(sch); node->block_rv = std::move(block_rv); node->tiles = std::move(tiles); @@ -116,16 +153,12 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { const String& intrin_name) const; public: - /*! \brief The tensor core intrin group to apply */ - TensorCoreIntrinGroup intrin_group; + /*! \brief The candidate tensor core intrin groups to apply */ + std::vector intrin_groups; static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore"; TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode); private: - /*! - * \brief The mapping info for auto tensorization - */ - tir::AutoTensorizeMappingInfo mapping_info_{nullptr}; }; // Entry of the mega rule; Inherited from ScheduleRuleNode @@ -135,21 +168,36 @@ Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, return {sch}; } - Optional mapping_info = - tir::GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block_rv), - tir::TensorIntrin::Get(intrin_group.compute_intrin)->desc); - if (!mapping_info.defined()) { + std::unordered_map intrin_group_to_mapping_info; + for (int i = 0, n = intrin_groups.size(); i < n; ++i) { + TensorCoreIntrinGroup intrin_group = intrin_groups[i]; + Optional mapping_info = tir::GetAutoTensorizeMappingInfo( + sch->state(), sch->GetSRef(block_rv), + tir::TensorIntrin::Get(intrin_groups[i].compute_intrin)->desc); + if (mapping_info.defined()) { + intrin_group_to_mapping_info.emplace(i, mapping_info.value()); + } + } + + if (intrin_group_to_mapping_info.empty()) { + // No tensor intrinsics can be applied. return {sch}; } - mapping_info_ = mapping_info.value(); - // Create a copy of the schedule so that we can roll back transformations if tensorization + // Save the original schedule so that we can roll back transformations if tensorization // fail. - Schedule original_sch = sch->Copy(); - sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); - + Schedule original_sch = sch; + + std::vector initial_states; + for (const auto& kv : intrin_group_to_mapping_info) { + const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first]; + const tir::AutoTensorizeMappingInfo& mapping_info = kv.second; + Schedule new_sch = sch->Copy(); + new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); + initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv)); + } Array results; - for (auto&& state : ApplySubRules({TensorCoreState(sch, block_rv)})) { + for (auto&& state : ApplySubRules(initial_states)) { results.push_back(std::move(state->sch)); } if (results.empty()) { @@ -196,7 +244,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( AnnotateCooperativeFetching(&sch, state->write_reuse[0]); } sch->ReverseComputeInline(state->tensor_core_reindex_store); - TileAndAnnotateTensorize(&sch, cache_write, intrin_group.store_intrin); + TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin); return {state}; } @@ -212,8 +260,8 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( TileAndAnnotateTensorize(&sch, cache_read, intrin_name); }; - f_tensorize_load(0, "wmma.matrix_a", intrin_group.load_a_intrin); - f_tensorize_load(1, "wmma.matrix_b", intrin_group.load_b_intrin); + f_tensorize_load(0, "wmma.matrix_a", state->intrin_group.load_a_intrin); + f_tensorize_load(1, "wmma.matrix_b", state->intrin_group.load_b_intrin); sch->ComputeInline(state->tensor_core_reindex_A); sch->ComputeInline(state->tensor_core_reindex_B); @@ -238,6 +286,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( TensorCoreStateNode* state, const String& intrin_name) const { BlockRV block_rv = state->block_rv; + const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); // Add reindex stages @@ -258,24 +307,24 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // Transform the layout of reindex buffers accordingly. // The index map defines the mapping for the computation block. We need to extract the sub index // map to transform the load and store block. - ICHECK_EQ(mapping_info_->mappings.size(), 1U); // assume only one mapping is present - const tir::IndexMap& index_map = mapping_info_->mappings[0]; + ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present + const tir::IndexMap& index_map = mapping_info->mappings[0]; // Find the correspondence between block iters and the iters in the index map. std::unordered_map lhs_to_index_map_src; std::unordered_map rhs_to_index_map_tgt; std::unordered_set unmapped_index_map_src; - ICHECK_EQ(mapping_info_->lhs_iters.size(), index_map->initial_indices.size()); - for (int i = 0; i < static_cast(mapping_info_->lhs_iters.size()); ++i) { - lhs_to_index_map_src[mapping_info_->lhs_iters[i]->var] = index_map->initial_indices[i]; + ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); + for (int i = 0; i < static_cast(mapping_info->lhs_iters.size()); ++i) { + lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i]; } // The number of result iters in the index map is equal or more than the number of rhs (the - // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from the - // lhs. They will be skipped during pattern matching for tensorization. - // An example of such case is batch matmul, the batch dimension is kept after layout - // transformations and it will be kept as a outer loop after tensorization. + // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from + // the lhs. They will be skipped during pattern matching for tensorization. An example of such + // case is batch matmul, the batch dimension is kept after layout transformations and it will be + // kept as a outer loop after tensorization. int offset = static_cast(index_map->final_indices.size()) - - static_cast(mapping_info_->rhs_iters.size()); + static_cast(mapping_info->rhs_iters.size()); ICHECK_GE(offset, 0); for (int i = 0; i < offset; ++i) { const tir::VarNode* var_ptr = index_map->final_indices[i].as(); @@ -283,13 +332,13 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( unmapped_index_map_src.insert(GetRef(var_ptr)); } for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { - rhs_to_index_map_tgt[mapping_info_->rhs_iters[i - offset]->var] = index_map->final_indices[i]; + rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; } auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) { std::vector sub_index_map_src; std::vector sub_index_map_tgt; - const tir::Buffer& rhs_buffer = mapping_info_->lhs_buffer_map[lhs_buffer]; + const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; for (const Range& range : lhs_region) { ICHECK(tir::is_one(range->extent)); const tir::VarNode* var_ptr = range->min.as(); @@ -300,8 +349,8 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( sub_index_map_tgt.push_back(lhs_representer); } } - for (size_t i = 0; i < mapping_info_->rhs_buffer_indices[rhs_buffer].size(); ++i) { - const tir::VarNode* var = mapping_info_->rhs_buffer_indices[rhs_buffer][i].as(); + for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { + const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); ICHECK(var != nullptr); sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef(var)]); } @@ -345,7 +394,7 @@ inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorizat TensorCoreState state) const { // Do reindex and layout transformations. Optional transformed_loop_rv = - TransformWithTensorIntrin(state.operator->(), intrin_group.compute_intrin); + TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin); if (!transformed_loop_rv.defined()) { // The workload can't be tensorized. return {}; @@ -356,32 +405,24 @@ inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorizat // Add annotations for post processors. state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize, - intrin_group.compute_intrin); + state->intrin_group.compute_intrin); state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init, - intrin_group.init_intrin); + state->intrin_group.init_intrin); state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true)); return {std::move(state)}; } ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( - Map intrin_group, String structure, Optional> tile_binds, + Array> intrin_groups, String structure, Optional> tile_binds, Optional max_innermost_factor, Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write) { auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); - auto f_initialize_intrin = [&intrin_group](String key_name, String* intrin_name) { - CHECK(intrin_group.count(key_name)) << "ValueError: " << key_name << " is not set."; - *intrin_name = intrin_group.at(key_name); - // Check the existence of the intrin - tir::TensorIntrin::Get(*intrin_name); - }; - f_initialize_intrin("init", &node->intrin_group.init_intrin); - f_initialize_intrin("load_a", &node->intrin_group.load_a_intrin); - f_initialize_intrin("load_b", &node->intrin_group.load_b_intrin); - f_initialize_intrin("compute", &node->intrin_group.compute_intrin); - f_initialize_intrin("store", &node->intrin_group.store_intrin); - + node->intrin_groups.reserve(intrin_groups.size()); + for (const auto& intrin_group_config : intrin_groups) { + node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config)); + } return ScheduleRule(node); } diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index 1ceef0afc3f5..c43645832b6f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -563,25 +563,6 @@ def test_multi_level_tiling_dense_dpa4(): check_trace(spaces, expected) -def test_cuda_tensor_core_conv2d(): - target = Target("cuda", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.conv2d_nhwc_f16( - N=1, H=16, W=16, CI=16, CO=16, kernel_size=3, stride=1, padding=1 - ) - ), - target, - multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - - expected = [] - print("".join(spaces[0].trace.as_python())) - check_trace(spaces, expected) - - def test_cuda_tensor_core_matmul_relu(): m = n = k = 128 target = Target("cuda", host="llvm") @@ -719,14 +700,15 @@ def test_cuda_tensor_core_matmul_relu(): def test_cuda_tensor_core_matmul_relu_global(): m = n = k = 128 target = Target("cuda", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu_fp16( - n=n, - m=m, - k=k, - ), + workload = create_prim_func( + te_workload.matmul_relu_fp16( + n=n, + m=m, + k=k, ), + ) + ctx = _create_context( + workload, target=target, rule=[ multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"), @@ -822,6 +804,106 @@ def test_cuda_tensor_core_matmul_relu_global(): ] check_trace(spaces, expected) + ctx = _create_context( + workload, + target=target, + rule=[ + multi_level_tiling_tensor_core( + target=target, write_reuse_scope="global", trans_b=[False, True] + ), + auto_inline(target), + ], + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 2 + + expected = [ + expected[0], + """b0 = sch.get_block(name="C", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") +b1 = sch.reindex(block=b0, buffer=("write", 0)) +b2 = sch.reindex(block=b0, buffer=("read", 0)) +b3 = sch.reindex(block=b0, buffer=("read", 1)) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (j, k, )) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) +l4, l5, l6 = sch.get_loops(block=b0) +l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) +l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) +l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True) +l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0) +sch.reorder(l15, l17, l12, l10, l8) +b19 = sch.blockize(loop=l12) +sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32_trans") +sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") +sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1) +l20, l21, l22 = sch.get_loops(block=b19) +v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4) +l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True) +v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) +l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True) +v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4) +l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True) +sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42) +l49 = sch.fuse(l28, l38, preserve_unit_iters=True) +sch.bind(loop=l49, thread_axis="blockIdx.y") +l50 = sch.fuse(l29, l39, preserve_unit_iters=True) +sch.bind(loop=l50, thread_axis="blockIdx.x") +l51 = sch.fuse(l30, l40, preserve_unit_iters=True) +sch.bind(loop=l51, thread_axis="threadIdx.y") +b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator") +sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True) +sch.reverse_compute_inline(block=b1) +l53, l54, l55, l56, l57 = sch.get_loops(block=b52) +l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True) +l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True) +l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52) +sch.reorder(l67, l61, l59) +b69 = sch.blockize(loop=l61) +sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global") +b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared") +sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True) +l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70) +l77 = sch.fuse(l75, l76, preserve_unit_iters=True) +v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) +b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared") +sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True) +l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79) +l86 = sch.fuse(l84, l85, preserve_unit_iters=True) +v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) +b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a") +sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True) +l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88) +l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) +l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True) +l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88) +sch.reorder(l107, l99, l97) +b109 = sch.blockize(loop=l99) +sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") +b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b") +sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True) +l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110) +l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True) +l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True) +l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110) +sch.reorder(l129, l121, l119) +b131 = sch.blockize(loop=l121) +sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b_trans") +sch.compute_inline(block=b2) +sch.compute_inline(block=b3) +sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8) +sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split( + "\n" + ), + ] + check_trace(spaces, expected) + def test_multi_level_tiling_non_tensorizable(): # expected to do nothing on non-tensorizable workloads @@ -850,13 +932,13 @@ def test_multi_level_tiling_non_tensorizable(): def test_cuda_tensor_core_conv2d(): target = Target("cuda", host="llvm") + workload = create_prim_func( + te_workload.conv2d_nhwc_f16( + N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1 + ) + ) ctx = _create_context( - create_prim_func( - # dtype doesn't match tensor intrin - te_workload.conv2d_nhwc_f16( - N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1 - ) - ), + workload, target=target, rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"), ) @@ -955,6 +1037,21 @@ def test_cuda_tensor_core_conv2d(): ] check_trace(spaces, expected) + # test adding unappliable tensor intrinsics doesn't change the search space + ctx = _create_context( + workload, + target, + multi_level_tiling_tensor_core( + target=target, + write_reuse_scope="shared", + in_dtype="float16", + out_dtype=["float16", "float32"], + ), + ) + check_trace(spaces, expected) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + if __name__ == "__main__": tvm.testing.main()