From fd7712c771d1a9b98d409e95732d7f95ae0bc3c2 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 15 Jan 2022 16:53:57 +0800 Subject: [PATCH] [BugFix] Fix CrossThreadReduction on CUDA (#13) * Fix traling whitespace * Reorder * Add sketch test for NRM and SFM --- include/tvm/tir/stmt.h | 2 +- include/tvm/tir/transform.h | 7 + .../meta_schedule/testing/schedule_rule.py | 2 +- python/tvm/meta_schedule/tune.py | 2 +- .../schedule_rule/random_compute_location.cc | 4 +- src/tir/schedule/analysis.h | 8 ++ src/tir/schedule/analysis/analysis.cc | 14 +- .../test_meta_schedule_sketch_cuda.py | 130 ++++++++++++++++++ 8 files changed, 161 insertions(+), 8 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 20ad447a9b3c..429d4c2c54b4 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1358,7 +1358,7 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; /*! - * \brief Mark that the block need to add predicate for block var bounds during lowering + * \brief Mark that the block need to add predicate for block var bounds during lowering */ constexpr const char* require_block_var_bound_predicate = "require_bound_predicate"; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 9d4a7e976a45..edd75998d757 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -390,6 +390,13 @@ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); */ TVM_DLL Pass ApplyBlockBoundPredicate(); +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + * \return The pass. + */ +TVM_DLL Pass ApplyBlockBoundPredicate(); + /*! * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the * corresponding iter_values in BlockRealize, for opaque blocks by removing all diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index dec5ab68e701..d62a54bebac6 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -42,9 +42,9 @@ def get(target: Target) -> List[ScheduleRule]: ] if target.kind.name == "cuda": return [ - cross_thread_reduction(target), multi_level_tiling(target), auto_inline_after_tiling(target), + cross_thread_reduction(target), parallel_vectorize_unroll(target), ] raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 54229ee9823e..16f4104ad8ec 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -150,7 +150,6 @@ def _sch_rules() -> List[ScheduleRule]: ) return [ - M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), M.MultiLevelTiling( structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], @@ -177,6 +176,7 @@ def _sch_rules() -> List[ScheduleRule]: require_ordered=False, disallow_op=None, ), + M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), M.ParallelizeVectorizeUnroll( max_jobs_per_core=-1, # disable parallelize max_vectorize_extent=-1, # disable vectorize diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 20f6ac51595f..ba1476719491 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -47,8 +47,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode { if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { return false; } - // Cond 5. The block is not tiled. We check this condition by examine the block's annotation. - if (tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) { + // Cond 5. The block is not tiled. + if (tir::HasBeenMultiLevelTiled(block_sref)) { return false; } // Cond 6. The block has at lease one consumer. diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ed2507c76b8c..35c5ed76ccfd 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -629,6 +629,14 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri */ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Checks if the given block has been applied by multi-level tiling. We check this by examine + * the block's annotation. + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has been multi-level tiled. + */ +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); + /*! * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block. * \param self The schedule state. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index faf577cb4ecd..a6bb3f3e17b6 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1927,6 +1927,10 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref return total_unused_block_vars >= 1; } +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { + return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).defined(); +} + std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, const tir::StmtSRef& block_sref) { Array loops = tir::GetLoops(block_sref); @@ -1976,12 +1980,16 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // return false; } - // Cond 3. The block is a reduction block and has trivial binding. + // Cond 3. The block satisfies all the following properties + // - it is a reduction block; + // - it has trivial bindings; + // - it has not been tiled by multi-level tiling. const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, // /*require_stage_pipeline=*/false, // /*require_subtree_compact_dataflow=*/false); - if (!(IsReductionBlock(self, block_sref, scope_sref) && // - IsTrivialBinding(self, block_sref))) { + if (!IsReductionBlock(self, block_sref, scope_sref) // + || !IsTrivialBinding(self, block_sref) // + || HasBeenMultiLevelTiled(block_sref)) { return false; } diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py index 86bbfecd6980..ff31db46351c 100644 --- a/tests/python/unittest/test_meta_schedule_sketch_cuda.py +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -26,6 +26,10 @@ def _target() -> Target: return Target("cuda", host="llvm") +def _target_with_max_threads_per_block() -> Target: + return Target("nvidia/geforce-rtx-3080") + + def test_meta_schedule_cuda_sketch_matmul(): # pylint: disable=line-too-long expected = [ @@ -289,8 +293,134 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl check_trace(spaces, expected) +def test_meta_schedule_cuda_sketch_batchnorm(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "b2, = sch.get_consumers(block=b0)", + "l3, = sch.get_loops(block=b2)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l5, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9, l10 = sch.get_loops(block=b0)", + "l11 = sch.fuse(l9, l10)", + "l12, l13 = sch.split(loop=l11, factors=[None, v4])", + 'sch.bind(loop=l13, thread_axis="threadIdx.x")', + "v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v14)', + ], + [ + 'b0 = sch.get_block(name="root", func_name="main")', + "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.norm_bmn( + B=1, + M=256, + N=256, + ) + ), + target=_target_with_max_threads_per_block(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 2 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_softmax(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b1)", + "b4, = sch.get_consumers(block=b2)", + "l5, l6 = sch.get_loops(block=b4)", + "v7 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l8, l9 = sch.split(loop=l6, factors=[None, v7])", + 'sch.bind(loop=l9, thread_axis="threadIdx.x")', + "sch.compute_at(block=b2, loop=l5, preserve_unit_loops=True)", + 'sch.set_scope(block=b2, buffer_index=0, storage_scope="shared")', + "l10, l11, l12 = sch.get_loops(block=b2)", + "l13, l14 = sch.split(loop=l12, factors=[None, v7])", + 'sch.bind(loop=l14, thread_axis="threadIdx.x")', + "b15, b16 = sch.get_consumers(block=b0)", + "l17, l18, l19, l20 = sch.get_loops(block=b15)", + "sch.compute_at(block=b0, loop=l17, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l21, l22, l23 = sch.get_loops(block=b0)", + "l24, l25 = sch.split(loop=l23, factors=[None, v7])", + 'sch.bind(loop=l25, thread_axis="threadIdx.x")', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b0)", + "b3, = sch.get_consumers(block=b1)", + "l4, l5 = sch.get_loops(block=b3)", + "v6 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l7, l8 = sch.split(loop=l5, factors=[None, v6])", + 'sch.bind(loop=l8, thread_axis="threadIdx.x")', + "sch.compute_at(block=b1, loop=l4, preserve_unit_loops=True)", + 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', + "l9, l10, l11 = sch.get_loops(block=b1)", + "l12, l13 = sch.split(loop=l11, factors=[None, v6])", + 'sch.bind(loop=l13, thread_axis="threadIdx.x")', + "v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v14)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b1)", + "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l4, l5 = sch.get_loops(block=b0)", + "l6, l7 = sch.split(loop=l5, factors=[None, v3])", + 'sch.bind(loop=l7, thread_axis="threadIdx.x")', + "v8 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v8)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b0)", + "v2 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.softmax_mn( + m=256, + n=256, + ) + ), + target=_target_with_max_threads_per_block(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 4 + check_trace(spaces, expected) + + if __name__ == "__main__": test_meta_schedule_cuda_sketch_matmul() test_meta_schedule_cuda_sketch_matmul_relu() test_meta_schedule_cuda_sketch_conv2d_nchw() test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu() + test_meta_schedule_cuda_sketch_batchnorm() + test_meta_schedule_cuda_sketch_softmax()