Skip to content

Commit

Permalink
[AutoTuner] Distribute best cfg (#54834)
Browse files Browse the repository at this point in the history
* distribute best cfg

* adapt to multi args transmission

* update metric extracting

* fix bugs of prune and reading log

* fix time default value

* remove time record

* adjust the order of searching dim

* fix prune bugs

* fix adding cfg bug

* fix multi nodes bug

* reset status

* remove alarm and set logdir

* deepcopy ctx

* change alarm

* fix restart bug

* add exit

* best no need alarm

* add warmup time
  • Loading branch information
Caozhou1995 authored Jul 14, 2023
1 parent 5de773d commit 7f6d222
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 72 deletions.
16 changes: 10 additions & 6 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def same_cfgs_beside(attr, cur_cfg, history_cfgs):
for key in cur_cfg:
if key == attr:
continue
if key not in history_cfgs or history_cfgs[key] != cur_cfg[key]:
if key not in cfg or cfg[key] != cur_cfg[key]:
same = False
break
if same:
Expand Down Expand Up @@ -56,7 +56,7 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
hidden_size = tuner_cfg["model_cfg"].get("hidden_size", None)
vocab_size = tuner_cfg["model_cfg"].get("vocab_size", None)

if not mp_degree:
if mp_degree is None:
return False

if hidden_size and hidden_size % mp_degree != 0:
Expand Down Expand Up @@ -93,7 +93,7 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=None):
num_layers = tuner_cfg["model_cfg"].get("num_layers", None)
num_nodes = tuner_cfg.get("num_nodes", 1)

if not pp_degree:
if pp_degree is None:
return False

if num_layers:
Expand Down Expand Up @@ -128,12 +128,15 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
// cur_cfg["dp_degree"]
// cur_cfg["sharding_degree"]
)
if local_batch_size == 0:
return True

mbs_candidates = tuner_cfg.get("micro_batch_size", None)

if mbs_candidates == "auto":
mbs_candidates = tuner_cfg["candidates"]["micro_batch_size"]

if not micro_batch_size:
if micro_batch_size is None:
return False

