-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask_spec.py
126 lines (109 loc) · 4.8 KB
/
task_spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import json
import argparse
import tensorflow as tf
import tensorport
class TaskSpec(object):
"""
Specification for the task with the job name, index of the task, the parameter servers and
the worker servers
"""
def __init__(self, job_name='master', index=0, ps_hosts=None, worker_hosts=None,
with_evaluator=False):
self.job_name = job_name
self.index = index
self.evaluator = False
self.cluster_spec = None
self.num_workers = 1
if ps_hosts and worker_hosts:
ps = ps_hosts if isinstance(ps_hosts, list) else ps_hosts.split(',')
worker = worker_hosts if isinstance(worker_hosts, list) else worker_hosts.split(',')
if with_evaluator and len(worker) > 1:
# last worker will be the evaluator
worker = worker[:-1]
self.evaluator = self.is_worker() and len(worker) == index
self.cluster_spec = tf.train.ClusterSpec({'ps': ps, 'worker': worker})
self.num_workers = len(worker)
def is_chief(self):
return self.index == 0
def is_master(self):
return self.job_name == 'master'
def is_ps(self):
return self.job_name == 'ps'
def is_worker(self):
return self.job_name == 'worker' or self.job_name == 'master'
def is_evaluator(self):
return self.evaluator
def join_if_ps(self):
if self.is_ps():
server = tf.train.Server(self.cluster_spec,
job_name=self.job_name,
task_index=self.index)
server.join()
return True
return False
def get_task_spec(with_evaluator=False):
"""
Loads the task information from the command line or the enviorment variables (if the command
line parameters are not set) and returns a TaskSpec object
:return TaskSpec: a TaskSpec object with the information about the task
"""
# get task from parameters
parser = argparse.ArgumentParser()
parser.add_argument('--job_name', dest='job_name', default=None)
parser.add_argument('--task_index', dest='task_index', default=None)
parser.add_argument('--ps_hosts', dest='ps_hosts', default=None)
parser.add_argument('--worker_hosts', dest='worker_hosts', default=None)
args, _ = parser.parse_known_args()
if args.job_name:
return TaskSpec(job_name=args.job_name, index=args.task_index,
ps_hosts=args.ps_hosts, worker_hosts=args.worker_hosts,
with_evaluator=with_evaluator)
# get task from environment:
if 'JOB_NAME' in os.environ:
return TaskSpec(job_name=os.environ['JOB_NAME'], index=int(os.environ['TASK_INDEX']),
ps_hosts=os.environ.get('PS_HOSTS', None),
worker_hosts=os.environ.get('WORKER_HOSTS', None),
with_evaluator=with_evaluator)
if 'TF_CONFIG' in os.environ:
env = json.loads(os.environ.get('TF_CONFIG', '{}'))
task_data = env.get('task', None) or {'type': 'master', 'index': 0}
cluster_data = env.get('cluster', None) or {'ps': None, 'worker': None}
return TaskSpec(job_name=task_data['type'], index=int(task_data['index']),
ps_hosts=cluster_data['ps'], worker_hosts=cluster_data['worker'],
with_evaluator=with_evaluator)
# return emtpy task spec for running in local
return TaskSpec()
def get_logs_path(path):
"""
Log dir specification, see: get_logs_path,
https://tensorport.com/documentation/api/#get_logs_path
:param str path: the path for the logs dir
:return str: the real path for the logs
"""
if path.startswith('gs://'):
return path
return tensorport.get_logs_path(path)
def get_data_path(dataset_name, local_root, local_repo='', path=''):
"""
Dataset specification, see: get_data_path,
https://tensorport.com/documentation/api/#get_data_path
If local_root starts with gs:// we suppose a bucket in google cloud and return
local_root / local_repo / local_path
:param str name: TensorPort dataset repository name,
e.g. user_name/repo_name
:param str local_root: specifies the root directory for dataset.
e.g. /home/username/datasets, gs://my-project/my_dir
:param str local_repo: specifies the repo name inside the root data path.
e.g. my_repo_data/
:param str path: specifies the path inside the repository, (optional)
e.g. train
:return str: the real path of the dataset
"""
if local_root.startswith('gs://'):
return os.path.join(local_root, local_repo, path)
return tensorport.get_data_path(
dataset_name=dataset_name,
local_root=local_root,
local_repo=local_repo,
path=path)