-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathmbexp.py
64 lines (46 loc) · 1.96 KB
/
mbexp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import pprint
from dotmap import DotMap
from MBExperiment import MBExperiment
from MPC import MPC
from config import create_config
import env # We run this so that the env is registered
import torch
import numpy as np
import random
import tensorflow as tf
def set_global_seeds(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
tf.set_random_seed(seed)
def main(env, ctrl_type, ctrl_args, overrides, logdir):
set_global_seeds(0)
ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir)
cfg.pprint()
assert ctrl_type == 'MPC'
cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
exp = MBExperiment(cfg.exp_cfg)
os.makedirs(exp.logdir)
with open(os.path.join(exp.logdir, "config.txt"), "w") as f:
f.write(pprint.pformat(cfg.toDict()))
exp.run_experiment()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-env', type=str, required=True,
help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]')
parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
parser.add_argument('-logdir', type=str, default='log',
help='Directory to which results will be logged (default: ./log)')
args = parser.parse_args()
main(args.env, "MPC", args.ctrl_arg, args.override, args.logdir)