diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index b22d8beddbab..1c260d9d748a 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -313,6 +313,8 @@ class PyDatabaseNode : public DatabaseNode { */ class Database : public runtime::ObjectRef { public: + /*! An in-memory database. */ + TVM_DLL static Database MemoryDatabase(); /*! * \brief Create a default database that uses JSON file for tuning records. * \param path_workload The path to the workload table. diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 95d937cc77aa..f50e5a1afa94 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -15,52 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database that stores TuningRecords in memory""" -from typing import List +from tvm._ffi import register_object -from ...ir import IRModule, structural_equal -from ..utils import derived_object -from .database import PyDatabase, TuningRecord, Workload +from .. import _ffi_api +from .database import Database -@derived_object -class MemoryDatabase(PyDatabase): - """An in-memory database based on python list for testing.""" +@register_object("meta_schedule.MemoryDatabase") +class MemoryDatabase(Database): + """An in-memory database""" - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> bool: - for workload in self.workload_reg: - if structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def get_all_tuning_records(self) -> List[TuningRecord]: - return self.records - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DatabaseMemoryDatabase, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc new file mode 100644 index 000000000000..a00d5501ad1d --- /dev/null +++ b/src/meta_schedule/database/memory_database.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class MemoryDatabaseNode : public DatabaseNode { + public: + Array records; + Array workloads; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("records", &records); + v->Visit("workloads", &workloads); + } + + static constexpr const char* _type_key = "meta_schedule.MemoryDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(MemoryDatabaseNode, DatabaseNode); + + public: + bool HasWorkload(const IRModule& mod) final { + for (const auto& workload : workloads) { + if (StructuralEqual()(workload->mod, mod)) { + return true; + } + } + return false; + } + + Workload CommitWorkload(const IRModule& mod) { + for (const auto& workload : workloads) { + if (StructuralEqual()(workload->mod, mod)) { + return workload; + } + } + Workload workload(mod, StructuralHash()(mod)); + workloads.push_back(workload); + return workload; + } + + void CommitTuningRecord(const TuningRecord& record) { records.push_back(record); } + + Array GetTopK(const Workload& workload, int top_k) { + std::vector> results; + results.reserve(this->records.size()); + for (const TuningRecord& record : records) { + if (!record->run_secs.defined()) { + continue; + } + Array run_secs = record->run_secs.value(); + if (run_secs.empty()) { + continue; + } + if (record->workload.same_as(workload)) { + double sum = 0.0; + for (const FloatImm& i : run_secs) { + sum += i->value; + } + results.emplace_back(sum / run_secs.size(), record); + } + } + std::sort(results.begin(), results.end()); + auto begin = results.begin(); + auto end = results.end(); + if (static_cast(results.size()) > top_k) { + end = begin + top_k; + } + Array ret; + ret.reserve(end - begin); + while (begin != end) { + ret.push_back(begin->second); + ++begin; + } + return ret; + } + + Array GetAllTuningRecords() { return records; } + + int64_t Size() { return records.size(); } +}; + +Database Database::MemoryDatabase() { + ObjectPtr n = make_object(); + n->records.clear(); + n->workloads.clear(); + return Database(n); +} + +TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") + .set_body_typed(Database::MemoryDatabase); + +} // namespace meta_schedule +} // namespace tvm