diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 202b5283c0..f0649e2541 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -729,6 +729,52 @@ def storage_align(self, block: BlockRV, buffer_index: int, axis: int, factor: in The factor multiple of alignment. offset : int The required offset factor. + + Examples + -------- + + Before storage_align, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do storage_align: + + .. code-block:: python + + sch = tir.Schedule(before_storage_align) + sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1) + print(tvm.script.asscript(sch.mod["main"])) + + After applying rfactor, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + After lowering passes, buffer B will have strides as [129, 1]. + + Note + ---- + Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`. """ _ffi_api.ScheduleStorageAlign(self, block, buffer_index, axis, factor, offset) # type: ignore # pylint: disable=no-member diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e7a419f991..5c73e275d7 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -552,7 +552,7 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {mod_}; } + Array LocationsOfInterest() const final { return {block_}; } private: IRModule mod_;