diff --git a/python/paddle/distributed/auto_tuner/prune.py b/python/paddle/distributed/auto_tuner/prune.py index 71e7d95e2ed64..888b53ee6b2b6 100644 --- a/python/paddle/distributed/auto_tuner/prune.py +++ b/python/paddle/distributed/auto_tuner/prune.py @@ -26,7 +26,7 @@ def same_cfgs_beside(attr, cur_cfg, history_cfgs): for key in cur_cfg: if key == attr: continue - if key not in history_cfgs or history_cfgs[key] != cur_cfg[key]: + if key not in cfg or cfg[key] != cur_cfg[key]: same = False break if same: @@ -56,7 +56,7 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None): hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None) vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None) - if not mp_degree: + if mp_degree is None: return False if hidden_size and hidden_size % mp_degree != 0: @@ -93,7 +93,7 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None): num_layers = tuner_cfg["model_cfg"].get("num_layers", None) num_nodes = tuner_cfg.get("num_nodes", 1) - if not pp_degree: + if pp_degree is None: return False if num_layers: @@ -128,12 +128,15 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None): // cur_cfg["dp_degree"] // cur_cfg["sharding_degree"] ) + if local_batch_size == 0: + return True + mbs_candidates = tuner_cfg.get("micro_batch_size", None) if mbs_candidates == "auto": mbs_candidates = tuner_cfg["candidates"]["micro_batch_size"] - if not micro_batch_size: + if micro_batch_size is None: return False if local_batch_size: @@ -222,7 +225,7 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs): """ recompute_granularity = cur_cfg.get("recompute_granularity", None) use_recompute = cur_cfg.get("use_recompute", None) - if not use_recompute: + if use_recompute is None: return False recompute_granularity_candidates = tuner_cfg["candidates"].get( @@ -253,10 +256,11 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs): ): return True - if use_recompute is False: + if not use_recompute: cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs) if cfgs: return True + return False diff --git a/python/paddle/distributed/auto_tuner/recorder.py b/python/paddle/distributed/auto_tuner/recorder.py index d742d751a7a2c..ad388a9bfe2f7 100644 --- a/python/paddle/distributed/auto_tuner/recorder.py +++ b/python/paddle/distributed/auto_tuner/recorder.py @@ -63,6 +63,7 @@ def store_history(self, path="./history.csv"): cols = df.columns.tolist() cols.insert(0, cols.pop(cols.index('job_id'))) df = df.reindex(columns=cols) + df = df.drop(columns=['time']) # write to csv df.to_csv(self.store_path, index=False) diff --git a/python/paddle/distributed/auto_tuner/tuner.py b/python/paddle/distributed/auto_tuner/tuner.py index 3be9519b1cbc2..26831a2e8fcbd 100644 --- a/python/paddle/distributed/auto_tuner/tuner.py +++ b/python/paddle/distributed/auto_tuner/tuner.py @@ -47,6 +47,9 @@ def search_once(self): return None new_cfg = self.algo.search_once(self.history_cfgs) self.cur_task_id += 1 - self.history_cfgs.append(new_cfg) return new_cfg + + def add_cfg(self, cfg): + """Add cfg into history cfgs""" + self.history_cfgs.append(cfg) diff --git a/python/paddle/distributed/auto_tuner/utils.py b/python/paddle/distributed/auto_tuner/utils.py index 9c322b1d7a535..8db11df08c5a9 100644 --- a/python/paddle/distributed/auto_tuner/utils.py +++ b/python/paddle/distributed/auto_tuner/utils.py @@ -38,11 +38,11 @@ def dist_degree(mode, num_gpus, num_nodes): assert mode in ["dp", "mp", "pp", "sharding"] results = [] if mode == "dp": - results = divisor(num_gpus, reverse=True) + results = divisor(num_gpus, reverse=False) elif mode == "pp": if num_nodes > 1: - results = list(range(num_nodes)) + results = list(range(1, num_nodes + 1)) else: results = divisor(num_gpus, reverse=True) @@ -123,9 +123,7 @@ def default_candidates(tuner_cfg): elif tuner_cfg.get("micro_batch_size", None): candidates["micro_batch_size"] = tuner_cfg.get("micro_batch_size") else: - candidates["micro_batch_size"] = [ - tuner_cfg["model_cfg"]["global_batch_size"] - ] + candidates["micro_batch_size"] = [None] return candidates @@ -133,7 +131,7 @@ def default_candidates(tuner_cfg): def search_all(tuner_cfg): """Permutate the candidates of all hyper params.""" candidates = tuner_cfg["candidates"] - # Order: dp -> mp -> pp -> mbs -> sharding-> recompute + # Order: dp -> sharding -> mbs -> pp -> mp -> recompute dp_degree_candidates = candidates["dp_degree"] mp_degree_candidates = candidates["mp_degree"] pp_degree_candidates = candidates["pp_degree"] @@ -145,22 +143,22 @@ def search_all(tuner_cfg): all_cfgs = list( itertools.product( dp_degree_candidates, - mp_degree_candidates, - pp_degree_candidates, - mbs_candidates, sharding_degree_candidates, sharding_stage_candidates, + mbs_candidates, + pp_degree_candidates, + mp_degree_candidates, use_recompute_candidates, recompute_granularity_candidates, ) ) mapping = { 0: "dp_degree", - 1: "mp_degree", - 2: "pp_degree", + 1: "sharding_degree", + 2: "sharding_stage", 3: "micro_batch_size", - 5: "sharding_stage", - 4: "sharding_degree", + 4: "pp_degree", + 5: "mp_degree", 6: "use_recompute", 7: "recompute_granularity", } @@ -179,48 +177,96 @@ def gen_new_args(raw_args, cfg, tuner_cfg): cmd = copy.deepcopy(tuner_cfg["run_cmd"]) res_args = copy.deepcopy(raw_args) if "dp_degree" in cmd and "dp_degree" in cfg: - cmd["dp_degree"][1] = cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"]) - res_args.extend(cmd["dp_degree"]) + if "--" in cmd["dp_degree"][0]: + cmd["dp_degree"][1] = cmd["dp_degree"][1] + str(cfg["dp_degree"]) + res_args.extend(cmd["dp_degree"]) + else: + cmd["dp_degree"][1] = ( + cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"]) + ) + res_args.extend(cmd["dp_degree"]) if "mp_degree" in cmd and "mp_degree" in cfg: - cmd["mp_degree"][1] = cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"]) - res_args.extend(cmd["mp_degree"]) + if "--" in cmd["mp_degree"][0]: + cmd["mp_degree"][1] = cmd["mp_degree"][1] + str(cfg["mp_degree"]) + res_args.extend(cmd["mp_degree"]) + else: + cmd["mp_degree"][1] = ( + cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"]) + ) + res_args.extend(cmd["mp_degree"]) if "pp_degree" in cmd and "pp_degree" in cfg: - cmd["pp_degree"][1] = cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"]) - res_args.extend(cmd["pp_degree"]) + if "--" in cmd["pp_degree"][0]: + cmd["pp_degree"][1] = cmd["pp_degree"][1] + str(cfg["pp_degree"]) + res_args.extend(cmd["pp_degree"]) + else: + cmd["pp_degree"][1] = ( + cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"]) + ) + res_args.extend(cmd["pp_degree"]) if "micro_batch_size" in cmd and "micro_batch_size" in cfg: - cmd["micro_batch_size"][1] = ( - cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"]) - ) - res_args.extend(cmd["micro_batch_size"]) + if "--" in cmd["micro_batch_size"][0]: + cmd["micro_batch_size"][1] = cmd["micro_batch_size"][1] + str( + cfg["micro_batch_size"] + ) + res_args.extend(cmd["micro_batch_size"]) + else: + cmd["micro_batch_size"][1] = ( + cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"]) + ) + res_args.extend(cmd["micro_batch_size"]) if "sharding_degree" in cmd and "sharding_degree" in cfg: - cmd["sharding_degree"][1] = ( - cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"]) - ) - res_args.extend(cmd["sharding_degree"]) + if "--" in cmd["sharding_degree"][0]: + cmd["sharding_degree"][1] = cmd["sharding_degree"][1] + str( + cfg["sharding_degree"] + ) + res_args.extend(cmd["sharding_degree"]) + else: + cmd["sharding_degree"][1] = ( + cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"]) + ) + res_args.extend(cmd["sharding_degree"]) if "sharding_stage" in cmd and "sharding_stage" in cfg: - cmd["sharding_stage"][1] = ( - cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"]) - ) - res_args.extend(cmd["sharding_stage"]) + if "--" in cmd["sharding_stage"][0]: + cmd["sharding_stage"][1] = cmd["sharding_stage"][1] + str( + cfg["sharding_stage"] + ) + res_args.extend(cmd["sharding_stage"]) + else: + cmd["sharding_stage"][1] = ( + cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"]) + ) + res_args.extend(cmd["sharding_stage"]) if "use_recompute" in cmd and "use_recompute" in cfg: - cmd["use_recompute"][1] = ( - cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"]) - ) - res_args.extend(cmd["use_recompute"]) + if "--" in cmd["use_recompute"][0]: + cmd["use_recompute"][1] = cmd["use_recompute"][1] + str( + cfg["use_recompute"] + ) + res_args.extend(cmd["use_recompute"]) + else: + cmd["use_recompute"][1] = ( + cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"]) + ) + res_args.extend(cmd["use_recompute"]) if "recompute_granularity" in cmd and "recompute_granularity" in cfg: - cmd["recompute_granularity"][1] = ( - cmd["recompute_granularity"][1] - + "=" - + str(cfg["recompute_granularity"]) - ) - res_args.extend(cmd["recompute_granularity"]) + if "--" in cmd["recompute_granularity"][0]: + cmd["recompute_granularity"][1] = cmd["recompute_granularity"][ + 1 + ] + str(cfg["recompute_granularity"]) + res_args.extend(cmd["recompute_granularity"]) + else: + cmd["recompute_granularity"][1] = ( + cmd["recompute_granularity"][1] + + "=" + + str(cfg["recompute_granularity"]) + ) + res_args.extend(cmd["recompute_granularity"]) if "local_batch_size" in cmd: local_batch_size = ( @@ -228,10 +274,48 @@ def gen_new_args(raw_args, cfg, tuner_cfg): // cfg["sharding_degree"] // cfg["dp_degree"] ) - cmd["local_batch_size"][1] = ( - cmd["local_batch_size"][1] + "=" + str(local_batch_size) - ) - res_args.extend(cmd["local_batch_size"]) + if "--" in cmd["local_batch_size"][0]: + cmd["local_batch_size"][1] = cmd["local_batch_size"][1] + str( + local_batch_size + ) + res_args.extend(cmd["local_batch_size"]) + else: + cmd["local_batch_size"][1] = ( + cmd["local_batch_size"][1] + "=" + str(local_batch_size) + ) + res_args.extend(cmd["local_batch_size"]) + + if "gradient_accumulation_steps" in cmd: + if "--" in cmd["gradient_accumulation_steps"][0]: + try: + gradient_accumulation_steps = ( + tuner_cfg["model_cfg"]["global_batch_size"] + // cfg["sharding_degree"] + // cfg["dp_degree"] + // cfg["micro_batch_size"] + ) + cmd["gradient_accumulation_steps"][1] = cmd[ + "gradient_accumulation_steps" + ][1] + str(gradient_accumulation_steps) + res_args.extend(cmd["gradient_accumulation_steps"]) + except: + pass + else: + try: + gradient_accumulation_steps = ( + tuner_cfg["model_cfg"]["global_batch_size"] + // cfg["sharding_degree"] + // cfg["dp_degree"] + // cfg["micro_batch_size"] + ) + cmd["gradient_accumulation_steps"][1] = ( + cmd["gradient_accumulation_steps"][1] + + "=" + + str(gradient_accumulation_steps) + ) + res_args.extend(cmd["gradient_accumulation_steps"]) + except: + pass return res_args @@ -245,7 +329,9 @@ def read_log( return (0.0, True) with open(target_file, "r") as f: # read file - re_metric_pattern = r'speed: (\d+(\.\d*)?) *' + target_metric + re_metric_pattern = ( + target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric + ) metric_list = [] lines = f.readlines() diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index 0d322094999c8..b252e966021bc 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -38,6 +38,8 @@ def __init__(self, enable_plugin=True): if enable_plugin: self._enable_plugin() + self.max_time_per_task = -1 + self.run_best = False def print(self): self.logger.info("----------- Configuration ----------------------") diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index bb938dd5c1f1d..ad3a811ec8f45 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -282,6 +282,13 @@ def run(self): self.job.replicas = replicas else: self.ctx.logger.warning(f"peer not ready {self.job}") + if self.ctx.is_auto_tuner_mode(): + self.ctx.logger.info( + "Failed to start peer, auto tuner exit." + ) + import sys + + sys.exit(-1) break self.ctx.logger.debug(f"Run {self.job}") diff --git a/python/paddle/distributed/launch/controllers/controller.py b/python/paddle/distributed/launch/controllers/controller.py index 34819fc35e963..9769ec9d6bf3f 100644 --- a/python/paddle/distributed/launch/controllers/controller.py +++ b/python/paddle/distributed/launch/controllers/controller.py @@ -36,6 +36,13 @@ def __init__(self, ctx): signal.signal(signal.SIGTERM, self.signal_handler) signal.signal(signal.SIGABRT, self.signal_handler) signal.signal(signal.SIGINT, self.signal_handler) + if ctx.is_auto_tuner_mode(): + if not ctx.run_best: + # set per task timeout + signal.signal(signal.SIGALRM, self.not_exit_signal_handler) + signal.alarm(ctx.max_time_per_task) + else: + signal.alarm(0) self.ctx = ctx self.master = Master.factory(self.ctx) diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index 7823ddad27ca3..908c0af8cc18f 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -295,7 +295,6 @@ def launch(): elif ctx.is_auto_tuner_mode(): import copy import json - import signal import sys import time @@ -304,6 +303,7 @@ def launch(): from ..auto_tuner.utils import gen_new_args, read_log from . import controllers + start_time = time.time() # read user defined tuner config json try: with open(ctx.args.auto_tuner_json, "r") as f: @@ -326,24 +326,48 @@ def launch(): gpus_per_node = len(ctx.args.devices.split(",")) nnodes = ctx.args.nnodes if isinstance(nnodes, str): - tuner_cfg["nodes"] = int(nnodes.split(":")[0]) + nnodes = int(nnodes.split(":")[0]) else: - tuner_cfg["nodes"] = int(nnodes) + nnodes = int(nnodes) + tuner_cfg["nodes"] = nnodes tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"] + if nnodes > 1: + import etcd3 + + assert "etcd://" in ctx.args.master + master_ip, port = ctx.args.master.strip("etcd://").split(':') + client = etcd3.client(host=master_ip, port=port) + client.delete("best_cfg") + # build AutoTuner to get new config auto_tuner = AutoTuner(tuner_cfg) cur_cfg = auto_tuner.search_once() + auto_tuner.add_cfg(cur_cfg) # get max time per task run max_time_per_task = tuner_cfg.get("max_time_per_task", 1800) + ctx.max_time_per_task = max_time_per_task + + # warmup + warmup_time = ( + max_time_per_task + if "warmup_time" not in tuner_cfg + else tuner_cfg.get("warmup_time") + ) + is_first_task = True # build history recorder recorder = History_recorder() job_id = 0 + ctx.args.max_restart = -1 + raw_ctx = copy.deepcopy(ctx) while cur_cfg: - ctx.status._current_status = None + ctx = copy.deepcopy(raw_ctx) + if is_first_task: + ctx.max_time_per_task = warmup_time + is_first_task = False # auto tuner supports dp, mp, pp, micro batch size, sharding, recompute by default and every task has own log dir log_dir = "DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS_{}_Recompute_{}_granularity_{}".format( cur_cfg["dp_degree"], @@ -373,14 +397,10 @@ def launch(): task_job_id, log_dir, cur_cfg ) ) - c = controllers.init(ctx) - # set per task timeout - signal.signal(signal.SIGALRM, c.not_exit_signal_handler) - signal.alarm(max_time_per_task) c.run() - # Process generated result + # process generated result metric, err = read_log( path=ctx.args.log_dir, file="workerlog.0", @@ -388,12 +408,15 @@ def launch(): ) if err: ctx.logger.warning(f"Read log failed for parameters: {log_dir}") - cur_cfg['time'] = None # for pruner use. + # for pruner use + cur_cfg['time'] = -1 cur_cfg[tuner_cfg['metric_cfg']['name']] = None else: - cur_cfg['time'] = metric # for pruner use. + # for pruner use + cur_cfg['time'] = metric cur_cfg[tuner_cfg['metric_cfg']['name']] = metric - # record history + + # record history cur_cfg['job_id'] = job_id recorder.add_cfg(**cur_cfg) cur_best_cfgs, err = recorder.get_best( @@ -409,18 +432,78 @@ def launch(): ctx.logger.info( "Get best config failed. Currently there are no appropriate configs." ) + c.finalize(exit=False) + # generate a new config new_cfg = auto_tuner.search_once() - if new_cfg: - c.finalize(exit=False) - else: - c.finalize(exit=True) + cur_cfg = copy.deepcopy(new_cfg) + auto_tuner.add_cfg(cur_cfg) # per task launch interval - time.sleep(5) - - cur_cfg = copy.deepcopy(new_cfg) + time.sleep(3) recorder.store_history() + + # get best config to run + best_cfg = None + ctx = copy.deepcopy(raw_ctx) + if nnodes > 1: + import socket + + ip = None + try: + hostname = socket.gethostname() + ip = socket.gethostbyname(socket.getfqdn(hostname)) + except: + ip = '127.0.0.1' + if ip == master_ip: + best_cfg, err = recorder.get_best( + metric=tuner_cfg['metric_cfg']['name'], + direction=tuner_cfg['metric_cfg']['OptimizationDirection'], + ) + if err: + raise ValueError( + "Get best config failed. Currently there are no appropriate configs." + ) + data = json.dumps(best_cfg) + while not client.put("best_cfg", data): + time.sleep(1) + continue + else: + for i in range(10): + try: + data = client.get("best_cfg")[0].decode() + best_cfg = json.loads(data) + except Exception as e: + ctx.logger.warning(e) + time.sleep(2) + if best_cfg: + break + assert best_cfg + else: + best_cfg, err = recorder.get_best( + metric=tuner_cfg['metric_cfg']['name'], + direction=tuner_cfg['metric_cfg']['OptimizationDirection'], + ) + if err: + raise ValueError( + "Get best config failed. Currently there are no appropriate configs." + ) + assert best_cfg + + end_time = time.time() + ctx.logger.info(f"AutoTuner ends in {end_time-start_time}s.") + # launch best cfg + new_args = gen_new_args(raw_args, best_cfg, tuner_cfg) + ctx.run_best = True + ctx.args.training_script_args = new_args + ctx.args.job_id = "best_cfg" + ctx.logger.info(f"Launch best cfg from auto tuner: {best_cfg}") + ctx.args.log_dir = "best_cfg" + # run best cfg + c = controllers.init(ctx) + c.run() + c.finalize(exit=True) + else: from . import controllers