From 612f5c102e98b90975e01da834c922d689b4850f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Aug 2022 16:23:16 -0700 Subject: [PATCH 1/3] [TIR] Add pass ManifestSharedMemoryLocalStage --- include/tvm/tir/stmt.h | 3 + include/tvm/tir/transform.h | 6 + python/tvm/tir/transform/transform.py | 11 + src/driver/driver_api.cc | 1 + .../manifest_shared_memory_local_stage.cc | 287 +++++++++ ...form_manifest_shared_memory_local_stage.py | 574 ++++++++++++++++++ 6 files changed, 882 insertions(+) create mode 100644 src/tir/transforms/manifest_shared_memory_local_stage.cc create mode 100644 tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5be1b9626d9c..bee9819a228e 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1547,6 +1547,9 @@ constexpr const char* software_pipeline_async_stages = "software_pipeline_async_ /*! \brief Mark the buffers which is const access and can be transformed layout. */ constexpr const char* layout_free_buffers = "layout_free_buffers"; +/*! \brief Mark the local stage for the shared memory access should be added. */ +constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage"; + /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index c758a00b3f0f..fd4261e4a4e3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -674,6 +674,12 @@ TVM_DLL Pass InjectPTXAsyncCopy(); */ TVM_DLL Pass RemoveWeightLayoutRewriteBlock(); +/*! + * \brief Add the explicit local stage for the shared memory access on GPU. + * \return The pass. + */ +TVM_DLL Pass ManifestSharedMemoryLocalStage(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index eb2cff641ca3..324471c71891 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -949,3 +949,14 @@ def RemoveWeightLayoutRewriteBlock(): The result pass """ return _ffi_api.RemoveWeightLayoutRewriteBlock() # type: ignore + + +def ManifestSharedMemoryLocalStage(): + """Add the explicit local stage for the shared memory access on GPU. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ManifestSharedMemoryLocalStage() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9bd2e8a812ea..e528686d967d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -199,6 +199,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc new file mode 100644 index 000000000000..3a3abf0b801c --- /dev/null +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file manifest_shared_memroy_local_stage.cc + * \brief Add the explicit local stage for the shared memory access on GPU. + * + * This pass finds the cache_read stage on the shared memory, and create another intermediate stage + * to store the data into local memory first, and then copy the data from local memory to the shared + * memory. This is similar to the schedule primitive cache_read, but it bypasses the limitation + * of requiring buffer access to be contiguous in each dimension. + */ +#include +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../schedule/transform.h" +#include "tvm/tir/stmt.h" + +namespace tvm { +namespace tir { + +/*! \brief Rewriter for the block storing to the target buffer. Create an intermediate cache stage + * to store the result. Rewrite the original block to load from the intermediate buffer. + */ +class IntermediateStageRewriter { + public: + explicit IntermediateStageRewriter(const Array& ancestor_loop_or_blocks) + : ancestor_loop_or_blocks_(ancestor_loop_or_blocks) {} + + std::tuple Rewrite(const BlockNode* block) { + const BufferStoreNode* store = block->body.as(); + CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == + runtime::StorageRank::kShared) + << "ValueError: Expect the body of the block to be BufferStore to shared memory."; + + const Buffer& target_buffer = store->buffer; + + // Step 0: Collect relaxed loops + std::vector relaxed_loops = CollectRelaxedOuterLoops(block, target_buffer); + + // Step 1: Create buffer for the local stage + Buffer new_buffer{nullptr}; + Array buffer_indices; + std::tie(new_buffer, buffer_indices) = CreateIntermediateBuffer(relaxed_loops, target_buffer); + + // Step 2: Create the local stage block + Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); + + // Step 3: Create BufferLoad from the intermediate buffer + BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); + BufferStore new_buffer_store = Downcast(block->body); + new_buffer_store.CopyOnWrite()->value = new_buffer_load; + Block new_block = GetRef(block); + new_block.CopyOnWrite()->body = std::move(new_buffer_store); + + return {target_buffer, new_buffer, new_block, local_stage}; + } + + private: + /*! \brief Collect relaxed outer loops from innermost to outermost */ + std::vector CollectRelaxedOuterLoops(const BlockNode* block, + const Buffer& target_buffer) { + std::vector relaxed_loops; + for (int n = static_cast(ancestor_loop_or_blocks_.size()) - 1, i = n - 1; i >= 0; --i) { + const Stmt& ancestor = ancestor_loop_or_blocks_[i]; + if (const ForNode* ancestor_loop = ancestor.as()) { + CHECK(ancestor_loop->kind == ForKind::kSerial || + ancestor_loop->kind == ForKind::kVectorized) + << "ValueError: Expect the ancestor loops to be serial or vectorized, got " + << ancestor_loop->kind; + relaxed_loops.push_back(ancestor.as()); + + if (i < n - 1) { + CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1])) + << "ValueError: Expect the ancestor loops to have a single child."; + } else { + const BlockRealizeNode* block_realize = ancestor_loop->body.as(); + ICHECK(block_realize != nullptr); + CHECK(block_realize != nullptr && block_realize->block.get() == block) + << "ValueError: Expect the ancestor loops to have a single child."; + } + } else { + const BlockRealizeNode* ancestor_block_realize = ancestor.as(); + ICHECK(ancestor_block_realize != nullptr); + const BlockNode* ancestor_block = ancestor_block_realize->block.get(); + auto it = std::find_if( + ancestor_block->alloc_buffers.begin(), ancestor_block->alloc_buffers.end(), + [&target_buffer](const Buffer& buffer) { return buffer.same_as(target_buffer); }); + CHECK(it != ancestor_block->alloc_buffers.end()) + << "ValueError: Expect the shared memory allocation to be in the parent block."; + break; + } + } + return relaxed_loops; + } + + /*! \brief Create the intermediate stage. */ + Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer, + Array local_stage_indices, + std::vector relaxed_loops, const BufferStoreNode* store) { + // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer. + Stmt local_stage = BufferStore(new_buffer, store->value, local_stage_indices); + + // Step 1: Make block and block realize + BufferRegion write_buffer_region = BufferRegion::FromPoint(new_buffer, local_stage_indices); + local_stage = + Block(/*iter_vars=*/{}, /*reads=*/block->reads, /*writes=*/{write_buffer_region}, "", + /*body=*/std::move(local_stage)); + local_stage = BlockRealize( + /*iter_values=*/{}, + /*predicate=*/ancestor_loop_or_blocks_.back().as()->predicate, + Downcast(local_stage)); + + // Step 2: Add outer loops + Map subst_map; + for (const ForNode* relaxed_loop : relaxed_loops) { + ObjectPtr for_node = make_object(*relaxed_loop); + for_node->loop_var = for_node->loop_var.copy_with_suffix(""); + for_node->body = std::move(local_stage); + local_stage = For(for_node); + subst_map.Set(relaxed_loop->loop_var, for_node->loop_var); + } + local_stage = Substitute(local_stage, subst_map); + return local_stage; + } + + /*! \brief Create the intermediate buffer with the extents of the relaxed outer loops. */ + std::pair> CreateIntermediateBuffer( + const std::vector relaxed_loops, const Buffer& buffer) const { + Array buffer_indices; + Array new_buffer_shape; + + // Create the intermediate buffer for the local stage. The shape of the new buffer is the + // extents of the relaxed outer loops. + + for (auto it = relaxed_loops.rbegin(); it != relaxed_loops.rend(); ++it) { + const ForNode* relaxed_loop = *it; + buffer_indices.push_back(relaxed_loop->min + relaxed_loop->loop_var); + new_buffer_shape.push_back(relaxed_loop->extent); + } + Buffer new_buffer = WithScope(buffer, "local"); + new_buffer.CopyOnWrite()->shape = new_buffer_shape; + return {new_buffer, buffer_indices}; + } + + const Array& ancestor_loop_or_blocks_; +}; + +class SharedMemoryLocalStageInserter : public StmtMutator { + public: + Stmt VisitStmt_(const ForNode* op) final { + ancestor_loop_or_blocks_.push_back(GetRef(op)); + Stmt new_stmt = StmtMutator::VisitStmt_(op); + ancestor_loop_or_blocks_.pop_back(); + return new_stmt; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + ancestor_loop_or_blocks_.push_back(GetRef(op)); + Stmt new_stmt = StmtMutator::VisitStmt_(op); + ancestor_loop_or_blocks_.pop_back(); + return new_stmt; + } + + Stmt VisitStmt_(const BlockNode* op) final { + if (op->annotations.count(attr::manifest_shared_memory_local_stage)) { + // Rewrite the shared memory access to load from the intermediate buffer. + // The annotated block must be a leaf block (will be checked during rewriting). No need to + // visit its body recursively. + + Buffer target_buffer{nullptr}; + Buffer new_buffer{nullptr}; + Block new_block{nullptr}; + Stmt local_stage{nullptr}; + IntermediateStageRewriter rewriter(ancestor_loop_or_blocks_); + std::tie(target_buffer, new_buffer, new_block, local_stage) = rewriter.Rewrite(op); + buffer_remap_.Set(target_buffer, new_buffer); + + new_block.CopyOnWrite()->annotations.erase(attr::manifest_shared_memory_local_stage); + buffer_local_stage_.Set(target_buffer, local_stage); + target_buffers_.push_back(target_buffer); + + return std::move(new_block); + } + + std::unordered_set allocated_buffers( + op->alloc_buffers.begin(), op->alloc_buffers.end()); + + // Visit children and insert local stages (if any) to the proper location. + Array new_alloc_buffers; + Array new_seq; + + // Helper function to check if the subtree (body of the block) contains any target buffers. + // If so, the allocated intermediate buffer and the local stage should be lifted to the current + // block. + auto f_check_subtree = [&](int start, int end) { + for (int i = start; i < end; ++i) { + const Buffer& buffer = target_buffers_[i]; + if (allocated_buffers.count(buffer)) { + new_seq.push_back(buffer_local_stage_.at(buffer)); + new_alloc_buffers.push_back(buffer_remap_.at(buffer)); + } + } + }; + + if (const SeqStmtNode* seq = op->body.as()) { + // Visit each element of the SeqStmt. Create a new SeqStmt if any of the children is modified. + bool changed = false; // whether the SeqStmt has been changed + for (int i = 0, n = seq->seq.size(); i < n; ++i) { + int subtree_start = target_buffers_.size(); + Stmt new_seq_elem = VisitStmt(seq->seq[i]); + int subtree_end = target_buffers_.size(); + f_check_subtree(subtree_start, subtree_end); + new_seq.push_back(new_seq_elem); + if (!new_seq_elem.same_as(seq->seq[i])) { + changed = true; + } + } + if (!changed) { + return GetRef(op); + } + } else { + int subtree_start = target_buffers_.size(); + Stmt body = VisitStmt(op->body); + int subtree_end = target_buffers_.size(); + f_check_subtree(subtree_start, subtree_end); + if (body.same_as(op->body)) { + return GetRef(op); + } + new_seq.push_back(body); + } + + Block new_block = GetRef(op); + BlockNode* new_block_node = new_block.CopyOnWrite(); + // Add new buffer allocations if any. + if (new_alloc_buffers.size() > 0) { + new_block_node->alloc_buffers = Concat(new_block_node->alloc_buffers, new_alloc_buffers); + } + new_block_node->body = new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq); + return std::move(new_block); + } + + std::vector ancestor_loop_or_blocks_; // ancestor loops or block realize + Map buffer_remap_; // mapping from the target buffer to the intermediate buffer + Map buffer_local_stage_; // mapping from the target buffer to the local stage + Array target_buffers_; // the target buffers for rewriting +}; + +namespace transform { + +Pass ManifestSharedMemoryLocalStage() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = SharedMemoryLocalStageInserter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage") + .set_body_typed(ManifestSharedMemoryLocalStage); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py new file mode 100644 index 000000000000..bd36993b67e0 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py @@ -0,0 +1,574 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm.script import tir as T + + +@tvm.script.ir_module +class MatmulBefore: + @T.prim_func + def main( + A: T.Buffer[(1024, 1024), "float32"], + B: T.Buffer[(1024, 1024), "float32"], + C: T.Buffer[(1024, 1024), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # body + # with T.block("root") + for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): + for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): + for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): + for k_0 in T.serial(32): + with T.block(): + T.reads( + A[ + blockIdx_y * 32 : blockIdx_y * 32 + 32, + k_0 * 32 : k_0 * 32 + 32, + ], + B[ + k_0 * 32 : k_0 * 32 + 32, + blockIdx_x * 32 : blockIdx_x * 32 + 32, + ], + ) + T.writes( + C[ + blockIdx_y * 32 : blockIdx_y * 32 + 32, + blockIdx_x * 32 : blockIdx_x * 32 + 32, + ] + ) + A_shared = T.alloc_buffer( + [1024, 1024], dtype="float32", scope="shared" + ) + B_shared = T.alloc_buffer( + [1024, 1024], dtype="float32", scope="shared" + ) + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("A_shared"): + T.reads( + A[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.writes( + A_shared[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.block_attr( + {"tir.manifest_shared_memory_local_stage": 1} + ) + A_shared[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] = A[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("B_shared"): + T.reads( + B[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.writes( + B_shared[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.block_attr( + {"tir.manifest_shared_memory_local_stage": 1} + ) + B_shared[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] = B[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): + with T.block("C"): + T.reads( + A_shared[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + k_0 * 32 + k_1 * 16 + k_2, + ], + B_shared[ + k_0 * 32 + k_1 * 16 + k_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ], + ) + T.writes( + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] + ) + if k_0 * 32 + k_1 * 16 + k_2 == 0: + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] = T.float32(0) + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] = ( + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] + + A_shared[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + k_0 * 32 + k_1 * 16 + k_2, + ] + * B_shared[ + k_0 * 32 + k_1 * 16 + k_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] + ) + + +@tvm.script.ir_module +class MatmulAfter: + @T.prim_func + def main( + A: T.Buffer[(1024, 1024), "float32"], + B: T.Buffer[(1024, 1024), "float32"], + C: T.Buffer[(1024, 1024), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # body + # with T.block("root") + for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): + for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): + for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): + for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): + for k_0 in T.serial(32): + with T.block(): + T.reads( + A[ + blockIdx_y * 32 : blockIdx_y * 32 + 32, + k_0 * 32 : k_0 * 32 + 32, + ], + B[ + k_0 * 32 : k_0 * 32 + 32, + blockIdx_x * 32 : blockIdx_x * 32 + 32, + ], + ) + T.writes( + C[ + blockIdx_y * 32 : blockIdx_y * 32 + 32, + blockIdx_x * 32 : blockIdx_x * 32 + 32, + ] + ) + A_shared = T.alloc_buffer( + [1024, 1024], dtype="float32", scope="shared" + ) + B_shared = T.alloc_buffer( + [1024, 1024], dtype="float32", scope="shared" + ) + A_shared_local = T.alloc_buffer( + [64, 4], dtype="float32", scope="local" + ) + B_shared_local = T.alloc_buffer( + [64, 4], dtype="float32", scope="local" + ) + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block(): + T.reads( + A[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.writes( + A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + ) + A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = A[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("A_shared"): + T.reads( + A[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.writes( + A_shared[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + A_shared[ + blockIdx_y * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] = A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block(): + T.reads( + B[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.writes( + B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + ) + B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = B[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + for ax0_ax1_fused_0 in T.serial(64): + for ax0_ax1_fused_3 in T.vectorized(4): + with T.block("B_shared"): + T.reads( + B[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + T.writes( + B_shared[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] + ) + B_shared[ + k_0 * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + // 32, + blockIdx_x * 32 + + ( + ax0_ax1_fused_0 * 16 + + threadIdx_y * 8 + + threadIdx_x * 4 + + ax0_ax1_fused_3 + ) + % 32, + ] = B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): + with T.block("C"): + T.reads( + A_shared[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + k_0 * 32 + k_1 * 16 + k_2, + ], + B_shared[ + k_0 * 32 + k_1 * 16 + k_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ], + ) + T.writes( + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] + ) + if k_0 * 32 + k_1 * 16 + k_2 == 0: + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] = T.float32(0) + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] = ( + C[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] + + A_shared[ + blockIdx_y * 32 + threadIdx_y * 16 + i_2, + k_0 * 32 + k_1 * 16 + k_2, + ] + * B_shared[ + k_0 * 32 + k_1 * 16 + k_2, + blockIdx_x * 32 + threadIdx_x * 16 + j_2, + ] + ) + + +def _check(before, expected): + after = tvm.tir.transform.ManifestSharedMemoryLocalStage()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_transform_matmul(): + _check(MatmulBefore, MatmulAfter) + + +if __name__ == "__main__": + tvm.testing.main() From 54a068324dfe9baeb3e413832ccb14c16fd9af13 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 12 Aug 2022 17:12:47 -0700 Subject: [PATCH 2/3] update style --- ...form_manifest_shared_memory_local_stage.py | 536 ++---------------- 1 file changed, 48 insertions(+), 488 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py index bd36993b67e0..111b91d5fd54 100644 --- a/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py +++ b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py @@ -19,14 +19,14 @@ from tvm.script import tir as T +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + + @tvm.script.ir_module class MatmulBefore: @T.prim_func - def main( - A: T.Buffer[(1024, 1024), "float32"], - B: T.Buffer[(1024, 1024), "float32"], - C: T.Buffer[(1024, 1024), "float32"], - ) -> None: + def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body @@ -37,240 +37,37 @@ def main( for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for k_0 in T.serial(32): with T.block(): - T.reads( - A[ - blockIdx_y * 32 : blockIdx_y * 32 + 32, - k_0 * 32 : k_0 * 32 + 32, - ], - B[ - k_0 * 32 : k_0 * 32 + 32, - blockIdx_x * 32 : blockIdx_x * 32 + 32, - ], - ) - T.writes( - C[ - blockIdx_y * 32 : blockIdx_y * 32 + 32, - blockIdx_x * 32 : blockIdx_x * 32 + 32, - ] - ) - A_shared = T.alloc_buffer( - [1024, 1024], dtype="float32", scope="shared" - ) - B_shared = T.alloc_buffer( - [1024, 1024], dtype="float32", scope="shared" - ) + T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("A_shared"): - T.reads( - A[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.writes( - A_shared[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.block_attr( - {"tir.manifest_shared_memory_local_stage": 1} - ) - A_shared[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] = A[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] + T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.block_attr({"tir.manifest_shared_memory_local_stage":1}) + A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("B_shared"): - T.reads( - B[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.writes( - B_shared[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.block_attr( - {"tir.manifest_shared_memory_local_stage": 1} - ) - B_shared[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] = B[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] + T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.block_attr({"tir.manifest_shared_memory_local_stage":1}) + B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): with T.block("C"): - T.reads( - A_shared[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - k_0 * 32 + k_1 * 16 + k_2, - ], - B_shared[ - k_0 * 32 + k_1 * 16 + k_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ], - ) - T.writes( - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] - ) + T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) + T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) if k_0 * 32 + k_1 * 16 + k_2 == 0: - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] = T.float32(0) - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] = ( - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] - + A_shared[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - k_0 * 32 + k_1 * 16 + k_2, - ] - * B_shared[ - k_0 * 32 + k_1 * 16 + k_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] - ) + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0) + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] @tvm.script.ir_module class MatmulAfter: @T.prim_func - def main( - A: T.Buffer[(1024, 1024), "float32"], - B: T.Buffer[(1024, 1024), "float32"], - C: T.Buffer[(1024, 1024), "float32"], - ) -> None: + def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body @@ -281,284 +78,47 @@ def main( for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for k_0 in T.serial(32): with T.block(): - T.reads( - A[ - blockIdx_y * 32 : blockIdx_y * 32 + 32, - k_0 * 32 : k_0 * 32 + 32, - ], - B[ - k_0 * 32 : k_0 * 32 + 32, - blockIdx_x * 32 : blockIdx_x * 32 + 32, - ], - ) - T.writes( - C[ - blockIdx_y * 32 : blockIdx_y * 32 + 32, - blockIdx_x * 32 : blockIdx_x * 32 + 32, - ] - ) - A_shared = T.alloc_buffer( - [1024, 1024], dtype="float32", scope="shared" - ) - B_shared = T.alloc_buffer( - [1024, 1024], dtype="float32", scope="shared" - ) - A_shared_local = T.alloc_buffer( - [64, 4], dtype="float32", scope="local" - ) - B_shared_local = T.alloc_buffer( - [64, 4], dtype="float32", scope="local" - ) + T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32]) + A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") + A_shared_local = T.alloc_buffer([64, 4], dtype="float32", scope="local") + B_shared_local = T.alloc_buffer([64, 4], dtype="float32", scope="local") for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block(): - T.reads( - A[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.writes( - A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] - ) - A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = A[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] + T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]) + A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("A_shared"): - T.reads( - A[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.writes( - A_shared[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - A_shared[ - blockIdx_y * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] = A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block(): - T.reads( - B[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.writes( - B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] - ) - B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = B[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] + T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]) + B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("B_shared"): - T.reads( - B[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - T.writes( - B_shared[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] - ) - B_shared[ - k_0 * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - // 32, - blockIdx_x * 32 - + ( - ax0_ax1_fused_0 * 16 - + threadIdx_y * 8 - + threadIdx_x * 4 - + ax0_ax1_fused_3 - ) - % 32, - ] = B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] + T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) + B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): with T.block("C"): - T.reads( - A_shared[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - k_0 * 32 + k_1 * 16 + k_2, - ], - B_shared[ - k_0 * 32 + k_1 * 16 + k_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ], - ) - T.writes( - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] - ) + T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) + T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) if k_0 * 32 + k_1 * 16 + k_2 == 0: - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] = T.float32(0) - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] = ( - C[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] - + A_shared[ - blockIdx_y * 32 + threadIdx_y * 16 + i_2, - k_0 * 32 + k_1 * 16 + k_2, - ] - * B_shared[ - k_0 * 32 + k_1 * 16 + k_2, - blockIdx_x * 32 + threadIdx_x * 16 + j_2, - ] - ) + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0) + C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + + +# fmt: on +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks def _check(before, expected): From 09f326bdbf4267d1eb59c01dc25093b461028e72 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 12 Aug 2022 18:58:32 -0700 Subject: [PATCH 3/3] fix lint --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index cb6cb0f93fb6..3828412de054 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4623,7 +4623,7 @@ def test_fn(f, dim=None, keepdim=False): return lambda x: f(x, dim=dim, keepdim=keepdim) def test_fn_no_arg(f): - return lambda x: f(x) + return lambda x: f(x) # pylint: disable=unnecessary-lambda for f in [torch.all, torch.any]: verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()])