forked from tjmccue00/Thesis-Code
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CartPoleTrain.py
executable file
·81 lines (67 loc) · 2.39 KB
/
CartPoleTrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from Environments.CartPoleEnv import CartPole
from RL.QLearning import Agent
import time
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
qtable = Agent(3, 4, [-200, 200], [[-.2095, .2095], [-10, 10], [-0.62, 0.62], [-10, 10]], 0.2, 0.995, 30)
timestep=100
epochs = 150000
rewards = 0
solved = False
steps = 0
runs = [0]
data = {'score' : [0], 'avg' : [0]}
start = time.time()
ep = [i for i in range(0,epochs + 1,timestep)]
epsilon = 0.2
learn_iters = 0
qtable.initialize_table()
cp = CartPole()
date = str(datetime.now())
bestScore = 0
for episode in range(1,epochs+1):
current_state = qtable.discretize(cp.reset()) # initial observation
score = 0
done = False
temp_start = time.time()
while not done and score <= 3600:
ep_start = time.time()
if np.random.uniform(0,1) < epsilon:
action, action_idx = qtable.get_sample_action()
action = int(action)
else:
action = np.argmax(qtable.qtable[current_state])
action, action_idx = qtable.get_action(current_state)
action = int(action)
next_state, reward, done = cp.step(action)
steps += 1
next_state = qtable.discretize(next_state)
score += reward
if not done:
qtable.update_table(next_state, current_state, action_idx, reward)
current_state = next_state
else:
rewards += score
runs.append(score)
epsilon -= 0.2/epochs
# Timestep value update
if episode%timestep == 0:
if round(np.mean(data['score'][-timestep:]),2) > bestScore:
bestScore = round(np.mean(data['score'][-timestep:]),2)
qtable.save_qtable("Test4")
print('Episode : {} | Reward -> {} | Max reward : {} |'.format(episode,round(np.mean(data['score'][-100:]),2), max(runs)))
data['score'].append(score)
data['avg'].append(np.mean(data['score'][-100:]))
if rewards/timestep >= 195:
pass
#print('Solved in episode : {}'.format(episode))
rewards, runs= 0, [0]
with open(qtable.chkpt_dir+"QL_data_" + date + ".csv", "a") as f:
f.write(str(episode) + "," + str(round(np.mean(data['score'][-timestep:]),2)) + "\n")
plt.plot(ep, data['avg'], label = 'Avg')
plt.title('Average Reward v Episode')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.show()
cp.close()