Skip to content

Commit

Permalink
[M2a] Storage align (#422)
Browse files Browse the repository at this point in the history
* Storage align

* Update src/tir/schedule/primitive/bind_annotate.cc

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

* remove blank lines

* address comments

* Update python/tvm/tir/schedule/schedule.py

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* -

* lint

* -

* Update src/tir/schedule/analysis.h

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* Update src/tir/schedule/analysis.h

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* example

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
  • Loading branch information
3 people authored Aug 9, 2021
1 parent 3145867 commit 95805ae
Show file tree
Hide file tree
Showing 12 changed files with 795 additions and 3 deletions.
12 changes: 12 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,18 @@ class ScheduleNode : public runtime::Object {
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
/******** Schedule: Block annotation ********/
/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k.
* \param block_rv The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param axis The dimension to be specified for alignment
* \param factor The factor multiple of alignment
* \param offset The required offset factor
*/
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) = 0;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
68 changes: 68 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,74 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
"""
return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member

######## Schedule: Block annotatoin ########

def storage_align(self, block: BlockRV, buffer_index: int, axis: int, factor: int,
offset: int) -> None:
"""Set alignment requirement for specific dimension such that
stride[axis] == k * factor + offset for some k.
Parameters
----------
block : BlockRV
The producer block of the buffer.
buffer_index : int
The index of the buffer in block's write region.
axis : int
The dimension to be specified for alignment.
factor : int
The factor multiple of alignment.
offset : int
The required offset factor.
Examples
--------
Before storage_align, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_storage_align(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do storage_align:
.. code-block:: python
sch = tir.Schedule(before_storage_align)
sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1)
print(tvm.script.asscript(sch.mod["main"]))
After applying rfactor, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_storage_align(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
After lowering passes, buffer B will have strides as [129, 1].
Note
----
Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
_ffi_api.ScheduleStorageAlign(self, block, buffer_index, axis, factor, offset) # type: ignore # pylint: disable=no-member

########## Schedule: Blockize & Tensorize ##########

########## Schedule: Annotation ##########
Expand Down
24 changes: 24 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,19 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
*/
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);

/******** Block-buffer relation ********/

/*!
* \brief Get the BlockRealize of the single child block of the block or loop specified by
* `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
* \param self The schedule state
* \param block The queried block
* \param n The index of the queried buffer
* \return The buffer of the n-th write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);

/******** Commutative Reducer ********/

/*!
Expand All @@ -224,6 +237,17 @@ std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs);

/******** Annotation ********/

/*!
* \brief Create a new block with the given annotation added
* \param block The block with original annotation
* \param attr_key The annotation key to be added
* \param attr_value The annotation value to be added
* \return A new block with the given annotation as its last annotation
*/
Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value);

} // namespace tir
} // namespace tvm

Expand Down
48 changes: 48 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,54 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
}
}

/******** Block-buffer relation ********/

Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
class WriteBufferIndexOutOfRangeError : public ScheduleError {
public:
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}

String FastErrorString() const final {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in range "
"[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
"written by the block.";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
size_t num_writes = block_->writes.size();
os << "The block {0} has " << num_writes
<< " write regions, so `buffer_index` is required to be in [0, "
<< num_writes << "). However, the input `buffer_index` is " << buffer_index_
<< ", which is out of the expected range";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }

private:
IRModule mod_;
Block block_;
int buffer_index_;
};

if (n < 0 || n > block->writes.size()) {
throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
}
return block->writes[n]->buffer;
}

/******** Annotation ********/
Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) {
Map<String, ObjectRef> annotations = block->annotations;
annotations.Set(attr_key, attr_value);
ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
new_block->annotations = std::move(annotations);
return Block(new_block);
}

/******** Pattern Matcher ********/

/*!
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
}

/******** Schedule: loop binding/annotation ********/
/******** Schedule: block annotation ********/

void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
int factor, int offset) {
TVM_TIR_SCHEDULE_BEGIN();
tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset);
TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/

Expand Down
3 changes: 3 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class ConcreteScheduleNode : public ScheduleNode {
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref
* \return The sref of the rfactor block
*/
TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis);
/******** Schedule: Block annotation ********/
/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k
* \param block_sref The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param axis The dimension to be specified for alignment
* \param factor The factor multiple of alignment
* \param offset The required offset factor
*/
TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
int axis, int factor, int offset);
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
Loading

0 comments on commit 95805ae

Please sign in to comment.