Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Aug 5, 2021
1 parent d49ac86 commit 2912342
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 74 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class ScheduleNode : public runtime::Object {
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
/******** Schedule: Block annotations ********/
/******** Schedule: Block annotation ********/
/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
"""
return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member

######## Schedule: Block annotatoins ########
######## Schedule: Block annotatoin ########

def storage_align(self, block: BlockRV, buffer_index: int, axis: int, factor: int,
offset: int) -> None:
Expand Down
23 changes: 23 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
*/
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);

/******** Block-block relation ********/

/*!
* \brief Get the BlockRealize of the single child block of the block or loop specified by
* `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
* \param self The schedule state
* \param block The queried block
* \return The buffer of the n-th write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);

/******** Commutative Reducer ********/

/*!
Expand All @@ -224,6 +236,17 @@ std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs);

/******** Annotation ********/

/*!
* \brief Create a new block with the given annotation added
* \param block The block with original annotation
* \param attr_key The annotation key to be added
* \param attr_value The annotation value to be added
* \return A new block with the given annotation as its last annotation
*/
Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value);

} // namespace tir
} // namespace tvm

Expand Down
51 changes: 51 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,57 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
}
}

/******** Block-buffer relation ********/

Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
class WriteBufferIndexOutOfRangeError : public ScheduleError {
public:
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}

String FastErrorString() const final {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in range "
"[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
"written by the block.";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
size_t num_writes = block_->writes.size();
os << "The block {0} has " << num_writes
<< " write regions, so `buffer_index` is required to be in [0, "
<< num_writes << "). However, the input `buffer_index` is " << buffer_index_
<< ", which is out of the expected range";
return os.str();
}

static Buffer CheckAndGetBuffer(const IRModule& mod, const Block& block, int buffer_index) {
if (buffer_index < 0 || buffer_index > block->writes.size()) {
throw WriteBufferIndexOutOfRangeError(mod, block, buffer_index);
}
return block->writes[buffer_index]->buffer;
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {mod_}; }

private:
IRModule mod_;
Block block_;
int buffer_index_;
};
return WriteBufferIndexOutOfRangeError::CheckAndGetBuffer(self->mod, block, n);
}

/******** Annotation ********/
Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) {
Map<String, ObjectRef> annotations = block->annotations;
annotations.Set(attr_key, attr_value);
ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
new_block->annotations = std::move(annotations);
return Block(new_block);
}

/******** Pattern Matcher ********/

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
this->state_->DebugVerify();
}

/******** Schedule: block annotations ********/
/******** Schedule: block annotation ********/

void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
int factor, int offset) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
/******** Schedule: Block annotations ********/
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
/******** Schedule: Blockize & Tensorize ********/
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref
* \return The sref of the rfactor block
*/
TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis);
/******** Schedule: Block annotations ********/
/******** Schedule: Block annotation ********/
/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,6 @@
namespace tvm {
namespace tir {

/*!
* \brief Create a new block with the given annotation added
* \param block The block with original annotation
* \param attr_key The annotation key to be added
* \param attr_value The annotation value to be added
* \return A new block with the given annotation as its last annotation
*/
Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) {
Map<String, ObjectRef> annotations = block->annotations;
annotations.Set(attr_key, attr_value);
ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
new_block->annotations = std::move(annotations);
return Block(new_block);
}

class StorageAlignAxisOutOfRangeError : public ScheduleError {
public:
explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis)
Expand Down Expand Up @@ -78,43 +63,6 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError {
int axis_;
};

class WriteBufferIndexOutOfRangeError : public ScheduleError {
public:
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}

String FastErrorString() const final {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in range "
"[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
"written by the block.";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
size_t num_writes = block_->writes.size();
os << "The block {0} has " << num_writes
<< " write regions, so `buffer_index` is required to be in [0, "
<< num_writes << "). However, the input `buffer_index` is " << buffer_index_
<< ", which is out of the expected range";
return os.str();
}

static Buffer CheckAndGetBuffer(const IRModule& mod, const Block& block, int buffer_index) {
if (buffer_index < 0 || buffer_index > block->writes.size()) {
throw WriteBufferIndexOutOfRangeError(mod, block, buffer_index);
}
return block->writes[buffer_index]->buffer;
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {mod_}; }

private:
IRModule mod_;
Block block_;
int buffer_index_;
};

/*!
* \brief Find the defining site of the buffer in the given block and its ancestors
* \param block_sref The block sref
Expand Down Expand Up @@ -189,7 +137,7 @@ class NonAllocatedBufferError : public ScheduleError {

class StorageAlignInvalidFactorError : public ScheduleError {
public:
explicit StorageAlignInvalidFactorError(const IRModule& mod, int factor)
explicit StorageAlignInvalidFactorError(IRModule mod, int factor)
: mod_(std::move(mod)), factor_(factor) {}

String FastErrorString() const final {
Expand Down Expand Up @@ -219,29 +167,86 @@ class StorageAlignInvalidFactorError : public ScheduleError {
int factor_;
};

class StorageAlignInvalidAnnotationError : public ScheduleError {
public:
explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block)
: mod_(std::move(mod)), block_(std::move(block)) {}

String FastErrorString() const final {
return "ScheduleError: The block annotation for storage align is expected to be an array of "
"3-integer-tuples (axis, factor, offset).";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
os << "The block annotation for storage align is expected to be an array of 3-integer-tuples "
"(axis, factor, offset). However, the block annotation with key "
<< attr::buffer_dim_align << " of the block {0} is "
<< block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected.";
return os.str();
}

static Array<Array<Array<Integer>>> CheckAndGetAnnotation(const IRModule& mod,
const Block& block) {
// Get existing annotation value.
auto it = block->annotations.find(attr::buffer_dim_align);
if (it != block->annotations.end()) {
if (!IsValidAnnotatoin(block, (*it).second)) {
throw StorageAlignInvalidAnnotationError(mod, block);
}
return Downcast<Array<Array<Array<Integer>>>>((*it).second);
}

// Create new annotation value
Array<Array<Array<Integer>>> storage_align_annotation;
storage_align_annotation.resize(block->writes.size());
return storage_align_annotation;
}

Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
IRModule mod() const final { return mod_; }

private:
static bool IsValidAnnotatoin(const Block& block, const ObjectRef& anno_value) {
if (!anno_value->IsInstance<ArrayNode>()) {
return false;
}
const auto& buffer_annotations = Downcast<Array<ObjectRef>>(anno_value);
if (buffer_annotations.size() != block->writes.size()) {
return false;
}
for (const ObjectRef buffer_annotation: buffer_annotations) {
if (!buffer_annotation->IsInstance<ArrayNode>()) {
return false;
}
const auto& dim_annotations = Downcast<Array<ObjectRef>>(buffer_annotation);
// Check if the annotations are consist of 3-tuples.
if (dim_annotations.size() != 3) {
return false;
}
for (const ObjectRef& dim_anno_element : dim_annotations) {
if (!dim_anno_element->IsInstance<IntImmNode>()) {
return false;
}
}
}
return true;
}

IRModule mod_;
Block block_;
};

void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
int factor, int offset) {
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
Buffer buffer = WriteBufferIndexOutOfRangeError::CheckAndGetBuffer(
self->mod, GetRef<Block>(block_ptr), buffer_index);
Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr), buffer_index);
StorageAlignInvalidFactorError::Check(self->mod, factor);
axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis);
NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);

// Step 1: Get existing or create new annotation value.
auto it = block_ptr->annotations.find(attr::buffer_dim_align);

// Use an array to store the storage alignement information for each output tensor.
// For each output tensor, we use an array of tuples (axis, factor, offset) to specify storage
// alignment for each dimension.
Array<Array<Array<Integer>>> storage_align_annotation;

if (it != block_ptr->annotations.end()) {
storage_align_annotation = Downcast<Array<Array<Array<Integer>>>>((*it).second);
ICHECK(storage_align_annotation.size() == block_ptr->writes.size());
} else {
storage_align_annotation.resize(block_ptr->writes.size());
}
auto storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, GetRef<Block>(block_ptr));

// Step 2: Update the annotation value
Array<Array<Integer>> buffer_storage_align = storage_align_annotation[buffer_index];
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline")
/******** (FFI) Reduction ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor")
.set_body_method<Schedule>(&ScheduleNode::RFactor);
/******** (FFI) Block annotations ********/
/******** (FFI) Block annotation ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
.set_body_method<Schedule>(&ScheduleNode::StorageAlign);
/******** (FFI) Blockize & Tensorize ********/
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_tir_schedule_storage_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,33 @@ def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None:
C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))


@tvm.script.tir
def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None:
C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
# body
with tir.block([], "root"):
tir.reads([])
tir.writes([])
B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
for i0 in tir.serial(0, 128):
for ax1 in tir.serial(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.block_attr({"buffer_dim_align": [0]})
tir.bind(vi, i0)
tir.bind(vj, ax1)
tir.reads([A[vi, vj]])
tir.writes([B[vi, vj]])
B[vi, vj] = (A[vi, vj]*tir.float32(2))
for i1 in tir.serial(0, 128):
with tir.block([128, 128], "C") as [vi_1, vj_1]:
tir.bind(vi_1, i0)
tir.bind(vj_1, i1)
tir.reads([B[vi_1, vj_1]])
tir.writes([C[vi_1, vj_1]])
C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))


def test_storage_align():
func = element_wise
s = tir.Schedule(func, debug_mode=True)
Expand Down Expand Up @@ -133,6 +160,14 @@ def test_storage_align_invalid_axis():
s.storage_align(B, 0, axis=2, factor=128, offset=127)


def test_storage_align_invalid_annotation():
func = element_wise_invalid_annotation
s = tir.Schedule(func, debug_mode=True)
B = s.get_block("B")
with pytest.raises(tir.ScheduleError):
s.storage_align(B, 0, axis=2, factor=128, offset=127)


if __name__ == "__main__":
test_storage_align()
test_storage_align_update()
Expand All @@ -141,3 +176,4 @@ def test_storage_align_invalid_axis():
test_storage_align_invalid_buffer()
test_storage_align_invalid_buffer_index()
test_storage_align_invalid_axis()
test_storage_align_invalid_annotation()

0 comments on commit 2912342

Please sign in to comment.