-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathsimple_slurm.py
59 lines (47 loc) · 1.88 KB
/
simple_slurm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import os
import subprocess
import submitit
from powerful_benchmarker.utils.constants import add_default_args
from powerful_benchmarker.utils.utils import create_slurm_args
from validator_tests.utils import utils
from validator_tests.utils.constants import add_exp_group_args, exp_group_args
def exp_launcher(conda_env, command):
full_command = f"bash -i ./scripts/script_wrapper.sh {conda_env}".split(" ")
full_command += [command]
subprocess.run(full_command)
def run(args, slurm_args, exp_group):
executor = submitit.AutoExecutor(
folder=os.path.join(args.exp_folder, args.slurm_folder)
)
executor.update_parameters(
timeout_min=0,
tasks_per_node=1,
slurm_additional_parameters=slurm_args,
)
command = args.command
if exp_group:
command = f"{command} --exp_groups {exp_group}"
job = executor.submit(exp_launcher, args.conda_env, command)
print("started", job.job_id)
def main(args, slurm_args):
if not any(getattr(args, k) for k in exp_group_args()):
run(args, slurm_args, None)
return
exp_groups = utils.get_exp_groups(args)
if args.all_in_one:
run(args, slurm_args, " ".join(exp_groups))
else:
for e in exp_groups:
run(args, slurm_args, e)
if __name__ == "__main__":
parser = argparse.ArgumentParser(allow_abbrev=False)
add_default_args(parser, ["exp_folder", "conda_env", "slurm_folder"])
add_exp_group_args(parser)
parser.add_argument("--command", type=str, required=True)
parser.add_argument("--slurm_config_folder", type=str, required=True)
parser.add_argument("--slurm_config", type=str, required=True)
parser.add_argument("--all_in_one", action="store_true")
args, unknown_args = parser.parse_known_args()
slurm_args = create_slurm_args(args, unknown_args, args.slurm_config_folder)
main(args, slurm_args)