Skip to content

Commit

Permalink
example
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Aug 9, 2021
1 parent c8d0f59 commit 1bf6f27
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
46 changes: 46 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,52 @@ def storage_align(self, block: BlockRV, buffer_index: int, axis: int, factor: in
The factor multiple of alignment.
offset : int
The required offset factor.
Examples
--------
Before storage_align, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_storage_align(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do storage_align:
.. code-block:: python
sch = tir.Schedule(before_storage_align)
sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1)
print(tvm.script.asscript(sch.mod["main"]))
After applying rfactor, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_storage_align(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
After lowering passes, buffer B will have strides as [129, 1].
Note
----
Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
_ffi_api.ScheduleStorageAlign(self, block, buffer_index, axis, factor, offset) # type: ignore # pylint: disable=no-member

Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {mod_}; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }

private:
IRModule mod_;
Expand Down

0 comments on commit 1bf6f27

Please sign in to comment.