diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9c28ed1370..3351f6c2c6 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -264,7 +264,7 @@ class ScheduleNode : public runtime::Object { * \return The rfactor block */ virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; - /******** Schedule: Block annotations ********/ + /******** Schedule: Block annotation ********/ /*! * \brief Set alignment requirement for specific dimension such that * stride[axis] == k * factor + offset for some k. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b7129bba9f..52bb9d6fb4 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -727,7 +727,7 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None: """ return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member - ######## Schedule: Block annotatoins ######## + ######## Schedule: Block annotatoin ######## def storage_align(self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int) -> None: diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9baf4b5245..0a465630fe 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -202,6 +202,18 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/******** Block-block relation ********/ + +/*! + * \brief Get the BlockRealize of the single child block of the block or loop specified by + * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks + * \param self The schedule state + * \param block The queried block + * \return The buffer of the n-th write region of the block. + * \throw ScheduleError If the buffer index is out of bound. + */ +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n); + /******** Commutative Reducer ********/ /*! @@ -224,6 +236,17 @@ std::vector> GetReducerGetters(); bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); +/******** Annotation ********/ + +/*! + * \brief Create a new block with the given annotation added + * \param block The block with original annotation + * \param attr_key The annotation key to be added + * \param attr_value The annotation value to be added + * \return A new block with the given annotation as its last annotation + */ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3ee98ec5b7..eeb881220e 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -527,6 +527,57 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +/******** Block-buffer relation ********/ + +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { + class WriteBufferIndexOutOfRangeError : public ScheduleError { + public: + explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index) + : mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {} + + String FastErrorString() const final { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in range " + "[0, num_write_regions) where `num_write_regions` is the number of buffer regions " + "written by the block."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + size_t num_writes = block_->writes.size(); + os << "The block {0} has " << num_writes + << " write regions, so `buffer_index` is required to be in [0, " + << num_writes << "). However, the input `buffer_index` is " << buffer_index_ + << ", which is out of the expected range"; + return os.str(); + } + + static Buffer CheckAndGetBuffer(const IRModule& mod, const Block& block, int buffer_index) { + if (buffer_index < 0 || buffer_index > block->writes.size()) { + throw WriteBufferIndexOutOfRangeError(mod, block, buffer_index); + } + return block->writes[buffer_index]->buffer; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {mod_}; } + + private: + IRModule mod_; + Block block_; + int buffer_index_; + }; + return WriteBufferIndexOutOfRangeError::CheckAndGetBuffer(self->mod, block, n); +} + +/******** Annotation ********/ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { + Map annotations = block->annotations; + annotations.Set(attr_key, attr_value); + ObjectPtr new_block = make_object(*block); + new_block->annotations = std::move(annotations); + return Block(new_block); +} + /******** Pattern Matcher ********/ /*! diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index e6dc29dd8a..af3aa3e56e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -361,7 +361,7 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } -/******** Schedule: block annotations ********/ +/******** Schedule: block annotation ********/ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index be6d370171..cfdd9c8452 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -88,7 +88,7 @@ class ConcreteScheduleNode : public ScheduleNode { void ReverseComputeInline(const BlockRV& block) override; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; - /******** Schedule: Block annotations ********/ + /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; /******** Schedule: Blockize & Tensorize ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 05786b84c5..01ee590384 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -104,7 +104,7 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref * \return The sref of the rfactor block */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); -/******** Schedule: Block annotations ********/ +/******** Schedule: Block annotation ********/ /*! * \brief Set alignment requirement for specific dimension such that * stride[axis] == k * factor + offset for some k diff --git a/src/tir/schedule/primitive/bind_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc similarity index 74% rename from src/tir/schedule/primitive/bind_annotate.cc rename to src/tir/schedule/primitive/block_annotate.cc index 1507a7599e..c573d2834e 100644 --- a/src/tir/schedule/primitive/bind_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -21,21 +21,6 @@ namespace tvm { namespace tir { -/*! - * \brief Create a new block with the given annotation added - * \param block The block with original annotation - * \param attr_key The annotation key to be added - * \param attr_value The annotation value to be added - * \return A new block with the given annotation as its last annotation - */ -Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { - Map annotations = block->annotations; - annotations.Set(attr_key, attr_value); - ObjectPtr new_block = make_object(*block); - new_block->annotations = std::move(annotations); - return Block(new_block); -} - class StorageAlignAxisOutOfRangeError : public ScheduleError { public: explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) @@ -78,43 +63,6 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { int axis_; }; -class WriteBufferIndexOutOfRangeError : public ScheduleError { - public: - explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index) - : mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {} - - String FastErrorString() const final { - return "ScheduleError: The input `buffer_index` is out of range. It is required to be in range " - "[0, num_write_regions) where `num_write_regions` is the number of buffer regions " - "written by the block."; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - size_t num_writes = block_->writes.size(); - os << "The block {0} has " << num_writes - << " write regions, so `buffer_index` is required to be in [0, " - << num_writes << "). However, the input `buffer_index` is " << buffer_index_ - << ", which is out of the expected range"; - return os.str(); - } - - static Buffer CheckAndGetBuffer(const IRModule& mod, const Block& block, int buffer_index) { - if (buffer_index < 0 || buffer_index > block->writes.size()) { - throw WriteBufferIndexOutOfRangeError(mod, block, buffer_index); - } - return block->writes[buffer_index]->buffer; - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {mod_}; } - - private: - IRModule mod_; - Block block_; - int buffer_index_; -}; - /*! * \brief Find the defining site of the buffer in the given block and its ancestors * \param block_sref The block sref @@ -189,7 +137,7 @@ class NonAllocatedBufferError : public ScheduleError { class StorageAlignInvalidFactorError : public ScheduleError { public: - explicit StorageAlignInvalidFactorError(const IRModule& mod, int factor) + explicit StorageAlignInvalidFactorError(IRModule mod, int factor) : mod_(std::move(mod)), factor_(factor) {} String FastErrorString() const final { @@ -219,29 +167,86 @@ class StorageAlignInvalidFactorError : public ScheduleError { int factor_; }; +class StorageAlignInvalidAnnotationError : public ScheduleError { + public: + explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block annotation for storage align is expected to be an array of " + "3-integer-tuples (axis, factor, offset)."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The block annotation for storage align is expected to be an array of 3-integer-tuples " + "(axis, factor, offset). However, the block annotation with key " + << attr::buffer_dim_align << " of the block {0} is " + << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected."; + return os.str(); + } + + static Array>> CheckAndGetAnnotation(const IRModule& mod, + const Block& block) { + // Get existing annotation value. + auto it = block->annotations.find(attr::buffer_dim_align); + if (it != block->annotations.end()) { + if (!IsValidAnnotatoin(block, (*it).second)) { + throw StorageAlignInvalidAnnotationError(mod, block); + } + return Downcast>>>((*it).second); + } + + // Create new annotation value + Array>> storage_align_annotation; + storage_align_annotation.resize(block->writes.size()); + return storage_align_annotation; + } + + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod() const final { return mod_; } + + private: + static bool IsValidAnnotatoin(const Block& block, const ObjectRef& anno_value) { + if (!anno_value->IsInstance()) { + return false; + } + const auto& buffer_annotations = Downcast>(anno_value); + if (buffer_annotations.size() != block->writes.size()) { + return false; + } + for (const ObjectRef buffer_annotation: buffer_annotations) { + if (!buffer_annotation->IsInstance()) { + return false; + } + const auto& dim_annotations = Downcast>(buffer_annotation); + // Check if the annotations are consist of 3-tuples. + if (dim_annotations.size() != 3) { + return false; + } + for (const ObjectRef& dim_anno_element : dim_annotations) { + if (!dim_anno_element->IsInstance()) { + return false; + } + } + } + return true; + } + + IRModule mod_; + Block block_; +}; + void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); - Buffer buffer = WriteBufferIndexOutOfRangeError::CheckAndGetBuffer( - self->mod, GetRef(block_ptr), buffer_index); + Buffer buffer = GetNthWriteBuffer(self, GetRef(block_ptr), buffer_index); StorageAlignInvalidFactorError::Check(self->mod, factor); axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); // Step 1: Get existing or create new annotation value. - auto it = block_ptr->annotations.find(attr::buffer_dim_align); - - // Use an array to store the storage alignement information for each output tensor. - // For each output tensor, we use an array of tuples (axis, factor, offset) to specify storage - // alignment for each dimension. - Array>> storage_align_annotation; - - if (it != block_ptr->annotations.end()) { - storage_align_annotation = Downcast>>>((*it).second); - ICHECK(storage_align_annotation.size() == block_ptr->writes.size()); - } else { - storage_align_annotation.resize(block_ptr->writes.size()); - } + auto storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, GetRef(block_ptr)); // Step 2: Update the annotation value Array> buffer_storage_align = storage_align_annotation[buffer_index]; diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 2375e5c2f5..0e547a7ec6 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -135,7 +135,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") /******** (FFI) Reduction ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") .set_body_method(&ScheduleNode::RFactor); -/******** (FFI) Block annotations ********/ +/******** (FFI) Block annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); /******** (FFI) Blockize & Tensorize ********/ diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 7174c8c4ec..074c639457 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -76,6 +76,32 @@ def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None: C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) +@tvm.script.tir +def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B", annotations={"buffer_dim_align": [0]}) as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + def test_storage_align(): func = element_wise s = tir.Schedule(func, debug_mode=True) @@ -133,6 +159,14 @@ def test_storage_align_invalid_axis(): s.storage_align(B, 0, axis=2, factor=128, offset=127) +def test_storage_align_invalid_annotation(): + func = element_wise_invalid_annotation + s = tir.Schedule(func, debug_mode=True) + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=2, factor=128, offset=127) + + if __name__ == "__main__": test_storage_align() test_storage_align_update() @@ -141,3 +175,4 @@ def test_storage_align_invalid_axis(): test_storage_align_invalid_buffer() test_storage_align_invalid_buffer_index() test_storage_align_invalid_axis() + test_storage_align_invalid_annotation()