-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun.py
89 lines (69 loc) · 2.83 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import platform
import pprint
from utils.utils import set_seed
import os
import json
from tensorboardX import SummaryWriter
import warnings
from model.lstm import LSTM
from execute.rollout import rollout
from execute.train import trainer
from expr.tokenizer import MyTokenizer
from options import get_options
def run(opts):
# only one mode can be specified in one time, test or train
assert (opts.train==None) ^ (opts.test==None), 'Between train&test, only one mode can be given in one time'
sys=platform.system()
opts.is_linux=True if sys == 'Linux' else False
# Pretty print the run args
pprint.pprint(vars(opts))
# Optionally configure tensorboard
tb_logger = None
if not opts.no_tb:
tb_logger = SummaryWriter(os.path.join(opts.log_dir, "{}D".format(opts.dim), opts.run_name))
if not opts.no_saving and not os.path.exists(opts.save_dir):
os.makedirs(opts.save_dir)
# Save arguments so exact configuration can always be found
if not opts.no_saving:
with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
json.dump(vars(opts), f, indent=True)
# Set the device, you can change it according to your actual situation
opts.device = torch.device("cuda" if opts.use_cuda else "cpu")
# Set the random seed to initialize the network
set_seed(opts.seed)
# init agent
model=LSTM(opts,tokenizer=MyTokenizer())
# Load data from load_path or resume_path (if provided)
assert opts.load_path is None or opts.resume_path is None, "Only one of load path and resume can be given"
load_path = opts.load_path if opts.load_path is not None else opts.resume_path
if load_path is not None:
runner=trainer(model,opts)
runner.load(load_path)
# test only
if opts.test:
from env import SubprocVectorEnv,DummyVectorEnv
# init task
set_seed()
runner.vector_env=SubprocVectorEnv if opts.is_linux else DummyVectorEnv
print(f'run_name:{opts.run_name}')
rollout(opts,runner,-1,tb_logger,MyTokenizer(),testing=True)
else:
if opts.resume_path:
epoch_resume = int(os.path.splitext(os.path.split(opts.resume_path)[-1])[0].split("-")[1])
print("Resuming after {}".format(epoch_resume))
opts.epoch_start = epoch_resume + 1
runner.start_training(tb_logger)
else:
# training
runner=trainer(model,opts)
runner.start_training(tb_logger)
if not opts.no_tb:
tb_logger.close()
if __name__ == '__main__':
warnings.filterwarnings("ignore")
torch.set_num_threads(1)
os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
run(get_options())