if local_batch_size:
Expand Down Expand Up @@ -222,7 +225,7 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
"""
recompute_granularity = cur_cfg.get("recompute_granularity", None)
use_recompute = cur_cfg.get("use_recompute", None)
if not use_recompute:
if use_recompute is None:
return False

recompute_granularity_candidates = tuner_cfg["candidates"].get(
Expand Down Expand Up @@ -253,10 +256,11 @@ def prune_by_recompute(tuner_cfg, cur_cfg, history_cfgs):
):
return True

if use_recompute is False:
if not use_recompute:
cfgs = same_cfgs_beside("recompute_granularity", cur_cfg, history_cfgs)
if cfgs:
return True

return False


Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def store_history(self, path="./history.csv"):
cols = df.columns.tolist()
cols.insert(0, cols.pop(cols.index('job_id')))
df = df.reindex(columns=cols)
df = df.drop(columns=['time'])
# write to csv
df.to_csv(self.store_path, index=False)

Expand Down
5 changes: 4 additions & 1 deletion python/paddle/distributed/auto_tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def search_once(self):
return None
new_cfg = self.algo.search_once(self.history_cfgs)
self.cur_task_id += 1
self.history_cfgs.append(new_cfg)

return new_cfg

def add_cfg(self, cfg):
"""Add cfg into history cfgs"""
self.history_cfgs.append(cfg)
178 changes: 132 additions & 46 deletions python/paddle/distributed/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def dist_degree(mode, num_gpus, num_nodes):
assert mode in ["dp", "mp", "pp", "sharding"]
results = []
if mode == "dp":
results = divisor(num_gpus, reverse=True)
results = divisor(num_gpus, reverse=False)

elif mode == "pp":
if num_nodes > 1:
results = list(range(num_nodes))
results = list(range(1, num_nodes + 1))
else:
results = divisor(num_gpus, reverse=True)

Expand Down Expand Up @@ -123,17 +123,15 @@ def default_candidates(tuner_cfg):
elif tuner_cfg.get("micro_batch_size", None):
candidates["micro_batch_size"] = tuner_cfg.get("micro_batch_size")
else:
candidates["micro_batch_size"] = [
tuner_cfg["model_cfg"]["global_batch_size"]
]
candidates["micro_batch_size"] = [None]

return candidates


def search_all(tuner_cfg):
"""Permutate the candidates of all hyper params."""
candidates = tuner_cfg["candidates"]
# Order: dp -> mp -> pp -> mbs -> sharding-> recompute
# Order: dp -> sharding -> mbs -> pp -> mp -> recompute
dp_degree_candidates = candidates["dp_degree"]
mp_degree_candidates = candidates["mp_degree"]
pp_degree_candidates = candidates["pp_degree"]
Expand All @@ -145,22 +143,22 @@ def search_all(tuner_cfg):
all_cfgs = list(
itertools.product(
dp_degree_candidates,
mp_degree_candidates,
pp_degree_candidates,
mbs_candidates,
sharding_degree_candidates,
sharding_stage_candidates,
mbs_candidates,
pp_degree_candidates,
mp_degree_candidates,
use_recompute_candidates,
recompute_granularity_candidates,
)
)
mapping = {
0: "dp_degree",
1: "mp_degree",
2: "pp_degree",
1: "sharding_degree",
2: "sharding_stage",
3: "micro_batch_size",
5: "sharding_stage",
4: "sharding_degree",
4: "pp_degree",
5: "mp_degree",
6: "use_recompute",
7: "recompute_granularity",
}
Expand All @@ -179,59 +177,145 @@ def gen_new_args(raw_args, cfg, tuner_cfg):
cmd = copy.deepcopy(tuner_cfg["run_cmd"])
res_args = copy.deepcopy(raw_args)
if "dp_degree" in cmd and "dp_degree" in cfg:
cmd["dp_degree"][1] = cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"])
res_args.extend(cmd["dp_degree"])
if "--" in cmd["dp_degree"][0]:
cmd["dp_degree"][1] = cmd["dp_degree"][1] + str(cfg["dp_degree"])
res_args.extend(cmd["dp_degree"])
else:
cmd["dp_degree"][1] = (
cmd["dp_degree"][1] + "=" + str(cfg["dp_degree"])
)
res_args.extend(cmd["dp_degree"])

if "mp_degree" in cmd and "mp_degree" in cfg:
cmd["mp_degree"][1] = cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"])
res_args.extend(cmd["mp_degree"])
if "--" in cmd["mp_degree"][0]:
cmd["mp_degree"][1] = cmd["mp_degree"][1] + str(cfg["mp_degree"])
res_args.extend(cmd["mp_degree"])
else:
cmd["mp_degree"][1] = (
cmd["mp_degree"][1] + "=" + str(cfg["mp_degree"])
)
res_args.extend(cmd["mp_degree"])

if "pp_degree" in cmd and "pp_degree" in cfg:
cmd["pp_degree"][1] = cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"])
res_args.extend(cmd["pp_degree"])
if "--" in cmd["pp_degree"][0]:
cmd["pp_degree"][1] = cmd["pp_degree"][1] + str(cfg["pp_degree"])
res_args.extend(cmd["pp_degree"])
else:
cmd["pp_degree"][1] = (
cmd["pp_degree"][1] + "=" + str(cfg["pp_degree"])
)
res_args.extend(cmd["pp_degree"])

if "micro_batch_size" in cmd and "micro_batch_size" in cfg:
cmd["micro_batch_size"][1] = (
cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"])
)
res_args.extend(cmd["micro_batch_size"])
if "--" in cmd["micro_batch_size"][0]:
cmd["micro_batch_size"][1] = cmd["micro_batch_size"][1] + str(
cfg["micro_batch_size"]
)
res_args.extend(cmd["micro_batch_size"])
else:
cmd["micro_batch_size"][1] = (
cmd["micro_batch_size"][1] + "=" + str(cfg["micro_batch_size"])
)
res_args.extend(cmd["micro_batch_size"])

if "sharding_degree" in cmd and "sharding_degree" in cfg:
cmd["sharding_degree"][1] = (
cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"])
)
res_args.extend(cmd["sharding_degree"])
if "--" in cmd["sharding_degree"][0]:
cmd["sharding_degree"][1] = cmd["sharding_degree"][1] + str(
cfg["sharding_degree"]
)
res_args.extend(cmd["sharding_degree"])
else:
cmd["sharding_degree"][1] = (
cmd["sharding_degree"][1] + "=" + str(cfg["sharding_degree"])
)
res_args.extend(cmd["sharding_degree"])

if "sharding_stage" in cmd and "sharding_stage" in cfg:
cmd["sharding_stage"][1] = (
cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"])
)
res_args.extend(cmd["sharding_stage"])
if "--" in cmd["sharding_stage"][0]:
cmd["sharding_stage"][1] = cmd["sharding_stage"][1] + str(
cfg["sharding_stage"]
)
res_args.extend(cmd["sharding_stage"])
else:
cmd["sharding_stage"][1] = (
cmd["sharding_stage"][1] + "=" + str(cfg["sharding_stage"])
)
res_args.extend(cmd["sharding_stage"])

if "use_recompute" in cmd and "use_recompute" in cfg:
cmd["use_recompute"][1] = (
cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"])
)
res_args.extend(cmd["use_recompute"])
if "--" in cmd["use_recompute"][0]:
cmd["use_recompute"][1] = cmd["use_recompute"][1] + str(
cfg["use_recompute"]
)
res_args.extend(cmd["use_recompute"])
else:
cmd["use_recompute"][1] = (
cmd["use_recompute"][1] + "=" + str(cfg["use_recompute"])
)
res_args.extend(cmd["use_recompute"])

if "recompute_granularity" in cmd and "recompute_granularity" in cfg:
cmd["recompute_granularity"][1] = (
cmd["recompute_granularity"][1]
+ "="
+ str(cfg["recompute_granularity"])
)
res_args.extend(cmd["recompute_granularity"])
if "--" in cmd["recompute_granularity"][0]:
cmd["recompute_granularity"][1] = cmd["recompute_granularity"][
1
] + str(cfg["recompute_granularity"])
res_args.extend(cmd["recompute_granularity"])
else:
cmd["recompute_granularity"][1] = (
cmd["recompute_granularity"][1]
+ "="
+ str(cfg["recompute_granularity"])
)
res_args.extend(cmd["recompute_granularity"])

if "local_batch_size" in cmd:
local_batch_size = (
tuner_cfg["model_cfg"]["global_batch_size"]
// cfg["sharding_degree"]
// cfg["dp_degree"]
)
cmd["local_batch_size"][1] = (
cmd["local_batch_size"][1] + "=" + str(local_batch_size)
)
res_args.extend(cmd["local_batch_size"])
if "--" in cmd["local_batch_size"][0]:
cmd["local_batch_size"][1] = cmd["local_batch_size"][1] + str(
local_batch_size
)
res_args.extend(cmd["local_batch_size"])
else:
cmd["local_batch_size"][1] = (
cmd["local_batch_size"][1] + "=" + str(local_batch_size)
)
res_args.extend(cmd["local_batch_size"])

if "gradient_accumulation_steps" in cmd:
if "--" in cmd["gradient_accumulation_steps"][0]:
try:
gradient_accumulation_steps = (
tuner_cfg["model_cfg"]["global_batch_size"]
// cfg["sharding_degree"]
// cfg["dp_degree"]
// cfg["micro_batch_size"]
)
cmd["gradient_accumulation_steps"][1] = cmd[
"gradient_accumulation_steps"
][1] + str(gradient_accumulation_steps)
res_args.extend(cmd["gradient_accumulation_steps"])
except:
pass
else:
try:
gradient_accumulation_steps = (
tuner_cfg["model_cfg"]["global_batch_size"]
// cfg["sharding_degree"]
// cfg["dp_degree"]
// cfg["micro_batch_size"]
)
cmd["gradient_accumulation_steps"][1] = (
cmd["gradient_accumulation_steps"][1]
+ "="
+ str(gradient_accumulation_steps)
)
res_args.extend(cmd["gradient_accumulation_steps"])
except:
pass

return res_args

Expand All @@ -245,7 +329,9 @@ def read_log(
return (0.0, True)
with open(target_file, "r") as f:
# read file
re_metric_pattern = r'speed: (\d+(\.\d*)?) *' + target_metric
re_metric_pattern = (
target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric
)

metric_list = []
lines = f.readlines()
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/launch/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, enable_plugin=True):

if enable_plugin:
self._enable_plugin()
self.max_time_per_task = -1
self.run_best = False

def print(self):
self.logger.info("----------- Configuration ----------------------")
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/launch/controllers/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,13 @@ def run(self):
self.job.replicas = replicas
else:
self.ctx.logger.warning(f"peer not ready {self.job}")
if self.ctx.is_auto_tuner_mode():
self.ctx.logger.info(
"Failed to start peer, auto tuner exit."
)
import sys

sys.exit(-1)
break

self.ctx.logger.debug(f"Run {self.job}")
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/launch/controllers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def __init__(self, ctx):
signal.signal(signal.SIGTERM, self.signal_handler)
signal.signal(signal.SIGABRT, self.signal_handler)
signal.signal(signal.SIGINT, self.signal_handler)
if ctx.is_auto_tuner_mode():
if not ctx.run_best:
# set per task timeout
signal.signal(signal.SIGALRM, self.not_exit_signal_handler)
signal.alarm(ctx.max_time_per_task)
else:
signal.alarm(0)

self.ctx = ctx
self.master = Master.factory(self.ctx)
Expand Down
Loading

0 comments on commit 7f6d222

Please sign in to comment.