From 321e30fe7aab7bb634a3fcda39ea11c862fde205 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 6 Aug 2021 17:46:53 +0100 Subject: [PATCH] [TIR][USMP] further changes for extract buffer info * moved the comparison to a lambda * lint fixes Change-Id: If917a3ec12d2a5689eb584e0ac5918a39f9ac12e --- python/tvm/tir/ir_builder.py | 6 +++++- src/tir/usmp/analysis/extract_buffer_info.cc | 19 +++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 9abdcb4e5ea43..05ff20de88e1e 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -422,7 +422,11 @@ def allocate(self, dtype, shape, name="buf", scope="", pinned_memory=""): buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x, pinned_memory)) + self.emit( + lambda x: _stmt.Allocate( + buffer_var, dtype, shape, const(1, dtype="uint1"), x, pinned_memory + ) + ) return BufferVar(self, buffer_var, shape, dtype) def pointer(self, content_type, name="ptr", scope=""): diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 7997aec62cdeb..8faae828f9c9e 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -212,15 +212,6 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ size_t tick; LivenessEventType le_type; Allocate allocate; - bool operator<(const LivenessEvent& other) { - if (tick < other.tick) { - return true; - } else if (tick == other.tick && le_type == START && other.le_type == END) { - return true; - } - return false; - } - bool operator==(const LivenessEvent& other) { if (tick == other.tick && le_type == other.le_type && allocate == other.allocate) { return true; @@ -252,7 +243,15 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ le_events.push_back(le_event_end); } - std::sort(le_events.begin(), le_events.end()); + std::sort(le_events.begin(), le_events.end(), + [](const LivenessEvent& lhs, const LivenessEvent& rhs) { + if (lhs.tick < rhs.tick) { + return true; + } else if (lhs.tick == rhs.tick && lhs.le_type == START && rhs.le_type == END) { + return true; + } + return false; + }); std::unordered_set open_set; for (const auto& le_event : le_events) { if (le_event.le_type == START) {