diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5be1b9626d9c..bee9819a228e 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1547,6 +1547,9 @@ constexpr const char* software_pipeline_async_stages = "software_pipeline_async_ /*! \brief Mark the buffers which is const access and can be transformed layout. */ constexpr const char* layout_free_buffers = "layout_free_buffers"; +/*! \brief Mark the local stage for the shared memory access should be added. */ +constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage"; + /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index c758a00b3f0f..fd4261e4a4e3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -674,6 +674,12 @@ TVM_DLL Pass InjectPTXAsyncCopy(); */ TVM_DLL Pass RemoveWeightLayoutRewriteBlock(); +/*! + * \brief Add the explicit local stage for the shared memory access on GPU. + * \return The pass. + */ +TVM_DLL Pass ManifestSharedMemoryLocalStage(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index eb2cff641ca3..324471c71891 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -949,3 +949,14 @@ def RemoveWeightLayoutRewriteBlock(): The result pass """ return _ffi_api.RemoveWeightLayoutRewriteBlock() # type: ignore + + +def ManifestSharedMemoryLocalStage(): + """Add the explicit local stage for the shared memory access on GPU. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ManifestSharedMemoryLocalStage() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9bd2e8a812ea..e528686d967d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -199,6 +199,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc new file mode 100644 index 000000000000..3a3abf0b801c --- /dev/null +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file manifest_shared_memroy_local_stage.cc + * \brief Add the explicit local stage for the shared memory access on GPU. + * + * This pass finds the cache_read stage on the shared memory, and create another intermediate stage + * to store the data into local memory first, and then copy the data from local memory to the shared + * memory. This is similar to the schedule primitive cache_read, but it bypasses the limitation + * of requiring buffer access to be contiguous in each dimension. + */ +#include +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../schedule/transform.h" +#include "tvm/tir/stmt.h" + +namespace tvm { +namespace tir { + +/*! \brief Rewriter for the block storing to the target buffer. Create an intermediate cache stage + * to store the result. Rewrite the original block to load from the intermediate buffer. + */ +class IntermediateStageRewriter { + public: + explicit IntermediateStageRewriter(const Array& ancestor_loop_or_blocks) + : ancestor_loop_or_blocks_(ancestor_loop_or_blocks) {} + + std::tuple Rewrite(const BlockNode* block) { + const BufferStoreNode* store = block->body.as(); + CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == + runtime::StorageRank::kShared) + << "ValueError: Expect the body of the block to be BufferStore to shared memory."; + + const Buffer& target_buffer = store->buffer; + + // Step 0: Collect relaxed loops + std::vector relaxed_loops = CollectRelaxedOuterLoops(block, target_buffer); + + // Step 1: Create buffer for the local stage + Buffer new_buffer{nullptr}; + Array buffer_indices; + std::tie(new_buffer, buffer_indices) = CreateIntermediateBuffer(relaxed_loops, target_buffer); + + // Step 2: Create the local stage block + Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); + + // Step 3: Create BufferLoad from the intermediate buffer + BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); + BufferStore new_buffer_store = Downcast(block->body); + new_buffer_store.CopyOnWrite()->value = new_buffer_load; + Block new_block = GetRef(block); + new_block.CopyOnWrite()->body = std::move(new_buffer_store); + + return {target_buffer, new_buffer, new_block, local_stage}; + } + + private: + /*! \brief Collect relaxed outer loops from innermost to outermost */ + std::vector CollectRelaxedOuterLoops(const BlockNode* block, + const Buffer& target_buffer) { + std::vector relaxed_loops; + for (int n = static_cast(ancestor_loop_or_blocks_.size()) - 1, i = n - 1; i >= 0; --i) { + const Stmt& ancestor = ancestor_loop_or_blocks_[i]; + if (const ForNode* ancestor_loop = ancestor.as()) { + CHECK(ancestor_loop->kind == ForKind::kSerial || + ancestor_loop->kind == ForKind::kVectorized) + << "ValueError: Expect the ancestor loops to be serial or vectorized, got " + << ancestor_loop->kind; + relaxed_loops.push_back(ancestor.as()); + + if (i < n - 1) { + CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1])) + << "ValueError: Expect the ancestor loops to have a single child."; + } else { + const BlockRealizeNode* block_realize = ancestor_loop->body.as(); + ICHECK(block_realize != nullptr); + CHECK(block_realize != nullptr && block_realize->block.get() == block) + << "ValueError: Expect the ancestor loops to have a single child."; + } + } else { + const BlockRealizeNode* ancestor_block_realize = ancestor.as(); + ICHECK(ancestor_block_realize != nullptr); + const BlockNode* ancestor_block = ancestor_block_realize->block.get(); + auto it = std::find_if( + ancestor_block->alloc_buffers.begin(), ancestor_block->alloc_buffers.end(), + [&target_buffer](const Buffer& buffer) { return buffer.same_as(target_buffer); }); + CHECK(it != ancestor_block->alloc_buffers.end()) + << "ValueError: Expect the shared memory allocation to be in the parent block."; + break; + } + } + return relaxed_loops; + } + + /*! \brief Create the intermediate stage. */ + Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer, + Array local_stage_indices, + std::vector relaxed_loops, const BufferStoreNode* store) { + // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer. + Stmt local_stage = BufferStore(new_buffer, store->value, local_stage_indices); + + // Step 1: Make block and block realize + BufferRegion write_buffer_region = BufferRegion::FromPoint(new_buffer, local_stage_indices); + local_stage = + Block(/*iter_vars=*/{}, /*reads=*/block->reads, /*writes=*/{write_buffer_region}, "", + /*body=*/std::move(local_stage)); + local_stage = BlockRealize( + /*iter_values=*/{}, + /*predicate=*/ancestor_loop_or_blocks_.back().as()->predicate, + Downcast(local_stage)); + + // Step 2: Add outer loops + Map subst_map; + for (const ForNode* relaxed_loop : relaxed_loops) { + ObjectPtr for_node = make_object(*relaxed_loop); + for_node->loop_var = for_node->loop_var.copy_with_suffix(""); + for_node->body = std::move(local_stage); + local_stage = For(for_node); + subst_map.Set(relaxed_loop->loop_var, for_node->loop_var); + } + local_stage = Substitute(local_stage, subst_map); + return local_stage; + } + + /*! \brief Create the intermediate buffer with the extents of the relaxed outer loops. */ + std::pair> CreateIntermediateBuffer( + const std::vector relaxed_loops, const Buffer& buffer) const { + Array buffer_indices; + Array new_buffer_shape; + + // Create the intermediate buffer for the local stage. The shape of the new buffer is the + // extents of the relaxed outer loops. + + for (auto it = relaxed_loops.rbegin(); it != relaxed_loops.rend(); ++it) { + const ForNode* relaxed_loop = *it; + buffer_indices.push_back(relaxed_loop->min + relaxed_loop->loop_var); + new_buffer_shape.push_back(relaxed_loop->extent); + } + Buffer new_buffer = WithScope(buffer, "local"); + new_buffer.CopyOnWrite()->shape = new_buffer_shape; + return {new_buffer, buffer_indices}; + } + + const Array& ancestor_loop_or_blocks_; +}; + +class SharedMemoryLocalStageInserter : public StmtMutator { + public: + Stmt VisitStmt_(const ForNode* op) final { + ancestor_loop_or_blocks_.push_back(GetRef(op)); + Stmt new_stmt = StmtMutator::VisitStmt_(op); + ancestor_loop_or_blocks_.pop_back(); + return new_stmt; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + ancestor_loop_or_blocks_.push_back(GetRef(op)); + Stmt new_stmt = StmtMutator::VisitStmt_(op); + ancestor_loop_or_blocks_.pop_back(); + return new_stmt; + } + + Stmt VisitStmt_(const BlockNode* op) final { + if (op->annotations.count(attr::manifest_shared_memory_local_stage)) { + // Rewrite the shared memory access to load from the intermediate buffer. + // The annotated block must be a leaf block (will be checked during rewriting). No need to + // visit its body recursively. + + Buffer target_buffer{nullptr}; + Buffer new_buffer{nullptr}; + Block new_block{nullptr}; + Stmt local_stage{nullptr}; + IntermediateStageRewriter rewriter(ancestor_loop_or_blocks_); + std::tie(target_buffer, new_buffer, new_block, local_stage) = rewriter.Rewrite(op); + buffer_remap_.Set(target_buffer, new_buffer); + + new_block.CopyOnWrite()->annotations.erase(attr::manifest_shared_memory_local_stage); + buffer_local_stage_.Set(target_buffer, local_stage); + target_buffers_.push_back(target_buffer); + + return std::move(new_block); + } + + std::unordered_set allocated_buffers( + op->alloc_buffers.begin(), op->alloc_buffers.end()); + + // Visit children and insert local stages (if any) to the proper location. + Array new_alloc_buffers; + Array new_seq; + + // Helper function to check if the subtree (body of the block) contains any target buffers. + // If so, the allocated intermediate buffer and the local stage should be lifted to the current + // block. + auto f_check_subtree = [&](int start, int end) { + for (int i = start; i < end; ++i) { + const Buffer& buffer = target_buffers_[i]; + if (allocated_buffers.count(buffer)) { + new_seq.push_back(buffer_local_stage_.at(buffer)); + new_alloc_buffers.push_back(buffer_remap_.at(buffer)); + } + } + }; + + if (const SeqStmtNode* seq = op->body.as()) { + // Visit each element of the SeqStmt. Create a new SeqStmt if any of the children is modified. + bool changed = false; // whether the SeqStmt has been changed + for (int i = 0, n = seq->seq.size(); i < n; ++i) { + int subtree_start = target_buffers_.size(); + Stmt new_seq_elem = VisitStmt(seq->seq[i]); + int subtree_end = target_buffers_.size(); + f_check_subtree(subtree_start, subtree_end); + new_seq.push_back(new_seq_elem); + if (!new_seq_elem.same_as(seq->seq[i])) { + changed = true; + } + } + if (!changed) { + return GetRef(op); + } + } else { + int subtree_start = target_buffers_.size(); + Stmt body = VisitStmt(op->body); + int subtree_end = target_buffers_.size(); + f_check_subtree(subtree_start, subtree_end); + if (body.same_as(op->body)) { + return GetRef(op); + } + new_seq.push_back(body); + } + + Block new_block = GetRef(op); + BlockNode* new_block_node = new_block.CopyOnWrite(); + // Add new buffer allocations if any. + if (new_alloc_buffers.size() > 0) { + new_block_node->alloc_buffers = Concat(new_block_node->alloc_buffers, new_alloc_buffers); + } + new_block_node->body = new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq); + return std::move(new_block); + } + + std::vector ancestor_loop_or_blocks_; // ancestor loops or block realize + Map buffer_remap_; // mapping from the target buffer to the intermediate buffer + Map buffer_local_stage_; // mapping from the target buffer to the local stage + Array target_buffers_; // the target buffers for rewriting +}; + +namespace transform { + +Pass ManifestSharedMemoryLocalStage() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = SharedMemoryLocalStageInserter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage") + .set_body_typed(ManifestSharedMemoryLocalStage); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index cb6cb0f93fb6..3828412de054 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4623,7 +4623,7 @@ def test_fn(f, dim=None, keepdim=False): return lambda x: f(x, dim=dim, keepdim=keepdim) def test_fn_no_arg(f): - return lambda x: f(x) + return lambda x: f(x) # pylint: disable=unnecessary-lambda for f in [torch.all, torch.any]: verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()]) diff --git a/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py new file mode 100644 index 000000000000..111b91d5fd54 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm.script import tir as T + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + + +@tvm.script.ir_module +class MatmulBefore: + @T.prim_func + def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # body + # with T.block("root") + for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): + for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): + for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): + for k_0 in T.serial(32): + with T.block(): + T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("A_shared"): + T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.block_attr({"tir.manifest_shared_memory_local_stage":1}) + A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("B_shared"): + T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.block_attr({"tir.manifest_shared_memory_local_stage":1}) + B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] + for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): + with T.block("C"): + T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) + T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) + if k_0 * 32 + k_1 * 16 + k_2 == 0: + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0) + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + + +@tvm.script.ir_module +class MatmulAfter: + @T.prim_func + def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # body + # with T.block("root") + for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): + for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): + for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): + for k_0 in T.serial(32): + with T.block(): + T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + A_shared_local = T.alloc_buffer([64, 4], dtype="float32", scope="local") + B_shared_local = T.alloc_buffer([64, 4], dtype="float32", scope="local") + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block(): + T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]) + A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("A_shared"): + T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block(): + T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]) + B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("B_shared"): + T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): + with T.block("C"): + T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) + T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) + if k_0 * 32 + k_1 * 16 + k_2 == 0: + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0) + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + + +# fmt: on +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + + +def _check(before, expected): + after = tvm.tir.transform.ManifestSharedMemoryLocalStage()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_transform_matmul(): + _check(MatmulBefore, MatmulAfter) + + +if __name__ == "__main__": + tvm.testing.main()