Skip to content

Commit

Permalink
[TIR][USMP] further changes for extract buffer info
Browse files Browse the repository at this point in the history
* moved the comparison to a lambda
* lint fixes

Change-Id: If917a3ec12d2a5689eb584e0ac5918a39f9ac12e
  • Loading branch information
manupak committed Aug 6, 2021
1 parent 0271916 commit 321e30f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
6 changes: 5 additions & 1 deletion python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,11 @@ def allocate(self, dtype, shape, name="buf", scope="", pinned_memory=""):
buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope))
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x, pinned_memory))
self.emit(
lambda x: _stmt.Allocate(
buffer_var, dtype, shape, const(1, dtype="uint1"), x, pinned_memory
)
)
return BufferVar(self, buffer_var, shape, dtype)

def pointer(self, content_type, name="ptr", scope=""):
Expand Down
19 changes: 9 additions & 10 deletions src/tir/usmp/analysis/extract_buffer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,6 @@ Map<tir::Stmt, BufferInfo> BufferInfoExtractor::operator()(const PrimFunc& main_
size_t tick;
LivenessEventType le_type;
Allocate allocate;
bool operator<(const LivenessEvent& other) {
if (tick < other.tick) {
return true;
} else if (tick == other.tick && le_type == START && other.le_type == END) {
return true;
}
return false;
}

bool operator==(const LivenessEvent& other) {
if (tick == other.tick && le_type == other.le_type && allocate == other.allocate) {
return true;
Expand Down Expand Up @@ -252,7 +243,15 @@ Map<tir::Stmt, BufferInfo> BufferInfoExtractor::operator()(const PrimFunc& main_
le_events.push_back(le_event_end);
}

std::sort(le_events.begin(), le_events.end());
std::sort(le_events.begin(), le_events.end(),
[](const LivenessEvent& lhs, const LivenessEvent& rhs) {
if (lhs.tick < rhs.tick) {
return true;
} else if (lhs.tick == rhs.tick && lhs.le_type == START && rhs.le_type == END) {
return true;
}
return false;
});
std::unordered_set<Allocate, ObjectPtrHash, ObjectPtrEqual> open_set;
for (const auto& le_event : le_events) {
if (le_event.le_type == START) {
Expand Down

0 comments on commit 321e30f

Please sign in to comment.