diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index e3a853e4c7..b2736a30e4 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -25,6 +25,7 @@ #define TVM_TIR_BUFFER_H_ #include +#include #include #include #include @@ -162,6 +163,7 @@ class BufferNode : public Object { static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); }; /*! diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 0198feb3cd..fe3b662d61 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -493,6 +493,13 @@ TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span()); */ TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span = Span()); +/*! + * \brief Protected write. This is only used on BufferStore's immediate RHS to indicate that + * out-of-bound access will not be performed. + * \param expr The expression to be protected. + * \return The result expression. + */ +TVM_DLL PrimExpr protected_write(PrimExpr expr); /*! * \brief Mark condition as likely. * \param cond The condition diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 4439a9c3d7..1350ca46f2 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -462,6 +462,13 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, {cond, true_value, false_value}, span); } +TVM_TIR_REGISTER_OP("protected_write"); + +PrimExpr protected_write(PrimExpr expr) { + static const Op& op = Op::Get("tir.protected_write"); + return tir::Call(expr.dtype(), op, {expr}, Span()); +} + // likely PrimExpr likely(PrimExpr cond, Span span) { if (is_const_int(cond)) return cond; diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 7b9c2a9a24..6e361cbe05 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -707,24 +707,6 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, const P const StmtSRef& dom_high_exclusive, arith::Analyzer* analyzer); -/*! - * \brief Check if buffer indices are all Vars and extr - * \param buffer_access The BufferLoad or BufferStore - * \return The indices if the indices are all Vars, otherwise NullOpt - */ -template -Optional> CheckTrivialBufferIndices(const T& buffer_access) { - Array indices; - for (const PrimExpr& index : buffer_access->indices) { - const VarNode* var = index.as(); - if (var == nullptr) { - return NullOpt; - } - indices.push_back(GetRef(var)); - } - return indices; -} - /*! * \brief Simplify non-trivial expressions * \param expr The expression to be simplified diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 78d1cab05c..efd0449ae3 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -622,8 +622,7 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref * \param block_sref The block sref that matches the Einsum pattern. * \param padding The padding for each block iter. */ -TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref, - const Array& padding); +TVM_DLL void PadEinsum(ScheduleState self, StmtSRef block_sref, Array padding); /******** Schedule: Buffer transformation ********/ /*! diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/tir/schedule/primitive/pad_einsum.cc index 2190dc69d3..5f1b97a837 100644 --- a/src/tir/schedule/primitive/pad_einsum.cc +++ b/src/tir/schedule/primitive/pad_einsum.cc @@ -17,13 +17,46 @@ * under the License. */ -#include +#include #include "../utils.h" namespace tvm { namespace tir { +/*! + * \brief Check if buffer indices are all Vars and expr + * \param buffer_access The BufferLoad or BufferStore + * \return The indices if the indices are all Vars, otherwise NullOpt + */ +Optional> CheckTrivialBufferIndices(const Array& buffer_access) { + Array indices; + for (const PrimExpr& index : buffer_access) { + const VarNode* var = index.as(); + if (var == nullptr) { + return NullOpt; + } + indices.push_back(GetRef(var)); + } + return indices; +} + +Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) { + Array indices; + indices.reserve(buffer_region->region.size()); + for (const Range& range : buffer_region->region) { + if (!tir::is_one(range->extent)) { + return NullOpt; + } + if (const auto* var = range->min.as()) { + indices.push_back(GetRef(var)); + } else { + return NullOpt; + } + } + return indices; +} + /*! \brief The schedule error class when the padding size is invalid. */ class InvalidPaddingError : public ScheduleError { public: @@ -46,7 +79,7 @@ class InvalidPaddingError : public ScheduleError { throw InvalidPaddingError(self->mod, block, padding); } for (const auto& pad : padding) { - if (pad->value < 0) { + if (pad->value <= 0) { throw InvalidPaddingError(self->mod, block, padding); } } @@ -81,116 +114,123 @@ class NonEinsumError : public ScheduleError { /*! \brief Data structure that represents a Einsum computation. */ struct Einsum { // The output buffer - Buffer output_buffer; + Array output_buffers; // The indices of the output buffer - Array output_indices; + Map> output_indices; + // The input buffers + Array input_buffers; // The indices of the input buffers Map> input_indices; }; -class EinsumExtractor : public ExprVisitor { - public: - EinsumExtractor() = default; - - std::optional Extract(const Block& block) { - const BufferStoreNode* update = block->body.as(); - // Step 1: Check the body is a BufferStore and the block has the init statement, and the - // BufferStore and the init statement store have the same output buffer indices. - if (update == nullptr || !block->init.defined()) { - return std::nullopt; +struct BufferPadding { + Buffer buffer; + Buffer padded_buffer; + + static BufferPadding FromBufferRegion(const BufferRegion& buffer_region, + const Map& iter_extents) { + BufferPadding result; + result.buffer = buffer_region->buffer; + Array shape; + shape.reserve(buffer_region->region.size()); + int ndim = buffer_region->region.size(); + for (int i = 0; i < ndim; ++i) { + Var var = Downcast(buffer_region->region[i]->min); + if (Optional extent = iter_extents.Get(var)) { + shape.push_back(extent.value()); + } else { + shape.push_back(buffer_region->buffer->shape[i]); + } } + result.padded_buffer = decl_buffer(shape, result.buffer->dtype, result.buffer->name + "_pad", + result.buffer.scope()); + return result; + } - if (Optional> opt_indices = CheckTrivialBufferIndices(update); - opt_indices.defined()) { - ein_sum_.output_indices = std::move(opt_indices.value()); + Stmt MakeCopyBlock(bool is_read, Array* blocks, arith::Analyzer* analyzer) { + Array loop_vars; + Array loop_doms; + Array iter_vars; + Array instance_dom; + Array indices; + int ndim = buffer->shape.size(); + for (int i = 0; i < ndim; ++i) { + PrimExpr dim{nullptr}; + if (is_read) { + dim = padded_buffer->shape[i]; + } else { + dim = buffer->shape[i]; + } + Range dom = Range::FromMinExtent(IntImm(dim->dtype, 0), dim); + loop_vars.push_back(Var("i" + std::to_string(i), dim->dtype)); + loop_doms.push_back(dom); + IterVar iter_var(dom, Var("v" + std::to_string(i), dim->dtype), kDataPar); + instance_dom.push_back(Range::FromMinExtent(iter_var->var, IntImm(dim->dtype, 1))); + iter_vars.push_back(iter_var); + indices.push_back(iter_var->var); + } + Stmt body{nullptr}; + if (is_read) { + PrimExpr predicate = Bool(true); + for (int i = 0; i < ndim; ++i) { + if (!analyzer->CanProveEqual(buffer->shape[i], padded_buffer->shape[i])) { + predicate = predicate && (indices[i] < buffer->shape[i]); + } + } + PrimExpr rhs = BufferLoad(buffer, indices); + body = + BufferStore(padded_buffer, if_then_else(predicate, rhs, make_zero(rhs->dtype)), indices); } else { - return std::nullopt; + body = BufferStore(buffer, BufferLoad(padded_buffer, indices), indices); } - ein_sum_.output_buffer = update->buffer; - - const BufferStoreNode* init = block->init.value().as(); - ICHECK(init != nullptr); - if (!CompareBufferIndices(init->indices, ein_sum_.output_indices)) { - return std::nullopt; + BufferRegion read_region(buffer, instance_dom); + BufferRegion write_region(padded_buffer, instance_dom); + if (!is_read) { + std::swap(read_region, write_region); } - // Step 2: Check the BufferStore updates the output buffer and the input buffers indices are - // block iter variables. - CheckStoreValue(update->value); - if (fail_) { - return std::nullopt; + Block new_block(iter_vars, {read_region}, {write_region}, padded_buffer->name, std::move(body)); + blocks->push_back(new_block); + body = BlockRealize(Array{loop_vars.begin(), loop_vars.end()}, Bool(true), new_block); + for (int i = ndim - 1; i >= 0; --i) { + body = For(loop_vars[i], loop_doms[i]->min, loop_doms[i]->extent, ForKind::kSerial, + std::move(body)); } - return std::move(ein_sum_); + return body; } +}; - private: - void CheckStoreValue(const PrimExpr& update) { - // Check the update part has the form: - // Output[output_indices] += Input_0[input_indices_0] op_0 Input_1[input_indices_1] op_1 ... - // where output_indices and input_indices_i are the indices are arrays whose elements are the - // block iter variables instead of composite PrimExpr, and op_i are the binary operations. - - // Check the value is Add and eithe LHS or RHS is the BufferLoad from the output buffer. - const AddNode* add = update.as(); - if (add == nullptr) { - fail_ = true; - return; - } - const BufferLoadNode* lhs = add->a.as(); - const BufferLoadNode* rhs = add->b.as(); - if (lhs == nullptr && rhs != nullptr) { - std::swap(lhs, rhs); - } - if (lhs == nullptr || !lhs->buffer.same_as(ein_sum_.output_buffer) || - !CompareBufferIndices(lhs->indices, ein_sum_.output_indices)) { - fail_ = true; - return; +Einsum ExtractEinsum(const ScheduleState& self, const Block& block) { + Einsum result; + std::unordered_set buffer_used; + int n_reads = block->reads.size(); + for (int i = 0; i < n_reads; ++i) { + const Buffer& buffer = block->reads[i]->buffer; + if (buffer_used.count(buffer.get()) != 0) { + throw NonEinsumError(self->mod, block); } - VisitExpr(add->b); - } - - void VisitExpr(const PrimExpr& n) final { - if (n->IsInstance() || n->IsInstance() || n->IsInstance()) { - ExprVisitor::VisitExpr(n); + buffer_used.insert(buffer.get()); + if (Optional> opt_indices = CheckTrivialBufferAccess(block->reads[i])) { + result.input_buffers.push_back(buffer); + result.input_indices.Set(buffer, opt_indices.value()); } else { - fail_ = true; - return; + throw NonEinsumError(self->mod, block); } } - - void VisitExpr_(const BufferLoadNode* op) final { - if (auto it = ein_sum_.input_indices.find(op->buffer); - it != ein_sum_.input_indices.end() && !CompareBufferIndices(op->indices, (*it).second)) { - fail_ = true; - return; + int n_writes = block->writes.size(); + for (int i = 0; i < n_writes; ++i) { + const Buffer& buffer = block->writes[i]->buffer; + if (buffer_used.count(buffer.get()) != 0) { + throw NonEinsumError(self->mod, block); } - if (Optional> opt_indices = CheckTrivialBufferIndices(op); opt_indices.defined()) { - ein_sum_.input_indices.Set(op->buffer, std::move(opt_indices.value())); + buffer_used.insert(buffer.get()); + if (Optional> opt_indices = CheckTrivialBufferAccess(block->writes[i])) { + result.output_buffers.push_back(buffer); + result.output_indices.Set(buffer, opt_indices.value()); } else { - fail_ = true; - return; + throw NonEinsumError(self->mod, block); } } - - void VisitExpr_(const CastNode* op) { VisitExpr(op->value); } - - bool Fail() { return fail_; } - - bool CompareBufferIndices(const Array& indices, const Array& other) { - return std::equal(indices.begin(), indices.end(), other.begin(), other.end(), - [](const PrimExpr& a, const Var& b) { return a.same_as(b); }); - } - - Einsum ein_sum_; - bool fail_{false}; -}; - -Einsum ExtractEinsum(const ScheduleState& self, const Block& block) { - EinsumExtractor extractor; - std::optional einsum = extractor.Extract(block); - if (!einsum.has_value()) { - throw NonEinsumError(self->mod, block); - } - return einsum.value(); + return result; } class BufferNotAllocatedInScopeError : public ScheduleError { @@ -218,69 +258,6 @@ class BufferNotAllocatedInScopeError : public ScheduleError { Buffer buffer_; }; -class PadEinsumRewriter : public ReplaceBufferMutator { - public: - PadEinsumRewriter(const std::unordered_map producer_predicate, - Map padded_iter_extents, const Map& buffer_remap, - Map* block_sref_reuse, arith::Analyzer* analyzer) - : ReplaceBufferMutator(buffer_remap, block_sref_reuse), - producer_predicate_(producer_predicate), - padded_iter_extents_(padded_iter_extents), - analyzer_(analyzer) {} - using ReplaceBufferMutator::VisitExpr_; - using ReplaceBufferMutator::VisitStmt_; - - Stmt VisitStmt_(const ForNode* op) final { - For new_for = Downcast(ReplaceBufferMutator::VisitStmt_(op)); - if (padded_iter_extents_.count(new_for->loop_var)) { - new_for.CopyOnWrite()->extent = padded_iter_extents_.at(new_for->loop_var); - } - return std::move(new_for); - } - - Block PadProducerBlock(Block block, const PrimExpr& predicate) { - BufferStore store = Downcast(block->body); - store.CopyOnWrite()->value = - analyzer_->Simplify(if_then_else(predicate, store->value, make_zero(store->value.dtype()))); - block.CopyOnWrite()->body = std::move(store); - return block; - } - - Stmt VisitStmt_(const BlockNode* op) final { - Block old_block = GetRef(op); - Block new_block = Downcast(ReplaceBufferMutator::VisitStmt_(op)); - if (auto it = producer_predicate_.find(op); it != producer_predicate_.end()) { - new_block = PadProducerBlock(std::move(new_block), (*it).second); - } - - // Mutate block iters - Array new_iters; - bool changed = false; - for (const IterVar& iter : new_block->iter_vars) { - if (auto it = padded_iter_extents_.find(iter->var); it != padded_iter_extents_.end()) { - changed = true; - new_iters.push_back( - IterVar(Range::FromMinExtent(0, (*it).second), iter->var, iter->iter_type)); - } else { - new_iters.push_back(iter); - } - } - if (changed) { - new_block.CopyOnWrite()->iter_vars = std::move(new_iters); - } - if (!old_block.same_as(new_block)) { - block_sref_reuse_->Set(old_block, new_block); - } - return std::move(new_block); - } - - private: - const std::unordered_set producer_blocks_; - const std::unordered_map producer_predicate_; - const Map padded_iter_extents_; - arith::Analyzer* analyzer_; -}; - /*! \brief The schedule error class when the producer block cannot be padded. */ class InvalidProducerError : public ScheduleError { public: @@ -307,140 +284,187 @@ class InvalidProducerError : public ScheduleError { Block producer_; }; -void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array& padding) { - arith::Analyzer analyzer; - // Step 1: Input checking and error handling - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - BlockRealize realize = GetBlockRealize(self, block_sref); - - const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - InvalidPaddingError::Check(self, GetRef(block), padding); - - const Array producers = GetProducers(self, block_sref); - { - auto f_check_block_properties = [&](const StmtSRef& block_sref, bool is_producer) { - CheckBlockHasTrivialBinding(self, block_sref); - if (is_producer) { - CheckCompleteBlock(self, block_sref, scope_sref); +class BufferReplacer : public StmtExprMutator { + public: + Stmt VisitStmt_(const BlockNode* old_block_ptr) final { + Block old_block = GetRef(old_block_ptr); + Block block = Downcast(StmtMutator::VisitStmt_(old_block_ptr)); + Array iter_vars; + iter_vars.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + if (Optional new_dom = iter2padded_extents.Get(iter_var->var)) { + ObjectPtr new_iter_var = make_object(*iter_var.get()); + new_iter_var->dom = Range::FromMinExtent(iter_var->dom->min, new_dom.value()); + iter_vars.push_back(IterVar(new_iter_var)); } else { - CheckReductionBlock(self, block_sref, scope_sref); + iter_vars.push_back(iter_var); + } + } + Array reads; + reads.reserve(block->reads.size()); + for (const BufferRegion& read : block->reads) { + if (Optional buffer = buffer_map_.Get(read->buffer)) { + reads.push_back(BufferRegion(buffer.value(), read->region)); + } else { + reads.push_back(read); } - Array loops = GetLoops(block_sref); - ICHECK(!loops.empty()); - CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front()); - }; - - // Check block properties of the computation block - f_check_block_properties(block_sref, false); - - // Check block properties of the producer block - for (const StmtSRef& producer_sref : producers) { - f_check_block_properties(producer_sref, true); } + Array writes; + writes.reserve(block->writes.size()); + for (const BufferRegion& write : block->writes) { + if (Optional buffer = buffer_map_.Get(write->buffer)) { + writes.push_back(BufferRegion(buffer.value(), write->region)); + } else { + writes.push_back(write); + } + } + Block new_block = Block(iter_vars, reads, writes, block->name_hint, block->body, block->init); + block_sref_reuse_.Set(old_block, new_block); + return new_block; } - Einsum einsum = ExtractEinsum(self, GetRef(block)); + Stmt VisitStmt_(const ForNode* old_for_ptr) final { + For old_for = GetRef(old_for_ptr); + For new_for = Downcast(StmtMutator::VisitStmt_(old_for_ptr)); + if (Optional new_extent = loop_var2padded_extent.Get(new_for->loop_var)) { + ObjectPtr new_for_ptr = make_object(*new_for.get()); + new_for_ptr->extent = new_extent.value(); + new_for = For(new_for_ptr); + } + return new_for; + } - // Check input and output buffers are all allocated in the current scope. - { - auto f_check_buffer_allocated = [&](const Buffer& buffer) { - auto [defining_site_sref, is_allocate] = GetBufferDefiningSite(block_sref, buffer); - if (!defining_site_sref.defined() || !is_allocate) { - throw BufferNotAllocatedInScopeError(self->mod, buffer); - } - }; - f_check_buffer_allocated(einsum.output_buffer); - for (const auto& buffer_indices_pair : einsum.input_indices) { - f_check_buffer_allocated(buffer_indices_pair.first); + Stmt VisitStmt_(const BufferStoreNode* old_store_ptr) final { + BufferStore store = Downcast(StmtMutator::VisitStmt_(old_store_ptr)); + if (Optional buffer = buffer_map_.Get(store->buffer)) { + return BufferStore(buffer.value(), store->value, store->indices); + } else { + return store; } } - // Step 2: Prepare buffer and variable remapping. Infer the new shape of the input and the output - // buffers. Infer the new extent of the block iters of the computation block and the producer - // block. + PrimExpr VisitExpr_(const BufferLoadNode* old_load_ptr) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(old_load_ptr)); + if (Optional buffer = buffer_map_.Get(load->buffer)) { + return BufferLoad(buffer.value(), load->indices); + } else { + return load; + } + } - Map padded_iter_extents; // The new extents of both the block iters and loop vars + Map iter2padded_extents; + Map loop_var2padded_extent; + Map buffer_map_; + Map block_sref_reuse_; +}; - // Convert the input padding array to a map from variables to the padded extents +void PadEinsum(ScheduleState self, StmtSRef block_sref, Array padding) { + arith::Analyzer analyzer; + // Step 1: Input checking and error handling + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + BlockRealize realize = GetBlockRealize(self, block_sref); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + InvalidPaddingError::Check(self, GetRef(block), padding); + // Step 2. Extract the Einsum pattern + Einsum einsum = ExtractEinsum(self, GetRef(block)); + // Step 3. Figure out the padding needed + BufferReplacer replacer; for (int i = 0, n = padding.size(); i < n; ++i) { const IterVar& iter = block->iter_vars[i]; - PrimExpr new_extent = - IntImm(iter->var->dtype, Downcast(iter->dom->extent)->value + padding[i]->value); - padded_iter_extents.Set(iter->var, new_extent); - padded_iter_extents.Set(Downcast(realize->iter_values[i]), new_extent); + Var loop_var = Downcast(realize->iter_values[i]); + PrimExpr dom = iter->dom->extent; + PrimExpr new_dom = analyzer.Simplify(ceildiv(dom, padding[i]) * padding[i]); + if (!analyzer.CanProveEqual(new_dom, dom)) { + replacer.iter2padded_extents.Set(iter->var, new_dom); + replacer.iter2padded_extents.Set(Downcast(realize->iter_values[i]), new_dom); + replacer.loop_var2padded_extent.Set(loop_var, new_dom); + } } - - Map buffer_remap; // mapping from buffers to new buffers with padded shapes - - // Utility function to pad a buffer with the new shape - auto f_pad_buffer = [&padded_iter_extents](Buffer buffer, const Array& indices) -> Buffer { - Array new_shape; - for (const Var& index : indices) { - new_shape.push_back(padded_iter_extents.at(index)); + auto f_needs_padding = [&replacer](const Array& region) { + for (const Range& range : region) { + if (const auto* var = range->min.as()) { + if (replacer.iter2padded_extents.count(GetRef(var))) { + return true; + } + } } - ICHECK_EQ(buffer->shape.size(), new_shape.size()); - buffer.CopyOnWrite()->shape = std::move(new_shape); - return buffer; + return false; }; - - buffer_remap.Set(einsum.output_buffer, f_pad_buffer(einsum.output_buffer, einsum.output_indices)); - - std::unordered_map producer_predicate; - - // Different from the output block, the padding for the producer block is not directly specified - // as the input argument. Instead, it is inferred from indices of the producer buffer accessed in - // the output block. - // We will find the indices (which are block iters) in BufferStore to the producer buffer - // and infer the new extents of the block iters and the corresponding loop vars. - for (const StmtSRef& producer_sref : producers) { - const BlockNode* producer_block = TVM_SREF_TO_BLOCK(producer_sref); - const BufferStoreNode* buffer_store = producer_block->body.as(); - Optional> producer_store_indices; - if (!buffer_store || producer_block->writes.size() != 1 || - !(producer_store_indices = CheckTrivialBufferIndices(buffer_store)).defined()) { - throw InvalidProducerError(self->mod, GetRef(producer_block)); - } - BlockRealize producer_realize = GetBlockRealize(self, producer_sref); - - const Buffer& old_buffer = producer_block->writes[0]->buffer; - Buffer new_buffer = f_pad_buffer(old_buffer, einsum.input_indices.at(old_buffer)); - buffer_remap.Set(old_buffer, new_buffer); - - // The predicate to ensure the producer block is in the original bound before padding - PrimExpr predicate = Bool(true); - Map indices_to_padded_extents; // buffer indices to padded extents - for (int i = 0, n = producer_store_indices.value().size(); i < n; ++i) { - const Var& index = producer_store_indices.value()[i]; - PrimExpr padded_extent = new_buffer->shape[i]; - if (!analyzer.CanProveEqual(padded_extent, old_buffer->shape[i])) { - predicate = predicate && (index < old_buffer->shape[i]); + // Step 3. Convert the subtree under the scope root + Array scope_body; + if (const auto* seq_stmt = scope_block->body.as()) { + scope_body = seq_stmt->seq; + } else { + scope_body.push_back(scope_block->body); + } + // Step 4. Find out the block of our interest + int pos = -1; + for (int i = 0; i < static_cast(scope_body.size()); ++i) { + bool found = false; + PostOrderVisit(scope_body[i], [&found, &block](const ObjectRef& node) { + if (node.get() == block) { + found = true; } - indices_to_padded_extents.Set(index, padded_extent); + }); + if (found) { + pos = i; + break; } - - for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) { - const IterVar& iter = producer_block->iter_vars[i]; - if (auto it = indices_to_padded_extents.find(iter->var); - it != indices_to_padded_extents.end()) { - const PrimExpr& padded_extent = (*it).second; - padded_iter_extents.Set(iter->var, padded_extent); - padded_iter_extents.Set(Downcast(producer_realize->iter_values[i]), padded_extent); - } else if (!is_one(iter->dom->extent)) { - throw InvalidProducerError(self->mod, GetRef(producer_block)); - } + } + ICHECK_NE(pos, -1); + // Step 5. For each buffer, if it needs padding, create a new buffer and a new block + Array read_blocks; + Array write_blocks; + Array new_copy_blocks; + Array alloc_buffers; + for (const BufferRegion& buffer_region : block->reads) { + if (f_needs_padding(buffer_region->region)) { + BufferPadding bp = + BufferPadding::FromBufferRegion(buffer_region, replacer.iter2padded_extents); + replacer.buffer_map_.Set(bp.buffer, bp.padded_buffer); + read_blocks.push_back(bp.MakeCopyBlock(true, &new_copy_blocks, &analyzer)); + alloc_buffers.push_back(bp.padded_buffer); } - producer_predicate[producer_block] = predicate; } - - // Step 3: Mutate the AST subtree with the new buffers and the new block iter extents. - Map block_sref_reuse; - PadEinsumRewriter rewriter(producer_predicate, padded_iter_extents, buffer_remap, - &block_sref_reuse, &analyzer); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - Stmt new_scope_block = rewriter(GetRef(scope_block)); - - // Step 4: Do the actual replacement. - self->Replace(scope_sref, new_scope_block, block_sref_reuse); + for (const BufferRegion& buffer_region : block->writes) { + if (f_needs_padding(buffer_region->region)) { + BufferPadding bp = + BufferPadding::FromBufferRegion(buffer_region, replacer.iter2padded_extents); + replacer.buffer_map_.Set(bp.buffer, bp.padded_buffer); + write_blocks.push_back(bp.MakeCopyBlock(false, &new_copy_blocks, &analyzer)); + alloc_buffers.push_back(bp.padded_buffer); + } + } + // Step 6. Create new scope body + Array new_scope_body; + for (int i = 0; i < static_cast(scope_body.size()); ++i) { + if (i != pos) { + new_scope_body.push_back(scope_body[i]); + continue; + } + new_scope_body.insert(new_scope_body.end(), read_blocks.begin(), read_blocks.end()); + new_scope_body.push_back(replacer(scope_body[i])); + new_scope_body.insert(new_scope_body.end(), write_blocks.begin(), write_blocks.end()); + } + // Step 7. Create new scope + Block new_scope_block{nullptr}; + { + ObjectPtr n = make_object(*scope_block); + n->body = SeqStmt::Flatten(new_scope_body); + n->alloc_buffers.insert(n->alloc_buffers.end(), alloc_buffers.begin(), alloc_buffers.end()); + new_scope_block = Block(n); + } + replacer.block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + // Step 8. Do replacement and update flags + self->Replace(scope_sref, new_scope_block, replacer.block_sref_reuse_); + for (const Block& block : new_copy_blocks) { + StmtSRef block_sref = self->stmt2ref.at(block.get()); + BlockInfo& block_info = self->block_info[block_sref]; + block_info.affine_binding = true; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + } } /******** Instruction Registration ********/ diff --git a/tests/python/unittest/test_tir_schedule_pad_einsum.py b/tests/python/unittest/test_tir_schedule_pad_einsum.py index ec4d000655..0b0288e3f7 100644 --- a/tests/python/unittest/test_tir_schedule_pad_einsum.py +++ b/tests/python/unittest/test_tir_schedule_pad_einsum.py @@ -15,108 +15,213 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys - import pytest import tvm import tvm.testing -from tvm import tir, te +from tvm import tir from tvm.script import tir as T -from tvm.tir.schedule.schedule import ScheduleError from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.meta_schedule.testing import te_workload -# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +def test_pad_matmul(): + # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg -@T.prim_func -def matmul_before( - A: T.Buffer((128, 127), "float32"), - B: T.Buffer((127, 127), "float32"), - C: T.Buffer((128, 127), "float32"), -) -> None: - A_shared = T.alloc_buffer((128, 127), "float32", scope="shared") - B_shared = T.alloc_buffer((127, 127), "float32", scope="shared") - C_shared = T.alloc_buffer((128, 127), "float32", scope="shared") - for i0, i1 in T.grid(128, 127): - with T.block("A"): - i, j = T.axis.remap("SS", [i0, i1]) - A_shared[i, j] = A[i, j] - for i0, i1 in T.grid(127, 127): - with T.block("B"): - i, j = T.axis.remap("SS", [i0, i1]) - B_shared[i, j] = B[i, j] - for i0, i1, i2 in T.grid(128, 127, 127): - with T.block("C_shared"): - i, j, k = T.axis.remap("SSR", [i0, i1, i2]) - with T.init(): - C_shared[i, j] = T.float32(0) - C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j] - for i0, i1 in T.grid(128, 127): - with T.block("C"): - i, j = T.axis.remap("SS", [i0, i1]) - C[i, j] = C_shared[i, j] + @T.prim_func + def matmul_before( + a: T.handle, + b: T.handle, + c: T.handle, + ) -> None: + n = T.int32() + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (n, 128), "float32") + C = T.match_buffer(c, (128, n), "float32") + for i0, i1, i2 in T.grid(128, n, 128): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[j, k] + @T.prim_func + def matmul_after( + a: T.handle, + b: T.handle, + c: T.handle, + ): + n = T.int32() + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (n, 128), "float32") + C = T.match_buffer(c, (128, n), "float32") + B_pad = T.alloc_buffer(((n + 31) // 32 * 32, 128)) + C_pad = T.alloc_buffer((128, (n + 31) // 32 * 32)) + for i0, i1 in T.grid((n + 31) // 32 * 32, 128): + with T.block("B_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + B_pad[v0, v1] = T.if_then_else(v0 < n, B[v0, v1], T.float32(0)) + for i0, i1, i2 in T.grid(128, (n + 31) // 32 * 32, 128): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(A[i, k], B_pad[j, k]) + T.writes(C_pad[i, j]) + with T.init(): + C_pad[i, j] = T.float32(0) + C_pad[i, j] = C_pad[i, j] + A[i, k] * B_pad[j, k] + for i0, i1 in T.grid(128, n): + with T.block("C_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + C[v0, v1] = C_pad[v0, v1] -@T.prim_func -def matmul_expected( - A: T.Buffer((128, 127), "float32"), - B: T.Buffer((127, 127), "float32"), - C: T.Buffer((128, 127), "float32"), -) -> None: - A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - for i0, i1 in T.grid(128, 128): - with T.block("A"): - i, j = T.axis.remap("SS", [i0, i1]) - T.reads(A[i, j]) - T.writes(A_shared_padded[i, j]) - A_shared_padded[i, j] = T.if_then_else(j < 127, A[i, j], T.float32(0), dtype="float32") - for i0, i1 in T.grid(128, 128): - with T.block("B"): - i, j = T.axis.remap("SS", [i0, i1]) - T.reads(B[i, j]) - T.writes(B_shared_padded[i, j]) - B_shared_padded[i, j] = T.if_then_else( - i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32" - ) - for i0, i1, i2 in T.grid(128, 128, 128): - with T.block("C_shared"): - i, j, k = T.axis.remap("SSR", [i0, i1, i2]) - T.reads(A_shared_padded[i, k], B_shared_padded[k, j]) - T.writes(C_shared_padded[i, j]) - with T.init(): - C_shared_padded[i, j] = T.float32(0) - C_shared_padded[i, j] = ( - C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j] - ) - for i0, i1 in T.grid(128, 127): - with T.block("C"): - i, j = T.axis.remap("SS", [i0, i1]) - T.reads(C_shared_padded[i, j]) - T.writes(C[i, j]) - C[i, j] = C_shared_padded[i, j] + sch = tir.Schedule(matmul_before, debug_mask="all") + C = sch.get_block("C") + sch.pad_einsum(C, [32, 32, 32]) + tvm.ir.assert_structural_equal(matmul_after, sch.mod["main"]) + verify_trace_roundtrip(sch, mod=matmul_before) -# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg +def test_pad_matmul_2(): + @T.prim_func + def before( + a: T.handle, + b: T.handle, + m: T.handle, + d: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(a, (1, n, 4096)) + B = T.match_buffer(b, (11008, 4096)) + M = T.match_buffer(m, (1, n, 11008)) + D = T.match_buffer(d, (1, n, 11008)) + C = T.alloc_buffer((1, n, 11008)) + for i0, i1, i2, k in T.grid(1, n, 11008, 4096): + with T.block("C"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) + T.writes(C[v_i0, v_i1, v_i2]) + with T.init(): + C[v_i0, v_i1, v_i2] = T.float32(0) + C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(1, n, 11008): + with T.block("D"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + D[v_ax0, v_ax1, v_ax2] = M[v_ax0, v_ax1, v_ax2] * C[v_ax0, v_ax1, v_ax2] + @T.prim_func + def after(a: T.handle, b: T.handle, m: T.handle, d: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(a, (1, n, 4096)) + B = T.match_buffer(b, (11008, 4096)) + M = T.match_buffer(m, (1, n, 11008)) + D = T.match_buffer(d, (1, n, 11008)) + # with T.block("root"): + C = T.alloc_buffer((1, n, 11008)) + A_pad = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096)) + C_pad = T.alloc_buffer((1, (n + 31) // 32 * 32, 11008)) + for i0, i1, i2 in T.grid(1, (n + 31) // 32 * 32, 4096): + with T.block("A_pad"): + v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) + A_pad[v0, v1, v2] = T.if_then_else(v1 < n, A[v0, v1, v2], T.float32(0)) + for i0, i1, i2, k in T.grid(1, (n + 31) // 32 * 32, 11008, 4096): + with T.block("C"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(A_pad[v_i0, v_i1, v_k], B[v_i2, v_k]) + T.writes(C_pad[v_i0, v_i1, v_i2]) + with T.init(): + C_pad[v_i0, v_i1, v_i2] = T.float32(0) + C_pad[v_i0, v_i1, v_i2] = ( + C_pad[v_i0, v_i1, v_i2] + A_pad[v_i0, v_i1, v_k] * B[v_i2, v_k] + ) + for i0, i1, i2 in T.grid(1, n, 11008): + with T.block("C_pad"): + v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) + C[v0, v1, v2] = C_pad[v0, v1, v2] + for ax0, ax1, ax2 in T.grid(1, n, 11008): + with T.block("D"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + D[v_ax0, v_ax1, v_ax2] = M[v_ax0, v_ax1, v_ax2] * C[v_ax0, v_ax1, v_ax2] -def test_pad_matmul(): - sch = tir.Schedule(matmul_before, debug_mask="all") - C = sch.get_block("C_shared") - sch.pad_einsum(C, [0, 1, 1]) - tvm.ir.assert_structural_equal(matmul_expected, sch.mod["main"]) - verify_trace_roundtrip(sch, mod=matmul_before) + sch = tir.Schedule(before, debug_mask="all") + C = sch.get_block("C") + sch.pad_einsum(C, [1, 32, 32, 32]) + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + verify_trace_roundtrip(sch, mod=before) -def test_pad_matmul_error_non_intermediate_buffer(): - func = te.create_prim_func(te_workload.matmul(128, 127, 127)) - sch = tir.Schedule(func, debug_mask="all") - C = sch.get_block("C") - with pytest.raises(ScheduleError): - sch.pad_einsum(C, [0, 1, 1]) +def test_pad_rms(): + @T.prim_func + def before( + a: T.handle, + w: T.handle, + r: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(a, (1, n, 4096)) + W = T.match_buffer(w, (4096,), "float32") + R = T.match_buffer(r, (1, n, 4096), "float32") + S = T.alloc_buffer((1, n), "float32") + for bsz, i, k in T.grid(1, n, 4096): + with T.block("S"): + v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) + T.reads(A[v_bsz, v_i, v_k]) + T.writes(S[v_bsz, v_i]) + with T.init(): + S[v_bsz, v_i] = T.float32(0) + S[v_bsz, v_i] = S[v_bsz, v_i] + A[v_bsz, v_i, v_k] * A[v_bsz, v_i, v_k] + for bsz, i, k in T.grid(1, n, 4096): + with T.block("R"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + R[v_bsz, v_i, v_k] = W[v_k] * ( + A[v_bsz, v_i, v_k] + / T.sqrt(S[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(1e-6)) + ) + + @T.prim_func + def after(a: T.handle, w: T.handle, r: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(a, (1, n, 4096)) + W = T.match_buffer(w, (4096,), "float32") + R = T.match_buffer(r, (1, n, 4096)) + S = T.alloc_buffer((1, n)) + A_pad = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096)) + S_pad = T.alloc_buffer((1, (n + 31) // 32 * 32)) + for i0, i1, i2 in T.grid(1, (n + 31) // 32 * 32, 4096): + with T.block("A_pad"): + v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) + A_pad[v0, v1, v2] = T.if_then_else(v1 < n, A[v0, v1, v2], T.float32(0)) + for bsz, i, k in T.grid(1, (n + 31) // 32 * 32, 4096): + with T.block("S"): + v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) + T.reads(A_pad[v_bsz, v_i, v_k]) + T.writes(S_pad[v_bsz, v_i]) + with T.init(): + S_pad[v_bsz, v_i] = T.float32(0) + S_pad[v_bsz, v_i] = ( + S_pad[v_bsz, v_i] + A_pad[v_bsz, v_i, v_k] * A_pad[v_bsz, v_i, v_k] + ) + for i0, i1 in T.grid(1, n): + with T.block("S_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + S[v0, v1] = S_pad[v0, v1] + for bsz, i, k in T.grid(1, n, 4096): + with T.block("R"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + R[v_bsz, v_i, v_k] = W[v_k] * ( + A[v_bsz, v_i, v_k] + / T.sqrt(S[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(1e-6)) + ) + + sch = tir.Schedule(before, debug_mask="all") + C = sch.get_block("S") + sch.pad_einsum(C, [1, 32, 1]) + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + verify_trace_roundtrip(sch, mod=before) if __name__ == "__main__": - tvm.testing.main() + test_pad_matmul() + test_pad_matmul_2() + test_pad_rms()