diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 37a315bf744e9..ad6454161b65b 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -187,6 +187,11 @@ class DatabaseNode : public runtime::Object { * \return An array of top K tuning records for the given workload. */ virtual Array GetTopK(const Workload& workload, int top_k) = 0; + /*! + * \brief Get all tuning records from the database. + * \return An Array of all the tuning records in the database. + */ + virtual Array GetAllTuningRecords() = 0; /*! * \brief Get the size of the database. * \return The size of the database. @@ -224,6 +229,11 @@ class PyDatabaseNode : public DatabaseNode { * \return An array of top K tuning records for the given workload. */ using FGetTopK = runtime::TypedPackedFunc(const Workload&, int)>; + /*! + * \brief The function type of `GetAllTuningRecords` method. + * \return An Array of all the tuning records in the database. + */ + using FGetAllTuningRecords = runtime::TypedPackedFunc()>; /*! * \brief The function type of `Size` method. * \return The size of the database. @@ -238,6 +248,8 @@ class PyDatabaseNode : public DatabaseNode { FCommitTuningRecord f_commit_tuning_record; /*! \brief The packed function to the `GetTopK` function. */ FGetTopK f_get_top_k; + /*! \brief The packed function to the `GetAllTuningRecords` function. */ + FGetAllTuningRecords f_get_all_tuning_records; /*! \brief The packed function to the `Size` function. */ FSize f_size; @@ -273,6 +285,12 @@ class PyDatabaseNode : public DatabaseNode { return f_get_top_k(workload, top_k); } + Array GetAllTuningRecords() final { + ICHECK(f_get_all_tuning_records != nullptr) + << "PyDatabase's GetAllTuningRecords method not implemented!"; + return f_get_all_tuning_records(); + } + int64_t Size() final { ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; return f_size(); @@ -309,6 +327,7 @@ class Database : public runtime::ObjectRef { PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, PyDatabaseNode::FGetTopK f_get_top_k, + PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, PyDatabaseNode::FSize f_size); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); }; diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 3d732e7fbd992..cdc82502d19b8 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -99,6 +99,9 @@ class TuneContextNode : public runtime::Object { /*! \brief Initialize members that needs initialization with tune context. */ void Initialize(); + /*! \brief Construct the measure candidate given initial IR module and trace. */ + MeasureCandidate _GetMeasureCandidate( + const IRModule& mod, const tir::Trace& trace); /*! \brief Set the measure candidates from the SearchStrategy */ void _SetMeasureCandidates(const Array& candidates); /*! diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 802a739e69582..c1d749852550f 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -203,6 +203,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: """ return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member + def get_all_tuning_records(self) -> List[TuningRecord]: + """Get all the tuning records from the database. + + Returns + ------- + tuning_records : List[TuningRecord] + All tuning records from the database. + """ + return _ffi_api.DatabaseGetAllTuningRecords(self) # type: ignore # pylint: disable=no-member + def __len__(self) -> int: """Get the number of records in the database. @@ -229,6 +239,7 @@ def __init__( f_commit_workload: Callable = None, f_commit_tuning_record: Callable = None, f_get_top_k: Callable = None, + f_get_all_tuning_records : Callable = None, f_size: Callable = None, ): """Constructor.""" @@ -239,6 +250,7 @@ def __init__( f_commit_workload, f_commit_tuning_record, f_get_top_k, + f_get_all_tuning_records, f_size, ) @@ -258,6 +270,7 @@ class PyDatabase: "commit_workload", "commit_tuning_record", "get_top_k", + "get_all_tuning_records", "__len__", ], } @@ -317,6 +330,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: """ raise NotImplementedError + def get_all_tuning_records(self) -> List[TuningRecord]: + """Get all the tuning records from the database. + + Returns + ------- + tuning_records : List[TuningRecord] + All tuning records from the database. + """ + raise NotImplementedError + def __len__(self) -> int: """Get the number of records in the database. diff --git a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py index c80d78173e2e4..35b872e7351e8 100644 --- a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py +++ b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py @@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name): ------- None """ + candidate_path = os.path.join( + args.candidate_cache_dir, model_name, task_name + "_candidates.json" + ) + workload_path = os.path.join(args.candidate_cache_dir, model_name, task_name + "_workload.json") + database = ms.database.JSONDatabase( + path_workload=workload_path, + path_tuning_record=candidate_path, + ) sample_init_population = tvm.get_global_func( "meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation" ) @@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name): context.initialize() context.pre_tuning( context.generate_design_space(), - database=ms.database.MemoryDatabase(), # type: ignore + database=database, cost_model=ms.cost_model.RandomModel(), # type: ignore ) @@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name): all_states = all_states[: args.num_samples_per_task] workload = ms.database.Workload(context.mod) - file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json") - with open(file_path, "w", encoding="utf8") as file: - for i, state in enumerate(all_states): - tuning_record = ms.database.TuningRecord(state.trace, workload) - json_str = json.dumps(tuning_record.as_json()) - assert "\n" not in json_str, "Failed to generate single line string." - if i == len(all_states) - 1: - file.write(json_str) - else: - file.write(json_str + "\n") + database.commit_workload(context.mod) + for state in all_states: + database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload)) args = _parse_args() # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/testing/distributed_measure_candidates.py b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py new file mode 100644 index 0000000000000..739fdf2312bde --- /dev/null +++ b/python/tvm/meta_schedule/testing/distributed_measure_candidates.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import argparse +import glob +import os +import time + +from tqdm import tqdm # type: ignore +from tvm import meta_schedule as ms +from tvm.target import Target + + +def _parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--candidate_cache_dir", type=str, help="Please provide the full path to the candidates." + ) + parser.add_argument( + "--result_cache_dir", type=str, help="Please provide the full path to the result database." + ) + parser.add_argument( + "--target", + type=str, + default="nvidia/nvidia-v100", + help="Please specify the target hardware for tuning context.", + ) + parser.add_argument( + "--rpc_host", type=str, help="Please provide the private IPv4 address for the tracker." + ) + parser.add_argument( + "--rpc_port", type=int, default=4445, help="Please provide the port for the tracker." + ) + parser.add_argument( + "--rpc_key", + type=str, + default="p3.2xlarge", + help="Please provide the key for the rpc servers.", + ) + parser.add_argument( + "--builder_timeout_sec", + type=int, + default=10, + help="The time for the builder session to time out.", + ) + parser.add_argument( + "--min_repeat_ms", type=int, default=100, help="The time for preheating the gpu." + ) + parser.add_argument( + "--runner_timeout_sec", + type=int, + default=100, + help="The time for the runner session to time out.", + ) + return parser.parse_args() + + +# pylint: disable=too-many-locals +def measure_candidates(database, builder, runner): + """Send the candidates to builder and runner for distributed measurement, + and save the results in a new json database. + + Parameters + ---------- + database : JSONDatabase + The database for candidates to be measured. + builder : Builder + The builder for building the candidates. + runner : Runner + The runner for measuring the candidates. + + Returns + ------- + None + """ + context = ms.TuneContext(target=Target(args.target)) + tuning_records = database.get_all_tuning_records() + candidates = [] + for record in tuning_records: + candidate = context.get_measure_candidate(record.workload.mod, record.trace) + candidates.append(candidate) + context.set_measure_candidates(candidates) + + build_start_time = time.time() + context.send_to_builder(builder) + build_end_time = time.time() + build_fail_indices = [] + for i, result in enumerate(context.builder_results): + if result.error_msg is not None: + build_fail_indices.append(i) + print( + f"Builder time: {build_end_time - build_start_time}\n\ + Failed number of builds: {len(build_fail_indices)}" + ) + + context.send_to_runner(runner) + runner_results = context.join() + for result in context.builder_results: + ms.utils.remove_build_dir(result.artifact_path) + context.clear_measure_state() + run_end_time = time.time() + run_fail_indices = [] + for i, result in enumerate(runner_results): + if result.error_msg is not None: + run_fail_indices.append(i) + print( + f"Runner time: {run_end_time - build_end_time}\n\ + Failed number of runs: {len(run_fail_indices)}" + ) + + model_name, workload_name = database.path_workload.split("/")[-2:] + record_name = database.path_tuning_record.split("/")[-1] + new_database = ms.database.JSONDatabase( + path_workload=os.path.join(args.result_cache_dir, model_name, workload_name), + path_tuning_record=os.path.join(args.result_cache_dir, model_name, record_name), + ) + workload = tuning_records[0].workload + new_database.commit_workload(workload.mod) + for record, result in zip(tuning_records, runner_results): + if result.error_msg is None: + new_database.commit_tuning_record( + ms.database.TuningRecord( + trace=record.trace, + workload=workload, + run_secs=[v.value for v in result.run_secs], + target=Target(args.target), + ) + ) + fail_indices_name = workload_name[:-13] + "failed_indices.txt" + with open( + os.path.join(args.result_cache_dir, model_name, fail_indices_name), "w", encoding="utf8" + ) as file: + file.write(" ".join(run_fail_indices)) + + +args = _parse_args() # pylint: disable=invalid-name + + +def main(): + builder = ms.builder.LocalBuilder(timeout_sec=args.builder_timeout_sec) + runner = ms.runner.RPCRunner( + rpc_config=ms.runner.RPCConfig( + tracker_host=args.rpc_host, + tracker_port=args.rpc_port, + tracker_key=args.rpc_key, + session_timeout_sec=args.runner_timeout_sec, + ), + evaluator_config=ms.runner.EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=args.min_repeat_ms, + enable_cpu_cache_flush=False, + ), + max_workers=os.cpu_count(), + ) + if not os.path.isdir(args.candidate_cache_dir): + raise Exception("Please provide a correct candidate cache dir.") + try: + os.makedirs(args.result_cache_dir, exist_ok=True) + except OSError: + print(f"Directory {args.result_cache_dir} cannot be created successfully.") + model_dirs = glob.glob(os.path.join(args.candidate_cache_dir, "*")) + for model_dir in model_dirs: + model_name = model_dir.split("/")[-1] + os.makedirs(os.path.join(args.result_cache_dir, model_name), exist_ok=True) + all_tasks = glob.glob(os.path.join(model_dir, "*.json")) + workload_paths = [] + for path in all_tasks: + if "workload" in path: + workload_paths.append(path) + for workload_path in tqdm(workload_paths): + candidate_path = workload_path[:-13] + "candidates.json" + database = ms.database.JSONDatabase( + path_workload=workload_path, + path_tuning_record=candidate_path, + ) + measure_candidates(database, builder, runner) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 78fd3d659fafa..b97777df81ec8 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -236,3 +236,64 @@ def notify_runner_results( measure_candidates, results, ) + + def get_measure_candidate(self, mod, trace): + """Generate a measure candidate given an initial IR module and a trace. + + Parameters + ----------- + mod : IRModule + The initial IR module. + trace : Trace + The trace applying to the IR Module. + + Returns + ------- + candidate : MeasureCandidate + A generated candidate. + """ + return _ffi_api.TuneContextGetMeasureCandidate(self, mod, trace) + + def set_measure_candidates(self, candidates): + """Set candidates in a tuning context. + + Parameters + ---------- + candidates : List[MeasureCandidate] + A list of measure candidates for the tuning context. + """ + _ffi_api.TuneContextSetMeasureCandidates(self, candidates) + + def send_to_builder(self, builder): + """Send candidates to builder. + + Parameters + ---------- + builder : Builder + The builder for building the candidates. + """ + _ffi_api.TuneContextSendToBuilder(self, builder) + + def send_to_runner(self, runner): + """Send candidates to runner. + + Parameters + ---------- + runner : Runner + The runner for running the candidates. + """ + _ffi_api.TuneContextSendToRunner(self, runner) + + def join(self): + """Join the runner processes. + + Returns + ------- + result : List[RunnerResult] + The runner results. + """ + return _ffi_api.TuneContextJoin(self) + + def clear_measure_state(self): + """Clear the measure states.""" + _ffi_api.TuneContextClearMeasureState(self) diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 9905ff73c792c..47561bd845cfd 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -152,7 +152,9 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, PyDatabaseNode::FCommitWorkload f_commit_workload, PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, - PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size) { + PyDatabaseNode::FGetTopK f_get_top_k, + PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records, + PyDatabaseNode::FSize f_size) { ObjectPtr n = make_object(); n->f_has_workload = f_has_workload; n->f_commit_workload = f_commit_workload; @@ -190,6 +192,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") .set_body_method(&DatabaseNode::CommitTuningRecord); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK") .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") + .set_body_method(&DatabaseNode::GetAllTuningRecords); TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 4f5bd9b136131..9bb7ee1027b99 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -156,6 +156,15 @@ class JSONDatabaseNode : public DatabaseNode { return results; } + Array GetAllTuningRecords() { + Array results; + results.reserve(Size()); + for (const TuningRecord& record : this->tuning_records_) { + results.push_back(record); + } + return results; + } + int64_t Size() { return tuning_records_.size(); } }; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 362db0a380971..d2215562c555b 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -70,6 +70,19 @@ void TuneContextNode::Initialize() { } } +MeasureCandidate TuneContextNode::_GetMeasureCandidate( + const IRModule& mod, const tir::Trace& trace) { + tir::Schedule sch = tir::Schedule::Traced(mod, -1, 0, tir::ScheduleErrorRenderLevel::kDetail); + trace->ApplyToSchedule(sch, false, nullptr); + tir::PrimFunc func; + for (const auto& kv : sch->mod()->functions) { + func = Downcast(kv.second); + } + Array args_info = ArgInfo::FromPrimFunc(func); + MeasureCandidate candidate = MeasureCandidate(sch, args_info); + return candidate; +} + void TuneContextNode::_SetMeasureCandidates(const Array& candidates) { this->measure_candidates = candidates; } @@ -137,7 +150,9 @@ Array TuneContextNode::_Join() { for (RunnerFuture future : futures) { results.push_back(future->Result()); } - this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(), results); + if (this->search_strategy.defined()) { + this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(), results); + } ICHECK(this->measure_candidates.defined()); ICHECK(this->builder_results.defined()); ICHECK_EQ(results.size(), this->measure_candidates.value().size()); @@ -172,6 +187,18 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") .set_body_method(&TuneContextNode::Initialize); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextGetMeasureCandidate") + .set_body_method(&TuneContextNode::_GetMeasureCandidate); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSetMeasureCandidates") + .set_body_method(&TuneContextNode::_SetMeasureCandidates); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToBuilder") + .set_body_method(&TuneContextNode::_SendToBuilder); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToRunner") + .set_body_method(&TuneContextNode::_SendToRunner); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextJoin") + .set_body_method(&TuneContextNode::_Join); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClearMeasureState") + .set_body_method(&TuneContextNode::_ClearMeasureState); } // namespace meta_schedule } // namespace tvm