-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
35 lines (21 loc) · 1.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import json
from argparse import ArgumentParser
import os
from src.run.run import run
def train_main(episodes : int, agn_par_file, env_par_file, mode : str, checkpoint_file, tag : str ):
agn_par_file = 'parameters/agn_train_par/'+agn_par_file
env_par_file = 'parameters/env_train_par/'+env_par_file
with open(agn_par_file) as agn_json_file, open(env_par_file) as env_json_file:
par_agent = json.load(agn_json_file)
par_environment = json.load(env_json_file)
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
run(episodes,par_agent,par_environment,mode,checkpoint_file,False,PROJECT_ROOT,tag)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-e", "--episodes", dest="episodes", help="Number of episodes to run", default=500, type=int)
parser.add_argument("-pa", "--par_agn", dest="par_agn", help="Agent parameter file", default='agn_base.json')
parser.add_argument("-pe", "--par_env", dest="par_env", help="Env parameter file", default='env_base.json')
parser.add_argument("-c", "--checkpoint", dest="checkpoint", help="Checkpoint file to be loaded", default=None)
parser.add_argument("-t", "--tag", dest="tag", help="Useful to tag a run in wandb", default=None)
args = parser.parse_args()
train_main(args.episodes, args.par_agn, args.par_env, 'train', args.checkpoint, args.tag)