-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
127 lines (103 loc) · 4.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import matplotlib.pyplot as plt
import numpy as np
import gym
def plotLearning(x, scores, epsilons, filename, lines=None):
#plt.yscale("log")
fig=plt.figure()
ax=fig.add_subplot(111, label="1")
ax2=fig.add_subplot(111, label="2", frame_on=False)
ax.plot(x, epsilons, color="C0")
ax.set_xlabel("Game", color="C0")
ax.set_ylabel("Epsilon", color="C0")
ax.tick_params(axis='x', colors="C0")
ax.tick_params(axis='y', colors="C0")
N = len(scores)
running_avg = np.empty(N)
for t in range(N):
running_avg[t] = np.mean(scores[max(0, t-20):(t+1)])
ax2.axhline(y=-10, color="C3")
ax2.axhline(y=-100, color="C3")
ax2.axhline(y=-300, color="C3")
ax2.axhline(y=-400, color="C3")
ax2.axhline(y=-500, color="C3")
ax2.scatter(x, running_avg, color="C1", s=10)
#ax2.xaxis.tick_top()
ax2.axes.get_xaxis().set_visible(False)
ax2.yaxis.tick_right()
#ax2.set_xlabel('x label 2', color="C1")
ax2.set_ylabel('Score', color="C1")
#ax2.xaxis.set_label_position('top')
ax2.yaxis.set_label_position('right')
#ax2.tick_params(axis='x', colors="C1")
ax2.tick_params(axis='y', colors="C1")
if lines is not None:
for line in lines:
plt.axvline(x=line)
plt.yscale('symlog')
plt.savefig(filename)
class SkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
super(SkipEnv, self).__init__(env)
self._skip = skip
def step(self, action):
t_reward = 0.0
done = False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
t_reward += reward
if done:
break
return obs, t_reward, done, info
def reset(self):
self._obs_buffer = []
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class PreProcessFrame(gym.ObservationWrapper):
def __init__(self, env=None):
super(PreProcessFrame, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=(80,80,1), dtype=np.uint8)
def observation(self, obs):
return PreProcessFrame.process(obs)
@staticmethod
def process(frame):
new_frame = np.reshape(frame, frame.shape).astype(np.float32)
new_frame = 0.299*new_frame[:,:,0] + 0.587*new_frame[:,:,1] + \
0.114*new_frame[:,:,2]
new_frame = new_frame[35:195:2, ::2].reshape(80,80,1)
return new_frame.astype(np.uint8)
class MoveImgChannel(gym.ObservationWrapper):
def __init__(self, env):
super(MoveImgChannel, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0.0, high=1.0,
shape=(self.observation_space.shape[-1],
self.observation_space.shape[0],
self.observation_space.shape[1]),
dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class ScaleFrame(gym.ObservationWrapper):
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps):
super(BufferWrapper, self).__init__(env)
self.observation_space = gym.spaces.Box(
env.observation_space.low.repeat(n_steps, axis=0),
env.observation_space.high.repeat(n_steps, axis=0),
dtype=np.float32)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=np.float32)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
def make_env(env_name):
env = gym.make(env_name)
env = SkipEnv(env)
env = PreProcessFrame(env)
env = MoveImgChannel(env)
env = BufferWrapper(env, 4)
return ScaleFrame(env)