From 1d432c53aae50ff01761d16d0c371170fc2e378e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 4 Jul 2023 21:54:31 -0400 Subject: [PATCH] [TIR][Schedule] Scoped CacheRead/Write producing compact region This PR enhances CacheRead/Write so that when a cache operation is performed under an inner block, the generated cache buffer will have the shape as compact as possible, by region consumption analysis. The motivation of this change comes from the needs of dynamic shape TIR scheduling, in which case we may isolate a "static shape" internal block using blockize, and do further scheduling inside the internal block. For such cases, the current CacheRead/Write inside the static-shape block will still produce dynamic-shape cache buffers, which is not ideal for analysis and subsequent scheduling. One thing that worths noting is that, to ensure the IR correctness after inserting the cache block, we will only compact the cache buffer when all the consumer blocks of the read buffer (for CacheRead) or the write buffer (for CacheWrite) are children blocks of the cache block insertion location. Otherwise we will insist allocating the full-size cache buffer. Co-authored-by: Bohan Hou --- src/tir/schedule/primitive.h | 4 +- .../schedule/primitive/cache_read_write.cc | 316 +++++++++++++++--- .../test_tir_schedule_cache_read_write.py | 156 +++++++-- 3 files changed, 403 insertions(+), 73 deletions(-) diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 3b74e17781a4..4ae65ddc1768 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -105,7 +105,7 @@ TVM_DLL std::vector SamplePerfectTile( * The sampled tile size will be partitioned into two parts. The second part has a guarantee * that their extent's product have a factor of `innerpart_factor`. The first part is loops at * [0, partition_pos); the second part is loops at [partition_pos, n) and we will have - * `innerpart_factor` | \prod_{l=partition_pos}^{n-1} l.extent + * `innerpart_factor` | prod_{l=partition_pos}^{n-1} l.extent * * \param rand_state The random state * \param extent The loop extent to be tiled @@ -123,7 +123,7 @@ TVM_DLL std::vector SamplePartitionedTile( * The sampled tile size will be partitioned into two parts. The second part has a guarantee * that their extent's product have a factor of `innerpart_factor`. The first part is loops at * [0, partition_pos); the second part is loops at [partition_pos, n) and we will have - * `innerpart_factor` | \prod_{l=partition_pos}^{n-1} l.extent + * `innerpart_factor` | prod_{l=partition_pos}^{n-1} l.extent * * \param rand_state The random state * \param loop_sref The loop to be tiled diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 0a4cf2329ef3..6f9aa1127584 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -81,9 +81,11 @@ struct CacheStageInfo { Map block_reuse; /*! \brief A set of blocks that will consume the new cache. */ std::unordered_set consumer_blocks; + /*! \brief cache region for the buffer to be cached */ + BufferRegion cache_region; }; -/*! \brief Return the buffer region realted with the buffer */ +/*! \brief Return the buffer region related with the buffer */ Optional GetBufferRegionFromBuffer(const Array& buffer_regions, const Buffer& buffer) { Optional res = NullOpt; @@ -230,10 +232,12 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI * \param cache_region The cached copy region. * \param info The cache stage information, which will be updated in the function. * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \param cache_full_region A boolean indicating if the cache buffer is allocated with + * full region or compact region. * \returns A block indicating the body of the loop nesting. */ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, - const String& storage_scope) { + const String& storage_scope, bool cache_full_region = true) { // loop variables std::vector loop_vars; // bindings in block realize @@ -242,22 +246,50 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, for (const Range& axis_range : cache_region->region) { Var loop_var("ax" + std::to_string(loop_vars.size()), axis_range->extent.dtype()); loop_vars.push_back(loop_var); - iter_values.push_back(axis_range->min + loop_var); + iter_values.push_back(cache_full_region ? (axis_range->min + loop_var) : loop_var); } // block variables Array block_vars; // block access region for read/write buffers - Region access_region; + Region read_access_region; + Region write_access_region; // indices used in block body - Array access_indices; + Array read_access_indices; + Array write_access_indices; // Create block vars, block's accessed region and accessing indices - for (const PrimExpr& dim : cache_region->buffer->shape) { - Var var("v" + std::to_string(access_indices.size()), dim.dtype()); - block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim), - /*var=*/var, - /*IterVarType=*/kDataPar)); - access_indices.push_back(var); - access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + for (int i = 0; i < static_cast(cache_region->buffer->shape.size()); ++i) { + Range axis_range = cache_region->region[i]; + Var var("v" + std::to_string(read_access_indices.size()), axis_range->extent.dtype()); + if (cache_full_region) { + PrimExpr dim = cache_region->buffer->shape[i]; + block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim), + /*var=*/var, + /*IterVarType=*/kDataPar)); + read_access_indices.push_back(var); + write_access_indices.push_back(var); + read_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + write_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + } else { + block_vars.push_back(IterVar( + /*dom=*/Range::FromMinExtent(make_zero(axis_range->extent.dtype()), axis_range->extent), + /*var=*/var, + /*IterVarType=*/kDataPar)); + if (cache_region->buffer.same_as(info->read_buffer)) { + // cache_read + read_access_indices.push_back(axis_range->min + var); + read_access_region.push_back( + Range::FromMinExtent(axis_range->min + var, make_const(var.dtype(), 1))); + write_access_indices.push_back(var); + write_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + } else { + // cache_write + write_access_indices.push_back(axis_range->min + var); + write_access_region.push_back( + Range::FromMinExtent(axis_range->min + var, make_const(var.dtype(), 1))); + read_access_indices.push_back(var); + read_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + } + } } // Create the body block: @@ -266,12 +298,12 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, // write_buffer[access_indices] = read_buffer[access_indices] Block block( /*iter_vars=*/std::move(block_vars), - /*reads=*/{BufferRegion(info->read_buffer, access_region)}, - /*writes=*/{BufferRegion(info->write_buffer, access_region)}, + /*reads=*/{BufferRegion(info->read_buffer, read_access_region)}, + /*writes=*/{BufferRegion(info->write_buffer, write_access_region)}, /*name_hint=*/cache_region->buffer->name + "_" + storage_scope, /*body=*/ - BufferStore(info->write_buffer, BufferLoad(info->read_buffer, access_indices), - access_indices), + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices), + write_access_indices), /*init=*/NullOpt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, @@ -316,7 +348,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, BufferIndexType buffer_index_type) { // iters of the reindex block Array new_block_iters; - // the substition map from the original block iter to the iters of the reindex block + // the substitution map from the original block iter to the iters of the reindex block std::unordered_map block_var_replace_map; // indices to access the reindex buffer and the target buffer Array reindex_indices, target_indices; @@ -404,7 +436,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, } /*! - * \brief Recalculate the `affine_binding` flag of a specifc block + * \brief Recalculate the `affine_binding` flag of a specific block * \param block_sref The sref to the specific block */ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) { @@ -472,7 +504,7 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { * \param buffer The queried buffer * \return The sref of the only writer of the input buffer in the given scope, * or `NullOpt` if no block writes it in the scope. - * \throw NotSingleWriteBlock if there are more than one intrested block. + * \throw NotSingleWriteBlock if there are more than one interested block. */ Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, const Buffer& buffer) { @@ -490,6 +522,41 @@ Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_s } } +/*! + * \brief Check if all the consumer blocks of the given buffer in the given + * block scope are the children block of the given target stmt. + * \param self The state of the schedule . + * \param buffer The buffer whose consumer blocks are to be check. + * \param scope_sref The scope block of the check. + * \param stmt_sref The target stmt + * \return A boolean indicating if all the consumer blocks of the input buffer + * meet the requirement. + */ +bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sref, + StmtSRef stmt_sref) { + // Collect all children blocks of the target stmt. + std::unordered_set blocks_under_target; + for (const StmtSRef& block_sref : GetChildBlocks(self, stmt_sref)) { + const auto* block = block_sref->StmtAs(); + ICHECK(block != nullptr); + blocks_under_target.insert(block); + } + + // For each block in the scope, if it is a consumer of the + // input buffer, check if it is also a child block of the + // target stmt. + for (const StmtSRef& block_sref : GetChildBlocks(self, scope_sref)) { + const auto* block = block_sref->StmtAs(); + ICHECK(block != nullptr); + if (GetBufferRegionFromBuffer(block->reads, buffer).defined()) { + if (blocks_under_target.find(block) == blocks_under_target.end()) { + return false; + } + } + } + return true; +} + /*! * \brief Get the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) * \param self The state of the schedule. @@ -628,7 +695,7 @@ class CacheLocDetector : public StmtVisitor { void VisitStmt_(const BlockNode* block) final { // Only visit the current scope under buffer writer's parent block if (block == scope_sref_->stmt) { - // The block vistied is the current parent scope + // The block visited is the current parent scope StmtVisitor::VisitStmt_(block); // Handling cases when insert outside any loop or cache_read for input buffer if (visited_related_ && !loc_sref_.defined()) { @@ -728,7 +795,7 @@ class CacheInplaceLocDetector : public StmtVisitor { void VisitStmt_(const BlockNode* block) final { // Only visit the current scope under buffer writer's parent block if (block == scope_sref_->stmt) { - // The block vistied is the current parent scope + // The block visited is the current parent scope StmtVisitor::VisitStmt_(block); // Handling cases when insert outside any loop if (visited_block_ && !loc_sref_.defined()) { @@ -777,21 +844,63 @@ class CacheReadRewriter : public StmtExprMutator { * \brief Rewrite the AST and add a cache_read stage with the information provided * \param scope_sref The parent scope of this mutation * \param info The cache stage information + * \param cache_full_region A boolean indicating if the cache buffer is allocated with + * full region or compact region. * \return The new AST rooting at the original parent scope */ - static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) { - CacheReadRewriter rewriter(scope_sref, info); + static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info, + bool cache_full_region = true) { + CacheReadRewriter rewriter(scope_sref, info, cache_full_region); return rewriter(GetRef(scope_sref->stmt)); } private: - explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), info_(info) { - update_access_regions = [&](Array regions) { - return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); + explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info, + bool cache_full_region = true) + : scope_sref_(scope_sref), info_(info), cache_full_region_(cache_full_region) { + auto update_region = [this](const Region& region, const Region& offset) -> Region { + ICHECK_EQ(region.size(), offset.size()); + std::vector ret; + for (size_t i = 0; i < region.size(); ++i) { + ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), + region[i]->extent)); + } + return ret; + }; + + update_access_regions = [this, update_region](Array regions) { + if (cache_full_region_) { + return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); + } + + Array ret; + for (const BufferRegion& region : regions) { + if (region->buffer.same_as(info_->read_buffer)) { + ret.push_back(BufferRegion(info_->write_buffer, + update_region(region->region, info_->cache_region->region))); + } else { + ret.push_back(region); + } + } + return ret; }; - update_match_buffers = [&](Array match_buffers) { - return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); + update_match_buffers = [this, update_region](Array match_buffers) { + if (cache_full_region_) { + return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); + } + + Array ret; + for (const MatchBufferRegion& match_buffer : match_buffers) { + if (match_buffer->source->buffer.same_as(info_->read_buffer)) { + ret.push_back(MatchBufferRegion( + match_buffer->buffer, + BufferRegion(info_->write_buffer, update_region(match_buffer->source->region, + info_->cache_region->region)))); + } else { + ret.push_back(match_buffer); + } + } + return ret; }; } @@ -863,10 +972,21 @@ class CacheReadRewriter : public StmtExprMutator { return std::move(stmt); } + Array RewriteIndices(const Array& indices) { + std::vector ret; + for (size_t i = 0; i < indices.size(); ++i) { + ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); + } + return ret; + } + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { ObjectPtr n = make_object(*load); n->buffer = info_->write_buffer; + if (!cache_full_region_) { + n->indices = std::move(RewriteIndices(load->indices)); + } return PrimExpr(n); } return ExprMutator::VisitExpr_(load); @@ -890,6 +1010,13 @@ class CacheReadRewriter : public StmtExprMutator { std::function(Array)> update_access_regions; /*! \brief function to update match buffers of block being cache read.*/ std::function(Array)> update_match_buffers; + /*! + * \brief A boolean indicating if the cache buffer is allocated with + * full region or compact region. + */ + bool cache_full_region_; + /*! \brief Arithmetic analyzer. */ + arith::Analyzer ana_; friend ReindexCacheReadRewriter; }; @@ -970,23 +1097,66 @@ class CacheWriteRewriter : public StmtExprMutator { * \param scope_sref The parent scope of this mutation. * \param writer_block_sref The only writer block in the scope. * \param info The cache stage information. + * \param cache_full_region A boolean indicating if the cache buffer is allocated with + * full region or compact region. * \return The new AST rooting at the original parent scope. */ static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, - CacheStageInfo* info) { - CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); + CacheStageInfo* info, bool cache_full_region = true) { + CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info, cache_full_region); return rewriter(GetRef(scope_sref->stmt)); } private: explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, - CacheStageInfo* info) - : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) { - update_access_regions = [&](Array regions) { - return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); + CacheStageInfo* info, bool cache_full_region = true) + : scope_sref_(scope_sref), + writer_block_sref_(writer_block_sref), + info_(info), + cache_full_region_(cache_full_region) { + auto update_region = [this](const Region& region, const Region& offset) -> Region { + ICHECK_EQ(region.size(), offset.size()); + std::vector ret; + for (size_t i = 0; i < region.size(); ++i) { + ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), + region[i]->extent)); + } + return ret; }; - update_match_buffers = [&](Array match_buffers) { - return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); + + update_access_regions = [this, update_region](Array regions) { + if (cache_full_region_) { + return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); + } + + Array ret; + for (const BufferRegion& region : regions) { + if (region->buffer.same_as(info_->write_buffer)) { + ret.push_back(BufferRegion(info_->read_buffer, + update_region(region->region, info_->cache_region->region))); + } else { + ret.push_back(region); + } + } + return ret; + }; + update_match_buffers = [this, update_region](Array match_buffers) { + if (cache_full_region_) { + return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); + } + + Array ret; + for (const MatchBufferRegion& match_buffer : match_buffers) { + if (match_buffer->source->buffer.same_as(info_->write_buffer)) { + ret.push_back(MatchBufferRegion( + match_buffer->buffer, + BufferRegion(info_->read_buffer, update_region(match_buffer->source->region, + info_->cache_region->region)))); + } else { + ret.push_back(match_buffer); + } + } + return ret; }; } @@ -1072,11 +1242,22 @@ class CacheWriteRewriter : public StmtExprMutator { return std::move(stmt); } + Array RewriteIndices(const Array& indices) { + std::vector ret; + for (size_t i = 0; i < indices.size(); ++i) { + ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); + } + return ret; + } + Stmt VisitStmt_(const BufferStoreNode* store) override { BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); if (stmt->buffer.same_as(info_->write_buffer)) { auto n = CopyOnWrite(stmt.get()); n->buffer = info_->read_buffer; + if (!cache_full_region_) { + n->indices = std::move(RewriteIndices(n->indices)); + } return Stmt(n); } else { return std::move(stmt); @@ -1087,6 +1268,9 @@ class CacheWriteRewriter : public StmtExprMutator { if (load->buffer.same_as(info_->write_buffer)) { ObjectPtr n = make_object(*load); n->buffer = info_->read_buffer; + if (!cache_full_region_) { + n->indices = std::move(RewriteIndices(n->indices)); + } return PrimExpr(n); } return ExprMutator::VisitExpr_(load); @@ -1112,6 +1296,13 @@ class CacheWriteRewriter : public StmtExprMutator { std::function(Array)> update_access_regions; /*! \brief function to update match buffers of block being cache write.*/ std::function(Array)> update_match_buffers; + /*! + * \brief A boolean indicating if the cache buffer is allocated with + * full region or compact region. + */ + bool cache_full_region_; + /*! \brief Arithmetic analyzer. */ + arith::Analyzer ana_; friend ReindexCacheWriteRewriter; }; @@ -1506,10 +1697,6 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 2. Create CacheStageInfo CacheStageInfo info; info.read_buffer = read_buffer; - // Create the corresponding buffer to be written, i.e. result of cache_read - info.write_buffer = WithScope(read_buffer, storage_scope); - // Create the corresponding buffer allocation - info.alloc = info.write_buffer; // info.consumer_blocks indicates which buffers should consume the cache. for (auto consumer : consumer_blocks) { @@ -1545,9 +1732,25 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff } // Step 4. Making new cache stage block and rewrite readers. - Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, - /*storage_scope=*/storage_scope); - Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + bool cache_full_region = info.loc_sref->StmtAs() == nullptr || + !AllConsumersUnderStmt(self, read_buffer, scope_sref, info.loc_sref); + info.cache_region = cache_region; + info.write_buffer = WithScope(read_buffer, storage_scope); + if (!cache_full_region) { + auto* write_buffer = info.write_buffer.CopyOnWrite(); + std::vector shape; + for (auto cache_range : info.cache_region->region) { + shape.push_back(cache_range->extent); + } + write_buffer->shape = std::move(shape); + } + info.alloc = info.write_buffer; + + Block cache_read_stage = + MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope, /*cache_full_region=*/cache_full_region); + Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info, + /*cache_full_region=*/cache_full_region); // Step 5. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); @@ -1583,11 +1786,8 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 2. Creating CacheStageInfo CacheStageInfo info; - info.read_buffer = WithScope(write_buffer, storage_scope); // Create the corresponding buffer to be written, i.e. result of cache_write info.write_buffer = write_buffer; - // Create the corresponding buffer allocation - info.alloc = info.read_buffer; // info.consumer_blocks indicates which buffers should consume the cache. for (auto consumer : consumer_blocks) { @@ -1608,11 +1808,27 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu BufferRegion cache_region = RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); + bool cache_full_region = info.loc_sref->StmtAs() == nullptr || + !AllConsumersUnderStmt(self, write_buffer, scope_sref, info.loc_sref); + info.cache_region = cache_region; + info.read_buffer = WithScope(write_buffer, storage_scope); + if (!cache_full_region) { + auto* read_buffer = info.read_buffer.CopyOnWrite(); + std::vector shape; + for (auto cache_range : info.cache_region->region) { + shape.push_back(cache_range->extent); + } + read_buffer->shape = std::move(shape); + } + info.alloc = info.read_buffer; + // Step 5. Making new cache stage block and rewrite readers. - Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, - /*storage_scope=*/storage_scope); + Block cache_write_stage = + MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope, /*cache_full_region=*/cache_full_region); Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, - /*writer_block_sref=*/block_sref, /*info=*/&info); + /*writer_block_sref=*/block_sref, /*info=*/&info, + /*cache_full_region=*/cache_full_region); // Step 6. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 95955646c64b..2d460b359181 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -523,7 +523,7 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: for i0, j0 in T.grid(8, 8): with T.block("scope"): i, j = T.axis.remap("SS", [i0, j0]) - A_local = T.alloc_buffer((128, 128), scope="local") + A_local = T.alloc_buffer((16, 16), scope="local") for x, y in T.grid(16, 16): with T.block("A"): vi = T.axis.S(128, i * 16 + x) @@ -531,14 +531,14 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: A[vi, vj] = 1.0 for x, y in T.grid(16, 16): with T.block("A_local"): - vi = T.axis.S(128, i * 16 + x) - vj = T.axis.S(128, j * 16 + y) - A_local[vi, vj] = A[vi, vj] + vi = T.axis.S(16, x) + vj = T.axis.S(16, y) + A_local[vi, vj] = A[i * 16 + vi, j * 16 + vj] for x, y in T.grid(16, 16): with T.block("B"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) - B[vi, vj] = A_local[vi, vj] + 1.0 + B[vi, vj] = A_local[vi - i * 16, vj - j * 16] + 1.0 for i, j in T.grid(128, 128): with T.block("A_global"): vi, vj = T.axis.remap("SS", [i, j]) @@ -866,28 +866,28 @@ def cache_write_under_scope(b: T.handle, c: T.handle) -> None: for i0, j0 in T.grid(8, 8): with T.block("scope"): i, j = T.axis.remap("SS", [i0, j0]) - A_local = T.alloc_buffer((128, 128), scope="local") - B_global = T.alloc_buffer((128, 128)) + A_local = T.alloc_buffer((16, 16), scope="local") + B_global = T.alloc_buffer((16, 16)) for x, y in T.grid(16, 16): with T.block("A_local"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) - A_local[vi, vj] = 1.0 + A_local[vi - i * 16, vj - j * 16] = 1.0 for x, y in T.grid(16, 16): with T.block("A"): - vi = T.axis.S(128, i * 16 + x) - vj = T.axis.S(128, j * 16 + y) - A_global[vi, vj] = A_local[vi, vj] + vi = T.axis.S(16, x) + vj = T.axis.S(16, y) + A_global[i * 16 + vi, j * 16 + vj] = A_local[vi, vj] for x, y in T.grid(16, 16): - with T.block("B_global"): + with T.block("B"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) - B_global[vi, vj] = A_global[vi, vj] + 1.0 + B_global[vi - i * 16, vj - j * 16] = A_global[vi, vj] + 1.0 for x, y in T.grid(16, 16): with T.block("B_global"): - vi = T.axis.S(128, i * 16 + x) - vj = T.axis.S(128, j * 16 + y) - B[vi, vj] = B_global[vi, vj] + vi = T.axis.S(16, x) + vj = T.axis.S(16, y) + B[i * 16 + vi, j * 16 + vj] = B_global[vi, vj] for i, j in T.grid(128, 128): with T.block("A_global"): vi, vj = T.axis.remap("SS", [i, j]) @@ -1167,6 +1167,104 @@ def block_predicate_cache_write_output_buf() -> None: B[v0] = B_shared[v0] +@T.prim_func +def symbolic_matmul_blocked(var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32): + A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4)) + B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) + C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) + T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) + with T.init(): + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0) + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = ( + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] + + A[v_i0_o * 32 + v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i] + ) + + +@T.prim_func +def symbolic_matmul_blocked_cache_read( + var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32 +): + A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4)) + B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) + C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + A_shared = T.alloc_buffer((32, 4), scope="shared") + for ax0, ax1 in T.grid(32, 4): + with T.block("A_shared"): + v0 = T.axis.spatial(32, ax0) + v1 = T.axis.spatial(4, ax1) + T.reads(A[v_i0_o * 32 + v0, v1]) + T.writes(A_shared[v0, v1]) + A_shared[v0, v1] = A[v_i0_o * 32 + v0, v1] + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads(A_shared[v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) + T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) + with T.init(): + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0) + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = ( + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] + + A_shared[v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i] + ) + + +@T.prim_func +def symbolic_matmul_blocked_cache_write( + var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32 +): + A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4)) + B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) + C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + C_pad_local = T.alloc_buffer((32, 32), scope="local") + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) + T.writes(C_pad_local[v_i0_i, v_i1_i]) + with T.init(): + C_pad_local[v_i0_i, v_i1_i] = T.float32(0) + C_pad_local[v_i0_i, v_i1_i] = ( + C_pad_local[v_i0_i, v_i1_i] + + A[v_i0_o * 32 + v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i] + ) + for ax0, ax1 in T.grid(32, 32): + with T.block("C_pad_local"): + v0 = T.axis.spatial(32, ax0) + v1 = T.axis.spatial(32, ax1) + T.reads(C_pad_local[v0, v1]) + T.writes(C[v_i0_o * 32 + v0, v_i1_o * 32 + v1]) + C[v_i0_o * 32 + v0, v_i1_o * 32 + v1] = C_pad_local[v0, v1] + + ########## Testcases for cache_read ########## use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) @@ -1215,7 +1313,7 @@ def test_cache_read_location(use_block_name): tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B") block_c = "C" if use_block_name else sch.get_block("C") @@ -1355,7 +1453,7 @@ def test_cache_write_location(use_block_name): tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. # B read cache buffer and C read original output buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_a = "A" if use_block_name else sch.get_block("A") @@ -1364,7 +1462,7 @@ def test_cache_write_location(use_block_name): tvm.ir.assert_structural_equal(cache_write_multi_consumer_B_consume_cache, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. # B read original output buffer and C read cache buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_a = "A" if use_block_name else sch.get_block("A") @@ -1373,7 +1471,7 @@ def test_cache_write_location(use_block_name): tvm.ir.assert_structural_equal(cache_write_multi_consumer_C_consume_cache, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. # B and C read cache buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_a = "A" if use_block_name else sch.get_block("A") @@ -1532,7 +1630,7 @@ def test_reindex_cache_read_fail_not_match(): ) -def test_reindex_cache_read_faile_not_single_point(): +def test_reindex_cache_read_failed_not_single_point(): sch = tir.Schedule(access_under_scope, debug_mask="all") with pytest.raises(tvm.tir.ScheduleError): sch.reindex_cache_read("scope", 0, "shared", lambda i, j: (i, j)) @@ -1571,5 +1669,21 @@ def test_reindex_cache_write_fail_not_single_point(): sch.reindex_cache_write("scope", 0, "shared", lambda i, j: (i, j)) +def test_symbolic_matmul_blocked_cache_read(use_block_name): + sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") + block = "matmul" if use_block_name else sch.get_block("matmul") + sch.cache_read(block=block, read_buffer_index=0, storage_scope="shared") + tvm.ir.assert_structural_equal(sch.mod["main"], symbolic_matmul_blocked_cache_read) + verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked) + + +def test_symbolic_matmul_blocked_cache_write(use_block_name): + sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") + block = "matmul" if use_block_name else sch.get_block("matmul") + sch.cache_write(block=block, write_buffer_index=0, storage_scope="local") + tvm.ir.assert_structural_equal(sch.mod["main"], symbolic_matmul_blocked_cache_write) + verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked) + + if __name__ == "__main__": tvm.testing.main()