From 2f08d20e1a0de091b1234fe475bfdd228d4899a0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 23 Jan 2022 12:32:46 -0800 Subject: [PATCH] Add threadIdx filtering in Multi-Level-Tiling and Verify-GPU-Code (#20) * Add threadIdx filtering in Multi-Level-Tiling and Verify-GPU-Code * minor * minor * turn off debug flag --- include/tvm/tir/stmt.h | 10 +- src/meta_schedule/postproc/verify_gpu_code.cc | 53 +++++ .../schedule_rule/multi_level_tiling.cc | 41 +++- .../space_generator/post_order_apply.cc | 8 +- tests/python/meta_schedule/run_ansor_cuda.sh | 10 +- .../meta_schedule/run_meta_schedule_cuda.sh | 10 +- ..._meta_schedule_postproc_verify_gpu_code.py | 195 ++++++++++++++++++ ...hedule_schedule_rule_multi_level_tiling.py | 4 +- 8 files changed, 311 insertions(+), 20 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 429d4c2c54b4..d8726541aecd 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1373,9 +1373,17 @@ constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperati */ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; -/*! \brief Mark that tensor core is enbaled in the PrimExpr */ +/*! \brief Mark that tensor core is enabled in the PrimExpr */ constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"; +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_low_inclusive = + "meta_schedule.thread_extent_low_inclusive"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_high_inclusive = + "meta_schedule.thread_extent_high_inclusive"; + /*! * \brief Mark a block as generated by cache_read or cache_write block. * 0 means cache_read; 1 means cache_write. diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 196dd0c403ba..19e5fde06f23 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -20,6 +20,56 @@ #include "../utils.h" +namespace tvm { +namespace tir { + +class ThreadExtentChecker : private StmtVisitor { + public: + static bool Check(const Stmt& stmt) { + try { + ThreadExtentChecker().VisitStmt(stmt); + return true; + } catch (const dmlc::Error& e) { + return false; + } + } + + private: + void VisitStmt_(const ForNode* loop) { + if (IsThreadIdx(GetThreadScope(loop))) { + if (const int64_t* p_ext = GetLoopIntExtent(loop)) { + thread_extent_product *= *p_ext; + StmtVisitor::VisitStmt_(loop); + thread_extent_product /= *p_ext; + return; + } else { + throw dmlc::Error("Dynamic thread extent"); + } + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockNode* block) { + if (Optional low_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { + if (Optional high_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { + int64_t low = low_inclusive.value()->value; + int64_t high = high_inclusive.value()->value; + if (!(low <= thread_extent_product && thread_extent_product <= high)) { + throw dmlc::Error("Thread extent"); + } + } + } + StmtVisitor::VisitStmt_(block); + } + + int64_t thread_extent_product = 1; +}; + +} // namespace tir +} // namespace tvm + namespace tvm { namespace meta_schedule { @@ -66,6 +116,9 @@ class VerifyGPUCodeNode : public PostprocNode { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { + if (!tir::ThreadExtentChecker::Check(prim_func->body)) { + return false; + } IRModule lowered{nullptr}; try { auto pass_list = Array(); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index a5d677c5cdf2..e0438a2eb1ed 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -287,7 +287,17 @@ class MultiLevelTilingNode : public ScheduleRuleNode { } // Do nothing; Inherited from ScheduleRuleNode - void InitializeWithTuneContext(const TuneContext& context) final {} + void InitializeWithTuneContext(const TuneContext& context) final { + if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { + this->max_threads_per_block_ = v.value()->value; + if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + this->thread_warp_size_ = v.value()->value; + } else { + LOG(INFO) << "'thread_warp_size' is not defined in the target"; + } + } + } + // Entry of the mega rule; Inherited from ScheduleRuleNode Array Apply(const Schedule& sch, const BlockRV& block_rv) final { if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { @@ -331,6 +341,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode { std::vector s_indices_; /*! \brief The indices of reduction tiles in `structure` */ std::vector r_indices_; + /*! \brief The size of the thread warp */ + int thread_warp_size_; + /*! \brief The maximum number of threads to be used size of a thread warp */ + int max_threads_per_block_; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("structure", &structure); @@ -342,6 +356,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { // `reuse_write_` is not visited // `s_indices_` is not visited // `r_indices_` is not visited + // `thread_warp_size_` is not visited + // `max_threads_per_block` is not visited } static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; @@ -419,11 +435,20 @@ inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv)); ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it + int64_t spatial_loop_product = 1; std::vector> tiles(s_indices_.size() + r_indices_.size()); for (int i = 0, n = loops.size(); i < n; ++i) { + LoopRV loop = loops[i]; const std::vector* idx = nullptr; if (iter_types[i] == IterVarType::kDataPar) { idx = &s_indices_; + if (spatial_loop_product != -1) { + if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { + spatial_loop_product *= *extent; + } else { + spatial_loop_product = -1; + } + } } else if (iter_types[i] == IterVarType::kCommReduce) { idx = &r_indices_; } else { @@ -431,7 +456,6 @@ inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const } // Do the split int n_tiles = idx->size(); - LoopRV loop = loops[i]; Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, @@ -453,6 +477,17 @@ inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const tiles[i] = {fused}; } state.tiles = Array>{tiles.begin(), tiles.end()}; + if (this->thread_warp_size_ != -1) { + int64_t low_inclusive = 1; + int64_t high_inclusive = this->max_threads_per_block_; + if (spatial_loop_product > 2 * this->thread_warp_size_) { + low_inclusive = this->thread_warp_size_; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive, + Integer(low_inclusive)); + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive, + Integer(high_inclusive)); + } return {state}; } @@ -578,6 +613,8 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optionalthread_warp_size_ = -1; + n->max_threads_per_block_ = -1; return ScheduleRule(n); } diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index fff7c2711218..00abb567ded5 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -106,10 +106,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array GenerateDesignSpace(const IRModule& mod_) final { using ScheduleAndUnvisitedBlocks = std::pair>; - tir::Schedule sch = tir::Schedule::Traced( // - /*mod=*/mod_, // - /*rand_state=*/ForkSeed(&this->rand_state_), // - /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/mod_, // + /*rand_state=*/ForkSeed(&this->rand_state_), // + /*debug_mode=*/0, // /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; diff --git a/tests/python/meta_schedule/run_ansor_cuda.sh b/tests/python/meta_schedule/run_ansor_cuda.sh index 67d5d959d61f..6eda12fe119c 100644 --- a/tests/python/meta_schedule/run_ansor_cuda.sh +++ b/tests/python/meta_schedule/run_ansor_cuda.sh @@ -4,8 +4,8 @@ RPC_HOST="192.168.6.66" RPC_PORT="4445" RPC_KEY="jetson-agx-xavier" TARGET="nvidia/jetson-agx-xavier" -NUM_TRIALS=800 LOG_DIR=$HOME/logs/ansor-cuda/ +NUM_TRIALS=2000 mkdir -p $LOG_DIR @@ -23,19 +23,17 @@ run () { 2>&1 | tee "$LOG_DIR/$name.log" } -# Single op run C1D run C2D -run C3D run CAP run DEP run DIL run GMM run GRP -run NRM -run SFM run T2D -# Subgraph run C2d-BN-RELU run TBG +run C3D +run NRM +run SFM diff --git a/tests/python/meta_schedule/run_meta_schedule_cuda.sh b/tests/python/meta_schedule/run_meta_schedule_cuda.sh index 6509af5f532a..28132a05045a 100644 --- a/tests/python/meta_schedule/run_meta_schedule_cuda.sh +++ b/tests/python/meta_schedule/run_meta_schedule_cuda.sh @@ -4,7 +4,7 @@ RPC_HOST="192.168.6.66" RPC_PORT="4445" RPC_KEY="jetson-agx-xavier" TARGET="nvidia/jetson-agx-xavier" -LOG_DIR=/tmp/logs/ms-cuda/ +LOG_DIR=$HOME/logs/ms-cuda/ NUM_TRIALS=2000 mkdir -p $LOG_DIR @@ -25,19 +25,17 @@ run () { 2>&1 | tee "$work_dir/$name.log" } -# Single op run C1D run C2D -run C3D run CAP run DEP run DIL run GMM run GRP -run NRM -run SFM run T2D -# Subgraph run C2d-BN-RELU run TBG +run C3D +run NRM +run SFM diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index cdebcddf5d6d..f3a318f91358 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -193,6 +193,177 @@ def main(a: T.handle, b: T.handle) -> None: T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on +@T.prim_func +def GmmCuda0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + +@T.prim_func +def GmmCuda1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 0, + "meta_schedule.thread_extent_high_inclusive": 32, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + + +@T.prim_func +def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 1024, + "meta_schedule.thread_extent_high_inclusive": 1024, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant @@ -225,8 +396,32 @@ def test_postproc_verify_gpu_3(): assert not ctx.postprocs[0].apply(sch) +def test_postproc_verify_gpu_4(): + mod = GmmCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_5(): + mod = GmmCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_6(): + mod = GmmCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + if __name__ == "__main__": test_postproc_verify_gpu_0() test_postproc_verify_gpu_1() test_postproc_verify_gpu_2() test_postproc_verify_gpu_3() + test_postproc_verify_gpu_4() + test_postproc_verify_gpu_5() + test_postproc_verify_gpu_6() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index dba661bba03c..c2ad9258f275 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -186,6 +186,8 @@ def test_cuda_matmul(): 'sch.bind(loop=l32, thread_axis="vthread.x")', "l33 = sch.fuse(l12, l22)", 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", @@ -202,7 +204,7 @@ def test_cuda_matmul(): ] ] # pylint: enable=line-too-long - target = Target("cuda", host="llvm") + target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm") ctx = _create_context( create_prim_func( te_workload.matmul(