Skip to content

Commit

Permalink
[BugFix] Fix CrossThreadReduction on CUDA (apache#13)
Browse files Browse the repository at this point in the history
* Fix traling whitespace

* Reorder

* Add sketch test for NRM and SFM
  • Loading branch information
MasterJH5574 authored Jan 15, 2022
1 parent ac09abc commit fd7712c
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 8 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*!
* \brief Mark that the block need to add predicate for block var bounds during lowering
* \brief Mark that the block need to add predicate for block var bounds during lowering
*/
constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";

Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
*/
TVM_DLL Pass ApplyBlockBoundPredicate();

/*!
* \brief Narrow the extents of some loops by checking whether some constraints in the block iter
* bound predicates can be directly applied on the loops.
* \return The pass.
*/
TVM_DLL Pass ApplyBlockBoundPredicate();

/*!
* \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the
* corresponding iter_values in BlockRealize, for opaque blocks by removing all
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def get(target: Target) -> List[ScheduleRule]:
]
if target.kind.name == "cuda":
return [
cross_thread_reduction(target),
multi_level_tiling(target),
auto_inline_after_tiling(target),
cross_thread_reduction(target),
parallel_vectorize_unroll(target),
]
raise NotImplementedError(f"{target.kind.name} is not supported")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _sch_rules() -> List[ScheduleRule]:
)

return [
M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]),
M.MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
Expand All @@ -177,6 +176,7 @@ def _sch_rules() -> List[ScheduleRule]:
require_ordered=False,
disallow_op=None,
),
M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]),
M.ParallelizeVectorizeUnroll(
max_jobs_per_core=-1, # disable parallelize
max_vectorize_extent=-1, # disable vectorize
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
return false;
}
// Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
// Cond 5. The block is not tiled.
if (tir::HasBeenMultiLevelTiled(block_sref)) {
return false;
}
// Cond 6. The block has at lease one consumer.
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,14 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri
*/
bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if the given block has been applied by multi-level tiling. We check this by examine
* the block's annotation.
* \param block_sref The block to be checked
* \return A boolean indicating whether the block has been multi-level tiled.
*/
bool HasBeenMultiLevelTiled(const StmtSRef& block_sref);

/*!
* \brief Checks if the rfactor or cross thread reduction is beneficial to the given block.
* \param self The schedule state.
Expand Down
14 changes: 11 additions & 3 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,10 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref
return total_unused_block_vars >= 1;
}

bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) {
return tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined();
}

std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref) {
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
Expand Down Expand Up @@ -1976,12 +1980,16 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
return false;
}

// Cond 3. The block is a reduction block and has trivial binding.
// Cond 3. The block satisfies all the following properties
// - it is a reduction block;
// - it has trivial bindings;
// - it has not been tiled by multi-level tiling.
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
if (!(IsReductionBlock(self, block_sref, scope_sref) && //
IsTrivialBinding(self, block_sref))) {
if (!IsReductionBlock(self, block_sref, scope_sref) //
|| !IsTrivialBinding(self, block_sref) //
|| HasBeenMultiLevelTiled(block_sref)) {
return false;
}

Expand Down
130 changes: 130 additions & 0 deletions tests/python/unittest/test_meta_schedule_sketch_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def _target() -> Target:
return Target("cuda", host="llvm")


def _target_with_max_threads_per_block() -> Target:
return Target("nvidia/geforce-rtx-3080")


def test_meta_schedule_cuda_sketch_matmul():
# pylint: disable=line-too-long
expected = [
Expand Down Expand Up @@ -289,8 +293,134 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl
check_trace(spaces, expected)


def test_meta_schedule_cuda_sketch_batchnorm():
# pylint: disable=line-too-long
expected = [
[
'b0 = sch.get_block(name="C", func_name="main")',
'b1 = sch.get_block(name="root", func_name="main")',
"b2, = sch.get_consumers(block=b0)",
"l3, = sch.get_loops(block=b2)",
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l5, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9, l10 = sch.get_loops(block=b0)",
"l11 = sch.fuse(l9, l10)",
"l12, l13 = sch.split(loop=l11, factors=[None, v4])",
'sch.bind(loop=l13, thread_axis="threadIdx.x")',
"v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v14)',
],
[
'b0 = sch.get_block(name="root", func_name="main")',
"v1 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)',
],
]
# pylint: enable=line-too-long
ctx = create_context(
create_prim_func(
te_workload.norm_bmn(
B=1,
M=256,
N=256,
)
),
target=_target_with_max_threads_per_block(),
)
spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
assert len(spaces) == 2
check_trace(spaces, expected)


def test_meta_schedule_cuda_sketch_softmax():
# pylint: disable=line-too-long
expected = [
[
'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
'b1 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")',
'b3 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b1)",
"b4, = sch.get_consumers(block=b2)",
"l5, l6 = sch.get_loops(block=b4)",
"v7 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l8, l9 = sch.split(loop=l6, factors=[None, v7])",
'sch.bind(loop=l9, thread_axis="threadIdx.x")',
"sch.compute_at(block=b2, loop=l5, preserve_unit_loops=True)",
'sch.set_scope(block=b2, buffer_index=0, storage_scope="shared")',
"l10, l11, l12 = sch.get_loops(block=b2)",
"l13, l14 = sch.split(loop=l12, factors=[None, v7])",
'sch.bind(loop=l14, thread_axis="threadIdx.x")',
"b15, b16 = sch.get_consumers(block=b0)",
"l17, l18, l19, l20 = sch.get_loops(block=b15)",
"sch.compute_at(block=b0, loop=l17, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l21, l22, l23 = sch.get_loops(block=b0)",
"l24, l25 = sch.split(loop=l23, factors=[None, v7])",
'sch.bind(loop=l25, thread_axis="threadIdx.x")',
"v26 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v26)',
],
[
'b0 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")',
'b2 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b0)",
"b3, = sch.get_consumers(block=b1)",
"l4, l5 = sch.get_loops(block=b3)",
"v6 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l7, l8 = sch.split(loop=l5, factors=[None, v6])",
'sch.bind(loop=l8, thread_axis="threadIdx.x")',
"sch.compute_at(block=b1, loop=l4, preserve_unit_loops=True)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l9, l10, l11 = sch.get_loops(block=b1)",
"l12, l13 = sch.split(loop=l11, factors=[None, v6])",
'sch.bind(loop=l13, thread_axis="threadIdx.x")',
"v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v14)',
],
[
'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
'b1 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b2 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b1)",
"v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l4, l5 = sch.get_loops(block=b0)",
"l6, l7 = sch.split(loop=l5, factors=[None, v3])",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
"v8 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v8)',
],
[
'b0 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b1 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b0)",
"v2 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)',
],
]
# pylint: enable=line-too-long
ctx = create_context(
create_prim_func(
te_workload.softmax_mn(
m=256,
n=256,
)
),
target=_target_with_max_threads_per_block(),
)
spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
assert len(spaces) == 4
check_trace(spaces, expected)


if __name__ == "__main__":
test_meta_schedule_cuda_sketch_matmul()
test_meta_schedule_cuda_sketch_matmul_relu()
test_meta_schedule_cuda_sketch_conv2d_nchw()
test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu()
test_meta_schedule_cuda_sketch_batchnorm()
test_meta_schedule_cuda_sketch_softmax()

0 comments on commit fd7712c

Please sign in to comment.