Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule] Migrate MemoryDatabase to C++ (apache#12514)
Browse files Browse the repository at this point in the history
This PR migrates the existing MemoryDatabase, which is implemented in
python at the moment, to C++. The original intent of having an in-memory
database that does not persist on disk is merely for testing, but as
times go on, we found it useful in production workflow, and thus decided
to migrate it C++ for potentially better performance.
  • Loading branch information
junrushao authored and xinetzone committed Nov 25, 2022
1 parent 785d81f commit bd5d979
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 45 deletions.
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ class PyDatabaseNode : public DatabaseNode {
*/
class Database : public runtime::ObjectRef {
public:
/*! An in-memory database. */
TVM_DLL static Database MemoryDatabase();
/*!
* \brief Create a default database that uses JSON file for tuning records.
* \param path_workload The path to the workload table.
Expand Down
55 changes: 10 additions & 45 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""A database that stores TuningRecords in memory"""
from typing import List
from tvm._ffi import register_object

from ...ir import IRModule, structural_equal
from ..utils import derived_object
from .database import PyDatabase, TuningRecord, Workload
from .. import _ffi_api
from .database import Database


@derived_object
class MemoryDatabase(PyDatabase):
"""An in-memory database based on python list for testing."""
@register_object("meta_schedule.MemoryDatabase")
class MemoryDatabase(Database):
"""An in-memory database"""

def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> bool:
for workload in self.workload_reg:
if structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def get_all_tuning_records(self) -> List[TuningRecord]:
return self.records

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.DatabaseMemoryDatabase, # type: ignore # pylint: disable=no-member
)
111 changes: 111 additions & 0 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class MemoryDatabaseNode : public DatabaseNode {
public:
Array<TuningRecord> records;
Array<Workload> workloads;

void VisitAttrs(AttrVisitor* v) {
v->Visit("records", &records);
v->Visit("workloads", &workloads);
}

static constexpr const char* _type_key = "meta_schedule.MemoryDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(MemoryDatabaseNode, DatabaseNode);

public:
bool HasWorkload(const IRModule& mod) final {
for (const auto& workload : workloads) {
if (StructuralEqual()(workload->mod, mod)) {
return true;
}
}
return false;
}

Workload CommitWorkload(const IRModule& mod) {
for (const auto& workload : workloads) {
if (StructuralEqual()(workload->mod, mod)) {
return workload;
}
}
Workload workload(mod, StructuralHash()(mod));
workloads.push_back(workload);
return workload;
}

void CommitTuningRecord(const TuningRecord& record) { records.push_back(record); }

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) {
std::vector<std::pair<double, TuningRecord>> results;
results.reserve(this->records.size());
for (const TuningRecord& record : records) {
if (!record->run_secs.defined()) {
continue;
}
Array<FloatImm> run_secs = record->run_secs.value();
if (run_secs.empty()) {
continue;
}
if (record->workload.same_as(workload)) {
double sum = 0.0;
for (const FloatImm& i : run_secs) {
sum += i->value;
}
results.emplace_back(sum / run_secs.size(), record);
}
}
std::sort(results.begin(), results.end());
auto begin = results.begin();
auto end = results.end();
if (static_cast<int>(results.size()) > top_k) {
end = begin + top_k;
}
Array<TuningRecord> ret;
ret.reserve(end - begin);
while (begin != end) {
ret.push_back(begin->second);
++begin;
}
return ret;
}

Array<TuningRecord> GetAllTuningRecords() { return records; }

int64_t Size() { return records.size(); }
};

Database Database::MemoryDatabase() {
ObjectPtr<MemoryDatabaseNode> n = make_object<MemoryDatabaseNode>();
n->records.clear();
n->workloads.clear();
return Database(n);
}

TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase")
.set_body_typed(Database::MemoryDatabase);

} // namespace meta_schedule
} // namespace tvm

0 comments on commit bd5d979

Please sign in to comment.