-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBaseline_CartPole.py
94 lines (80 loc) · 2.58 KB
/
Baseline_CartPole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from MDPs.cartpole import get_CartPole_MDP
from agents.MLP_agent import MLP_Agent
from algorithms.Baseline import Baseline_Algorithm
import time
from tools.tools import time_hr_min_sec
from tools.savefig import savefig_mean_std
import numpy as np
from multiprocessing import Pool, cpu_count
import torch
def func(w_hidden_sizes, theta_hidden_sizes, w_lr, theta_lr, input_mode, n, seed):
torch.manual_seed(seed)
np.random.seed(seed)
print(w_hidden_sizes, theta_hidden_sizes, w_lr, theta_lr)
while True:
start = time.time()
MDP = get_CartPole_MDP(n_bins=50)
agent = MLP_Agent(
MDP=MDP,
state_size=4,
w_hidden_sizes=w_hidden_sizes,
w_lr=w_lr,
theta_hidden_sizes=theta_hidden_sizes,
theta_lr=theta_lr,
input_mode=input_mode
)
lc1, lc2, returns = Baseline_Algorithm(
MDP,
agent,
min_sigma=1.0,
max_sigma=2.0,
n=n,
print_output=0,
n_steps_limit=200,
error_curves=False
)
print('%.2f' % (time.time() - start), end=' ')
if len(returns) == n:
print(w_hidden_sizes, theta_hidden_sizes, w_lr, theta_lr)
return lc1, lc2, returns
if __name__ == '__main__':
n_runs = 20
w_hidden_sizes = [50] * 2
theta_hidden_sizes = [50] * 3
w_lr = 1e-7
theta_lr = 3e-9
input_mode = 'default'
n = 10 ** 5
##################################################################
#n_runs = 4
#n = 10 ** 4
##################################################################
parameters = [(w_hidden_sizes, theta_hidden_sizes, w_lr, theta_lr, input_mode, n, _) for _ in range(n_runs)]
with Pool(processes=min(20, n_runs)) as pool:
res = pool.starmap(func, parameters)
save_dir = 'results'
save_name = '1211_baseline_cartpole'
savefig_mean_std(
[item[0] for item in res],
eval_interval=1,
save_dir=save_dir,
save_name=save_name+"_lc1",
xlabel='Number of actions',
ylabel='Number of episodes'
)
savefig_mean_std(
[np.array(item[1]) for item in res],
eval_interval=1,
save_dir=save_dir,
save_name=save_name+'_lc2',
xlabel='Number of episodes',
ylabel='Mean squared error'
)
savefig_mean_std(
[-np.array(item[2]) for item in res],
eval_interval=1,
save_dir=save_dir,
save_name=save_name+'_lc3',
xlabel='Number of episodes',
ylabel='number of steps'
)