diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 3b74e17781a4..4ae65ddc1768 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -105,7 +105,7 @@ TVM_DLL std::vector SamplePerfectTile( * The sampled tile size will be partitioned into two parts. The second part has a guarantee * that their extent's product have a factor of `innerpart_factor`. The first part is loops at * [0, partition_pos); the second part is loops at [partition_pos, n) and we will have - * `innerpart_factor` | \prod_{l=partition_pos}^{n-1} l.extent + * `innerpart_factor` | prod_{l=partition_pos}^{n-1} l.extent * * \param rand_state The random state * \param extent The loop extent to be tiled @@ -123,7 +123,7 @@ TVM_DLL std::vector SamplePartitionedTile( * The sampled tile size will be partitioned into two parts. The second part has a guarantee * that their extent's product have a factor of `innerpart_factor`. The first part is loops at * [0, partition_pos); the second part is loops at [partition_pos, n) and we will have - * `innerpart_factor` | \prod_{l=partition_pos}^{n-1} l.extent + * `innerpart_factor` | prod_{l=partition_pos}^{n-1} l.extent * * \param rand_state The random state * \param loop_sref The loop to be tiled diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 0a4cf2329ef3..6f9aa1127584 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -81,9 +81,11 @@ struct CacheStageInfo { Map block_reuse; /*! \brief A set of blocks that will consume the new cache. */ std::unordered_set consumer_blocks; + /*! \brief cache region for the buffer to be cached */ + BufferRegion cache_region; }; -/*! \brief Return the buffer region realted with the buffer */ +/*! \brief Return the buffer region related with the buffer */ Optional GetBufferRegionFromBuffer(const Array& buffer_regions, const Buffer& buffer) { Optional res = NullOpt; @@ -230,10 +232,12 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI * \param cache_region The cached copy region. * \param info The cache stage information, which will be updated in the function. * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \param cache_full_region A boolean indicating if the cache buffer is allocated with + * full region or compact region. * \returns A block indicating the body of the loop nesting. */ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, - const String& storage_scope) { + const String& storage_scope, bool cache_full_region = true) { // loop variables std::vector loop_vars; // bindings in block realize @@ -242,22 +246,50 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, for (const Range& axis_range : cache_region->region) { Var loop_var("ax" + std::to_string(loop_vars.size()), axis_range->extent.dtype()); loop_vars.push_back(loop_var); - iter_values.push_back(axis_range->min + loop_var); + iter_values.push_back(cache_full_region ? (axis_range->min + loop_var) : loop_var); } // block variables Array block_vars; // block access region for read/write buffers - Region access_region; + Region read_access_region; + Region write_access_region; // indices used in block body - Array access_indices; + Array read_access_indices; + Array write_access_indices; // Create block vars, block's accessed region and accessing indices - for (const PrimExpr& dim : cache_region->buffer->shape) { - Var var("v" + std::to_string(access_indices.size()), dim.dtype()); - block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim), - /*var=*/var, - /*IterVarType=*/kDataPar)); - access_indices.push_back(var); - access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + for (int i = 0; i < static_cast(cache_region->buffer->shape.size()); ++i) { + Range axis_range = cache_region->region[i]; + Var var("v" + std::to_string(read_access_indices.size()), axis_range->extent.dtype()); + if (cache_full_region) { + PrimExpr dim = cache_region->buffer->shape[i]; + block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim), + /*var=*/var, + /*IterVarType=*/kDataPar)); + read_access_indices.push_back(var); + write_access_indices.push_back(var); + read_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + write_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + } else { + block_vars.push_back(IterVar( + /*dom=*/Range::FromMinExtent(make_zero(axis_range->extent.dtype()), axis_range->extent), + /*var=*/var, + /*IterVarType=*/kDataPar)); + if (cache_region->buffer.same_as(info->read_buffer)) { + // cache_read + read_access_indices.push_back(axis_range->min + var); + read_access_region.push_back( + Range::FromMinExtent(axis_range->min + var, make_const(var.dtype(), 1))); + write_access_indices.push_back(var); + write_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + } else { + // cache_write + write_access_indices.push_back(axis_range->min + var); + write_access_region.push_back( + Range::FromMinExtent(axis_range->min + var, make_const(var.dtype(), 1))); + read_access_indices.push_back(var); + read_access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + } + } } // Create the body block: @@ -266,12 +298,12 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, // write_buffer[access_indices] = read_buffer[access_indices] Block block( /*iter_vars=*/std::move(block_vars), - /*reads=*/{BufferRegion(info->read_buffer, access_region)}, - /*writes=*/{BufferRegion(info->write_buffer, access_region)}, + /*reads=*/{BufferRegion(info->read_buffer, read_access_region)}, + /*writes=*/{BufferRegion(info->write_buffer, write_access_region)}, /*name_hint=*/cache_region->buffer->name + "_" + storage_scope, /*body=*/ - BufferStore(info->write_buffer, BufferLoad(info->read_buffer, access_indices), - access_indices), + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices), + write_access_indices), /*init=*/NullOpt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, @@ -316,7 +348,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, BufferIndexType buffer_index_type) { // iters of the reindex block Array new_block_iters; - // the substition map from the original block iter to the iters of the reindex block + // the substitution map from the original block iter to the iters of the reindex block std::unordered_map block_var_replace_map; // indices to access the reindex buffer and the target buffer Array reindex_indices, target_indices; @@ -404,7 +436,7 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, } /*! - * \brief Recalculate the `affine_binding` flag of a specifc block + * \brief Recalculate the `affine_binding` flag of a specific block * \param block_sref The sref to the specific block */ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) { @@ -472,7 +504,7 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { * \param buffer The queried buffer * \return The sref of the only writer of the input buffer in the given scope, * or `NullOpt` if no block writes it in the scope. - * \throw NotSingleWriteBlock if there are more than one intrested block. + * \throw NotSingleWriteBlock if there are more than one interested block. */ Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, const Buffer& buffer) { @@ -490,6 +522,41 @@ Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_s } } +/*! + * \brief Check if all the consumer blocks of the given buffer in the given + * block scope are the children block of the given target stmt. + * \param self The state of the schedule . + * \param buffer The buffer whose consumer blocks are to be check. + * \param scope_sref The scope block of the check. + * \param stmt_sref The target stmt + * \return A boolean indicating if all the consumer blocks of the input buffer + * meet the requirement. + */ +bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sref, + StmtSRef stmt_sref) { + // Collect all children blocks of the target stmt. + std::unordered_set blocks_under_target; + for (const StmtSRef& block_sref : GetChildBlocks(self, stmt_sref)) { + const auto* block = block_sref->StmtAs(); + ICHECK(block != nullptr); + blocks_under_target.insert(block); + } + + // For each block in the scope, if it is a consumer of the + // input buffer, check if it is also a child block of the + // target stmt. + for (const StmtSRef& block_sref : GetChildBlocks(self, scope_sref)) { + const auto* block = block_sref->StmtAs(); + ICHECK(block != nullptr); + if (GetBufferRegionFromBuffer(block->reads, buffer).defined()) { + if (blocks_under_target.find(block) == blocks_under_target.end()) { + return false; + } + } + } + return true; +} + /*! * \brief Get the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) * \param self The state of the schedule. @@ -628,7 +695,7 @@ class CacheLocDetector : public StmtVisitor { void VisitStmt_(const BlockNode* block) final { // Only visit the current scope under buffer writer's parent block if (block == scope_sref_->stmt) { - // The block vistied is the current parent scope + // The block visited is the current parent scope StmtVisitor::VisitStmt_(block); // Handling cases when insert outside any loop or cache_read for input buffer if (visited_related_ && !loc_sref_.defined()) { @@ -728,7 +795,7 @@ class CacheInplaceLocDetector : public StmtVisitor { void VisitStmt_(const BlockNode* block) final { // Only visit the current scope under buffer writer's parent block if (block == scope_sref_->stmt) { - // The block vistied is the current parent scope + // The block visited is the current parent scope StmtVisitor::VisitStmt_(block); // Handling cases when insert outside any loop if (visited_block_ && !loc_sref_.defined()) { @@ -777,21 +844,63 @@ class CacheReadRewriter : public StmtExprMutator { * \brief Rewrite the AST and add a cache_read stage with the information provided * \param scope_sref The parent scope of this mutation * \param info The cache stage information + * \param cache_full_region A boolean indicating if the cache buffer is allocated with + * full region or compact region. * \return The new AST rooting at the original parent scope */ - static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) { - CacheReadRewriter rewriter(scope_sref, info); + static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info, + bool cache_full_region = true) { + CacheReadRewriter rewriter(scope_sref, info, cache_full_region); return rewriter(GetRef(scope_sref->stmt)); } private: - explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), info_(info) { - update_access_regions = [&](Array regions) { - return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); + explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info, + bool cache_full_region = true) + : scope_sref_(scope_sref), info_(info), cache_full_region_(cache_full_region) { + auto update_region = [this](const Region& region, const Region& offset) -> Region { + ICHECK_EQ(region.size(), offset.size()); + std::vector ret; + for (size_t i = 0; i < region.size(); ++i) { + ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), + region[i]->extent)); + } + return ret; + }; + + update_access_regions = [this, update_region](Array regions) { + if (cache_full_region_) { + return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); + } + + Array ret; + for (const BufferRegion& region : regions) { + if (region->buffer.same_as(info_->read_buffer)) { + ret.push_back(BufferRegion(info_->write_buffer, + update_region(region->region, info_->cache_region->region))); + } else { + ret.push_back(region); + } + } + return ret; }; - update_match_buffers = [&](Array match_buffers) { - return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); + update_match_buffers = [this, update_region](Array match_buffers) { + if (cache_full_region_) { + return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); + } + + Array ret; + for (const MatchBufferRegion& match_buffer : match_buffers) { + if (match_buffer->source->buffer.same_as(info_->read_buffer)) { + ret.push_back(MatchBufferRegion( + match_buffer->buffer, + BufferRegion(info_->write_buffer, update_region(match_buffer->source->region, + info_->cache_region->region)))); + } else { + ret.push_back(match_buffer); + } + } + return ret; }; } @@ -863,10 +972,21 @@ class CacheReadRewriter : public StmtExprMutator { return std::move(stmt); } + Array RewriteIndices(const Array& indices) { + std::vector ret; + for (size_t i = 0; i < indices.size(); ++i) { + ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); + } + return ret; + } + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { ObjectPtr n = make_object(*load); n->buffer = info_->write_buffer; + if (!cache_full_region_) { + n->indices = std::move(RewriteIndices(load->indices)); + } return PrimExpr(n); } return ExprMutator::VisitExpr_(load); @@ -890,6 +1010,13 @@ class CacheReadRewriter : public StmtExprMutator { std::function(Array)> update_access_regions; /*! \brief function to update match buffers of block being cache read.*/ std::function(Array)> update_match_buffers; + /*! + * \brief A boolean indicating if the cache buffer is allocated with + * full region or compact region. + */ + bool cache_full_region_; + /*! \brief Arithmetic analyzer. */ + arith::Analyzer ana_; friend ReindexCacheReadRewriter; }; @@ -970,23 +1097,66 @@ class CacheWriteRewriter : public StmtExprMutator { * \param scope_sref The parent scope of this mutation. * \param writer_block_sref The only writer block in the scope. * \param info The cache stage information. + * \param cache_full_region A boolean indicating if the cache buffer is allocated with + * full region or compact region. * \return The new AST rooting at the original parent scope. */ static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, - CacheStageInfo* info) { - CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); + CacheStageInfo* info, bool cache_full_region = true) { + CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info, cache_full_region); return rewriter(GetRef(scope_sref->stmt)); } private: explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, - CacheStageInfo* info) - : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) { - update_access_regions = [&](Array regions) { - return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); + CacheStageInfo* info, bool cache_full_region = true) + : scope_sref_(scope_sref), + writer_block_sref_(writer_block_sref), + info_(info), + cache_full_region_(cache_full_region) { + auto update_region = [this](const Region& region, const Region& offset) -> Region { + ICHECK_EQ(region.size(), offset.size()); + std::vector ret; + for (size_t i = 0; i < region.size(); ++i) { + ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), + region[i]->extent)); + } + return ret; }; - update_match_buffers = [&](Array match_buffers) { - return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); + + update_access_regions = [this, update_region](Array regions) { + if (cache_full_region_) { + return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); + } + + Array ret; + for (const BufferRegion& region : regions) { + if (region->buffer.same_as(info_->write_buffer)) { + ret.push_back(BufferRegion(info_->read_buffer, + update_region(region->region, info_->cache_region->region))); + } else { + ret.push_back(region); + } + } + return ret; + }; + update_match_buffers = [this, update_region](Array match_buffers) { + if (cache_full_region_) { + return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); + } + + Array ret; + for (const MatchBufferRegion& match_buffer : match_buffers) { + if (match_buffer->source->buffer.same_as(info_->write_buffer)) { + ret.push_back(MatchBufferRegion( + match_buffer->buffer, + BufferRegion(info_->read_buffer, update_region(match_buffer->source->region, + info_->cache_region->region)))); + } else { + ret.push_back(match_buffer); + } + } + return ret; }; } @@ -1072,11 +1242,22 @@ class CacheWriteRewriter : public StmtExprMutator { return std::move(stmt); } + Array RewriteIndices(const Array& indices) { + std::vector ret; + for (size_t i = 0; i < indices.size(); ++i) { + ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); + } + return ret; + } + Stmt VisitStmt_(const BufferStoreNode* store) override { BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); if (stmt->buffer.same_as(info_->write_buffer)) { auto n = CopyOnWrite(stmt.get()); n->buffer = info_->read_buffer; + if (!cache_full_region_) { + n->indices = std::move(RewriteIndices(n->indices)); + } return Stmt(n); } else { return std::move(stmt); @@ -1087,6 +1268,9 @@ class CacheWriteRewriter : public StmtExprMutator { if (load->buffer.same_as(info_->write_buffer)) { ObjectPtr n = make_object(*load); n->buffer = info_->read_buffer; + if (!cache_full_region_) { + n->indices = std::move(RewriteIndices(n->indices)); + } return PrimExpr(n); } return ExprMutator::VisitExpr_(load); @@ -1112,6 +1296,13 @@ class CacheWriteRewriter : public StmtExprMutator { std::function(Array)> update_access_regions; /*! \brief function to update match buffers of block being cache write.*/ std::function(Array)> update_match_buffers; + /*! + * \brief A boolean indicating if the cache buffer is allocated with + * full region or compact region. + */ + bool cache_full_region_; + /*! \brief Arithmetic analyzer. */ + arith::Analyzer ana_; friend ReindexCacheWriteRewriter; }; @@ -1506,10 +1697,6 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 2. Create CacheStageInfo CacheStageInfo info; info.read_buffer = read_buffer; - // Create the corresponding buffer to be written, i.e. result of cache_read - info.write_buffer = WithScope(read_buffer, storage_scope); - // Create the corresponding buffer allocation - info.alloc = info.write_buffer; // info.consumer_blocks indicates which buffers should consume the cache. for (auto consumer : consumer_blocks) { @@ -1545,9 +1732,25 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff } // Step 4. Making new cache stage block and rewrite readers. - Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, - /*storage_scope=*/storage_scope); - Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + bool cache_full_region = info.loc_sref->StmtAs() == nullptr || + !AllConsumersUnderStmt(self, read_buffer, scope_sref, info.loc_sref); + info.cache_region = cache_region; + info.write_buffer = WithScope(read_buffer, storage_scope); + if (!cache_full_region) { + auto* write_buffer = info.write_buffer.CopyOnWrite(); + std::vector shape; + for (auto cache_range : info.cache_region->region) { + shape.push_back(cache_range->extent); + } + write_buffer->shape = std::move(shape); + } + info.alloc = info.write_buffer; + + Block cache_read_stage = + MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope, /*cache_full_region=*/cache_full_region); + Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info, + /*cache_full_region=*/cache_full_region); // Step 5. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); @@ -1583,11 +1786,8 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 2. Creating CacheStageInfo CacheStageInfo info; - info.read_buffer = WithScope(write_buffer, storage_scope); // Create the corresponding buffer to be written, i.e. result of cache_write info.write_buffer = write_buffer; - // Create the corresponding buffer allocation - info.alloc = info.read_buffer; // info.consumer_blocks indicates which buffers should consume the cache. for (auto consumer : consumer_blocks) { @@ -1608,11 +1808,27 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu BufferRegion cache_region = RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); + bool cache_full_region = info.loc_sref->StmtAs() == nullptr || + !AllConsumersUnderStmt(self, write_buffer, scope_sref, info.loc_sref); + info.cache_region = cache_region; + info.read_buffer = WithScope(write_buffer, storage_scope); + if (!cache_full_region) { + auto* read_buffer = info.read_buffer.CopyOnWrite(); + std::vector shape; + for (auto cache_range : info.cache_region->region) { + shape.push_back(cache_range->extent); + } + read_buffer->shape = std::move(shape); + } + info.alloc = info.read_buffer; + // Step 5. Making new cache stage block and rewrite readers. - Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, - /*storage_scope=*/storage_scope); + Block cache_write_stage = + MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope, /*cache_full_region=*/cache_full_region); Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, - /*writer_block_sref=*/block_sref, /*info=*/&info); + /*writer_block_sref=*/block_sref, /*info=*/&info, + /*cache_full_region=*/cache_full_region); // Step 6. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 95955646c64b..2d460b359181 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -523,7 +523,7 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: for i0, j0 in T.grid(8, 8): with T.block("scope"): i, j = T.axis.remap("SS", [i0, j0]) - A_local = T.alloc_buffer((128, 128), scope="local") + A_local = T.alloc_buffer((16, 16), scope="local") for x, y in T.grid(16, 16): with T.block("A"): vi = T.axis.S(128, i * 16 + x) @@ -531,14 +531,14 @@ def cache_read_under_scope(b: T.handle, c: T.handle) -> None: A[vi, vj] = 1.0 for x, y in T.grid(16, 16): with T.block("A_local"): - vi = T.axis.S(128, i * 16 + x) - vj = T.axis.S(128, j * 16 + y) - A_local[vi, vj] = A[vi, vj] + vi = T.axis.S(16, x) + vj = T.axis.S(16, y) + A_local[vi, vj] = A[i * 16 + vi, j * 16 + vj] for x, y in T.grid(16, 16): with T.block("B"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) - B[vi, vj] = A_local[vi, vj] + 1.0 + B[vi, vj] = A_local[vi - i * 16, vj - j * 16] + 1.0 for i, j in T.grid(128, 128): with T.block("A_global"): vi, vj = T.axis.remap("SS", [i, j]) @@ -866,28 +866,28 @@ def cache_write_under_scope(b: T.handle, c: T.handle) -> None: for i0, j0 in T.grid(8, 8): with T.block("scope"): i, j = T.axis.remap("SS", [i0, j0]) - A_local = T.alloc_buffer((128, 128), scope="local") - B_global = T.alloc_buffer((128, 128)) + A_local = T.alloc_buffer((16, 16), scope="local") + B_global = T.alloc_buffer((16, 16)) for x, y in T.grid(16, 16): with T.block("A_local"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) - A_local[vi, vj] = 1.0 + A_local[vi - i * 16, vj - j * 16] = 1.0 for x, y in T.grid(16, 16): with T.block("A"): - vi = T.axis.S(128, i * 16 + x) - vj = T.axis.S(128, j * 16 + y) - A_global[vi, vj] = A_local[vi, vj] + vi = T.axis.S(16, x) + vj = T.axis.S(16, y) + A_global[i * 16 + vi, j * 16 + vj] = A_local[vi, vj] for x, y in T.grid(16, 16): - with T.block("B_global"): + with T.block("B"): vi = T.axis.S(128, i * 16 + x) vj = T.axis.S(128, j * 16 + y) - B_global[vi, vj] = A_global[vi, vj] + 1.0 + B_global[vi - i * 16, vj - j * 16] = A_global[vi, vj] + 1.0 for x, y in T.grid(16, 16): with T.block("B_global"): - vi = T.axis.S(128, i * 16 + x) - vj = T.axis.S(128, j * 16 + y) - B[vi, vj] = B_global[vi, vj] + vi = T.axis.S(16, x) + vj = T.axis.S(16, y) + B[i * 16 + vi, j * 16 + vj] = B_global[vi, vj] for i, j in T.grid(128, 128): with T.block("A_global"): vi, vj = T.axis.remap("SS", [i, j]) @@ -1167,6 +1167,104 @@ def block_predicate_cache_write_output_buf() -> None: B[v0] = B_shared[v0] +@T.prim_func +def symbolic_matmul_blocked(var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32): + A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4)) + B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) + C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) + T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) + with T.init(): + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0) + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = ( + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] + + A[v_i0_o * 32 + v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i] + ) + + +@T.prim_func +def symbolic_matmul_blocked_cache_read( + var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32 +): + A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4)) + B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) + C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + A_shared = T.alloc_buffer((32, 4), scope="shared") + for ax0, ax1 in T.grid(32, 4): + with T.block("A_shared"): + v0 = T.axis.spatial(32, ax0) + v1 = T.axis.spatial(4, ax1) + T.reads(A[v_i0_o * 32 + v0, v1]) + T.writes(A_shared[v0, v1]) + A_shared[v0, v1] = A[v_i0_o * 32 + v0, v1] + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads(A_shared[v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) + T.writes(C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) + with T.init(): + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0) + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = ( + C[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] + + A_shared[v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i] + ) + + +@T.prim_func +def symbolic_matmul_blocked_cache_write( + var_A: T.handle, var_B: T.handle, var_C: T.handle, n: T.int32 +): + A = T.match_buffer(var_A, ((n + 31) // 32 * 32, 4)) + B = T.match_buffer(var_B, (4, (n + 31) // 32 * 32)) + C = T.match_buffer(var_C, ((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + C_pad_local = T.alloc_buffer((32, 32), scope="local") + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads(A[v_i0_o * 32 + v_i0_i, v_k_i], B[v_k_i, v_i1_o * 32 + v_i1_i]) + T.writes(C_pad_local[v_i0_i, v_i1_i]) + with T.init(): + C_pad_local[v_i0_i, v_i1_i] = T.float32(0) + C_pad_local[v_i0_i, v_i1_i] = ( + C_pad_local[v_i0_i, v_i1_i] + + A[v_i0_o * 32 + v_i0_i, v_k_i] * B[v_k_i, v_i1_o * 32 + v_i1_i] + ) + for ax0, ax1 in T.grid(32, 32): + with T.block("C_pad_local"): + v0 = T.axis.spatial(32, ax0) + v1 = T.axis.spatial(32, ax1) + T.reads(C_pad_local[v0, v1]) + T.writes(C[v_i0_o * 32 + v0, v_i1_o * 32 + v1]) + C[v_i0_o * 32 + v0, v_i1_o * 32 + v1] = C_pad_local[v0, v1] + + ########## Testcases for cache_read ########## use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) @@ -1215,7 +1313,7 @@ def test_cache_read_location(use_block_name): tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B") block_c = "C" if use_block_name else sch.get_block("C") @@ -1355,7 +1453,7 @@ def test_cache_write_location(use_block_name): tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. # B read cache buffer and C read original output buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_a = "A" if use_block_name else sch.get_block("A") @@ -1364,7 +1462,7 @@ def test_cache_write_location(use_block_name): tvm.ir.assert_structural_equal(cache_write_multi_consumer_B_consume_cache, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. # B read original output buffer and C read cache buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_a = "A" if use_block_name else sch.get_block("A") @@ -1373,7 +1471,7 @@ def test_cache_write_location(use_block_name): tvm.ir.assert_structural_equal(cache_write_multi_consumer_C_consume_cache, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) - # Test that specific consumer block targetting works. + # Test that specific consumer block targeting works. # B and C read cache buffer sch = tir.Schedule(func_multi_consumer, debug_mask="all") block_a = "A" if use_block_name else sch.get_block("A") @@ -1532,7 +1630,7 @@ def test_reindex_cache_read_fail_not_match(): ) -def test_reindex_cache_read_faile_not_single_point(): +def test_reindex_cache_read_failed_not_single_point(): sch = tir.Schedule(access_under_scope, debug_mask="all") with pytest.raises(tvm.tir.ScheduleError): sch.reindex_cache_read("scope", 0, "shared", lambda i, j: (i, j)) @@ -1571,5 +1669,21 @@ def test_reindex_cache_write_fail_not_single_point(): sch.reindex_cache_write("scope", 0, "shared", lambda i, j: (i, j)) +def test_symbolic_matmul_blocked_cache_read(use_block_name): + sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") + block = "matmul" if use_block_name else sch.get_block("matmul") + sch.cache_read(block=block, read_buffer_index=0, storage_scope="shared") + tvm.ir.assert_structural_equal(sch.mod["main"], symbolic_matmul_blocked_cache_read) + verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked) + + +def test_symbolic_matmul_blocked_cache_write(use_block_name): + sch = tir.Schedule(symbolic_matmul_blocked, debug_mask="all") + block = "matmul" if use_block_name else sch.get_block("matmul") + sch.cache_write(block=block, write_buffer_index=0, storage_scope="local") + tvm.ir.assert_structural_equal(sch.mod["main"], symbolic_matmul_blocked_cache_write) + verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked) + + if __name__ == "__main__": tvm.testing.main()