-
Notifications
You must be signed in to change notification settings - Fork 11
/
run.py
145 lines (112 loc) · 3.9 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import os
import socket
from utils import has_slurm, random_port, runcmd
EXP_DIR_ENV_NAME = "VL_EXP_DIR"
# if key in hostname; apply the args in value to slurm.
DEFAULT_SLURM_ARGS = dict(login="-p gpu --mem=240GB -c 64 -t 2-00:00:00")
def get_default_slurm_args():
"""get the slurm args for different cluster.
Returns: TODO
"""
hostname = socket.gethostname()
for k, v in DEFAULT_SLURM_ARGS.items():
if k in hostname:
return v
return ""
def parse_args():
parser = argparse.ArgumentParser()
# slurm
parser.add_argument("--slurm_args", type=str, default="", help="args for slurm.")
parser.add_argument(
"--no_slurm",
action="store_true",
help="If specified, will launch job without slurm",
)
parser.add_argument("--jobname", type=str, required=True, help="experiment name")
parser.add_argument(
"--dep_jobname", type=str, default="impossible_jobname", help="the dependent job name"
)
parser.add_argument("--nnodes", "-n", type=int, default=1, help="the number of nodes")
parser.add_argument(
"--ngpus", "-g", type=int, default=1, help="the number of gpus per nodes"
)
#
parser.add_argument(
"--task",
type=str,
required=True,
help="one of: pretrain, retrieval, retrieval_mc, vqa.",
)
parser.add_argument("--config", type=str, required=True, help="config file name.")
parser.add_argument("--model_args", type=str, default="", help="args for model")
args = parser.parse_args()
return args
def get_output_dir(args):
"""get the output_dir"""
return os.path.join(os.environ[EXP_DIR_ENV_NAME], args.jobname)
def prepare(args: argparse.Namespace):
"""prepare for job submission
Args:
args (dict): The arguments.
Returns: The path to the copied source code.
"""
output_dir = get_output_dir(args)
code_dir = os.path.join(output_dir, "code")
project_dirname = os.path.basename(os.getcwd())
# check output_dir exist
if os.path.isdir(output_dir):
# if using slurm
if has_slurm() and not args.no_slurm:
raise ValueError(f"output_dir {output_dir} already exist. Exit.")
else:
os.mkdir(output_dir)
# copy code
cmd = f"cd ..; rsync -ar {project_dirname} {code_dir} --exclude='*.out'"
print(cmd)
runcmd(cmd)
return os.path.join(code_dir, project_dirname)
def submit_job(args: argparse.Namespace):
"""TODO: Docstring for build_job_script.
Args:
args (argparse.Namespace): The commandline args.
Returns: str. The script to run.
"""
output_dir = get_output_dir(args)
code_dir = prepare(args)
master_port = os.environ.get("MASTER_PORT", random_port())
init_cmd = f" cd {code_dir}; export MASTER_PORT={master_port}; "
if has_slurm() and not args.no_slurm:
# prepare slurm args.
mode = "slurm"
default_slurm_args = get_default_slurm_args()
bin = (
f" sbatch --output {output_dir}/%j.out --error {output_dir}/%j.out"
f" {default_slurm_args}"
f" {args.slurm_args} --job-name={args.jobname} --nodes {args.nnodes} "
f" --ntasks {args.nnodes} "
f" --gpus-per-node={args.ngpus} "
f" --dependency=$(squeue --noheader --format %i --name {args.dep_jobname}) "
)
else:
mode = "local"
bin = "bash "
# build job cmd
job_cmd = (
f" tasks/{args.task}.py"
f" {args.config}"
f" output_dir {output_dir}"
f" {args.model_args}"
)
cmd = (
f" {init_cmd} {bin} "
f" tools/submit.sh "
f" {mode} {args.nnodes} {args.ngpus} {job_cmd} "
)
with open(os.path.join(output_dir, "cmd.txt"), "w") as f:
f.write(cmd)
print(cmd)
runcmd(cmd)
if __name__ == "__main__":
args = parse_args()
submit_job(args)