Skip to content

Commit

Permalink
[AutoTuner]Improve ETCD fault tolerance (PaddlePaddle#58314)
Browse files Browse the repository at this point in the history
* add fault tolerant for etcd apis

* fix metric bug

* fix some bugs
  • Loading branch information
Caozhou1995 authored and zeroRains committed Nov 8, 2023
1 parent af20b13 commit 66376f9
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 25 deletions.
4 changes: 0 additions & 4 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,6 @@ def prune_by_mp(tuner_cfg, cur_cfg, history_cfgs=None):
if mp_degree not in mp_degree_candidates:
return True

# prune default candidates
if mp_degree > 8:
return True

return False


Expand Down
5 changes: 2 additions & 3 deletions python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]:
if first_few >= 5:
break
return (best_cfg, False)
if (
isinstance(self.history[0]["max_mem_usage"], str)
or self.history[0]["time"] == -1
if isinstance(self.history[0]["max_mem_usage"], str) or (
"time" in self.history[0] and self.history[0]["time"] == -1
):
return (self.history[0], True)
return (self.history[0], False)
Expand Down
20 changes: 6 additions & 14 deletions python/paddle/distributed/launch/controllers/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ def __init__(self, ctx):

host, port = self.endpoint.split(':')
if ctx.is_auto_tuner_mode():
self.etcd_client = ETCDClient(host=host, port=port)
self.client = etcd3.client(host=host, port=port)
self.client = ETCDClient(host=host, port=port)
else:
self.client = etcd3.client(host=host, port=port)

def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int):
'''
Expand Down Expand Up @@ -256,22 +257,13 @@ def register_heartbeat(self, job_id, pod_id, ttl=10):

self.job_prefix = f'/paddle/{job_id}'
self.heartbeat_prefix = f'{self.job_prefix}/heartbeat'
if self.ctx.is_auto_tuner_mode():
self.etcd_client.delete_prefix(self.job_prefix)
lease = self.etcd_client.lease(ttl)
else:
self.client.delete_prefix(self.job_prefix)
lease = self.client.lease(ttl)
self.client.delete_prefix(self.job_prefix)
lease = self.client.lease(ttl)

# self.client.delete_prefix(self.job_prefix)

beat_path = f"{self.heartbeat_prefix}/{pod_id}"
if self.ctx.is_auto_tuner_mode():
self.etcd_client.put(
beat_path, pod_id.encode('latin-1'), lease=lease
)
else:
self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease)
self.client.put(beat_path, pod_id.encode('latin-1'), lease=lease)

def _beat_watch(event):
self.ctx.status.restart()
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/launch/controllers/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Watcher:
def __init__(self, ctx):
self.ctx = ctx

self.interval = 30
self.interval = 5

self.gpu_util = []

Expand Down
11 changes: 9 additions & 2 deletions python/paddle/distributed/launch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def launch():
elif "OK" not in status:
timeout_flag = False

has_error = False
if err & (1 << 0):
ctx.logger.warning(
f"Read metric failed for parameters: {log_dir}"
Expand All @@ -665,6 +666,7 @@ def launch():
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM"
has_error = True

if err & (1 << 1):
ctx.logger.warning(f"Out of memory for parameters: {log_dir}")
Expand All @@ -673,6 +675,7 @@ def launch():
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = "OOM"
has_error = True

# not err & (1 << 1): do not record memory usage when out of memory
if err & (1 << 2) and not err & (1 << 1):
Expand All @@ -684,18 +687,20 @@ def launch():
)
cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM"

if not err and timeout_flag:
if not has_error and timeout_flag:
# for pruner use
cur_cfg['time'] = metric
cur_cfg[tuner_cfg['metric_cfg']['name']] = metric
cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM"

if not err and not timeout_flag:
if not has_error and not timeout_flag:
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM"

# record history
if tuner_cfg['metric_cfg']['name'] not in cur_cfg:
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg['job_id'] = job_id
recorder.add_cfg(**cur_cfg)
recorder.store_history(history_file_path)
Expand Down Expand Up @@ -794,6 +799,8 @@ def launch():
ctx.logger.info(f"AutoTuner ends in {end_time-start_time}s.")
logger.info(f"AutoTuner ends in {end_time-start_time}s.")
# launch best cfg
if not tuner_cfg.get("run_best", True):
sys.exit()
new_args = gen_new_args(raw_args, best_cfg, tuner_cfg, run_best=True)
ctx.run_best = True
ctx.args.training_script_args = new_args
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/distributed/launch/utils/etcd_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,41 @@ def lease(self, ttl, lease_id=None):

if times >= self.retry_times:
raise ValueError(f"Lease failed after {self.retry_times} times.")

def add_watch_prefix_callback(self, key_prefix, callback, **kwargs):
times = 0
while times < self.retry_times:
try:
return self.client.add_watch_prefix_callback(
key_prefix, callback, **kwargs
)
break
except Exception as e:
times += 1
logging.info(
f"Add watch prefix callback failed with exception {e}, retry after 1 second."
)
time.sleep(1)

if times >= self.retry_times:
raise ValueError(
f"Add watch prefix callback failed after {self.retry_times} times."
)

def cancel_watch(self, watch_id):
times = 0
while times < self.retry_times:
try:
return self.client.cancel_watch(watch_id)
break
except Exception as e:
times += 1
logging.info(
f"Cancel watch failed with exception {e}, retry after 1 second."
)
time.sleep(1)

if times >= self.retry_times:
raise ValueError(
f"Cancel watch failed after {self.retry_times} times."
)
2 changes: 1 addition & 1 deletion python/paddle/distributed/launch/utils/nvsmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_gpu_util(index=None):
if index is None or isinstance(index, list)
else str(index).split(",")
)
if paddle.device.is_compiled_with_cuda():
if paddle.device.is_compiled_with_rocm():
return query_rocm_smi(q, index=index, dtype=d)
return query_smi(q, index=index, dtype=d)

Expand Down

0 comments on commit 66376f9

Please sign in to comment.