-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgridworld.py
208 lines (174 loc) · 8.21 KB
/
gridworld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import matplotlib.pyplot as plt
import numpy as np
class GridWorld:
def __init__(self, reward_wall=-5):
# initialize grid with 2d numpy array
# >0: goal
# -1: wall/obstacles
# 0: non-terminal
self._grid = np.array(
[[0, 0, 0, 0, 0, -1, 0, 0],
[0, 0, 0, -1, 0, 0, 0, 5],
[0, 0, 0, -1, -1, 0, 0, 0],
[0, 0, 0, -1, -1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0]
])
# wall around the grid, padding grid with -1
self._grid_padded = np.pad(self._grid, pad_width=1, mode='constant', constant_values=-1)
self._reward_wall = reward_wall
# set start state
self._start_state = (1, 1)
self._random_start = False
# store position of goal states and non-terminal states
idx_goal_state_y, idx_goal_state_x = np.nonzero(self._grid > 0)
self._goal_states = [(idx_goal_state_y[i], idx_goal_state_x[i]) for i in range(len(idx_goal_state_x))]
idx_non_term_y, idx_non_term_x = np.nonzero(self._grid == 0)
self._non_term_states = [(idx_non_term_y[i], idx_non_term_x[i]) for i in range(len(idx_non_term_x))]
# store the current state in the padded grid
self._state_padded = (self._start_state[0] + 1, self._start_state[1] + 1)
def get_state_num(self):
# get the number of states (total_state_number) in the grid, note: the wall/obstacles inside the grid are
# counted as state as well
return np.prod(np.shape(self._grid))
def get_state_grid(self):
state_grid = np.multiply(np.reshape(np.arange(self.get_state_num()), self._grid.shape), self._grid >= 0) - (
self._grid == -1)
return state_grid, np.pad(state_grid, pad_width=1, mode='constant', constant_values=-1)
def get_current_state(self):
# get the current state as an integer from 0 to total_state_number-1
y, x = self._state_padded
return (y - 1) * self._grid.shape[1] + (x - 1)
def int_to_state(self, int_obs):
# convert an integer from 0 to total_state_number-1 to the position on the non-padded grid
x = int_obs % self._grid.shape[1]
y = int_obs // self._grid.shape[1]
return y, x
def reset(self):
# reset the gridworld
if self._random_start:
# randomly start at a non-terminal state
idx_start = np.random.randint(len(self._non_term_states))
start_state = self._non_term_states[idx_start]
self._state_padded = (start_state[0] + 1, start_state[1] + 1)
else:
# start at the designated start_state
self._state_padded = (self._start_state[0] + 1, self._start_state[1] + 1)
def step(self, action):
# take one step according to the action
# input: action (integer between 0 and 3)
# output: reward reward of this action
# terminated 1 if reaching the terminal state, 0 otherwise
# next_state next state after this action, integer from 0 to total_state_number-1)
y, x = self._state_padded
if action == 0: # up
new_state_padded = (y - 1, x)
elif action == 1: # right
new_state_padded = (y, x + 1)
elif action == 2: # down
new_state_padded = (y + 1, x)
elif action == 3: # left
new_state_padded = (y, x - 1)
else:
raise ValueError("Invalid action: {} is not 0, 1, 2, or 3.".format(action))
new_y, new_x = new_state_padded
if self._grid_padded[new_y, new_x] == -1: # wall/obstacle
reward = self._reward_wall
new_state_padded = (y, x)
elif self._grid_padded[new_y, new_x] == 0: # non-terminal cell
reward = 0.
else: # a goal
reward = self._grid_padded[new_y, new_x]
self.reset()
terminated = 1
return reward, terminated, self.get_current_state()
terminated = 0
self._state_padded = new_state_padded
return reward, terminated, self.get_current_state()
def plot_grid(self, plot_title=None):
# plot the grid
plt.figure(figsize=(5, 5))
plt.imshow(self._grid_padded <= -1, cmap='binary', interpolation="nearest")
ax = plt.gca()
ax.grid(0)
plt.xticks([])
plt.yticks([])
if plot_title:
plt.title(plot_title)
plt.text(
self._start_state[1] + 1, self._start_state[0] + 1,
r"$\mathbf{S}$", ha='center', va='center')
for goal_state in self._goal_states:
plt.text(
goal_state[1] + 1, goal_state[0] + 1,
"{:d}".format(self._grid[goal_state[0], goal_state[1]]), ha='center', va='center')
h, w = self._grid_padded.shape
for y in range(h - 1):
plt.plot([-0.5, w - 0.5], [y + 0.5, y + 0.5], '-k', lw=2)
for x in range(w - 1):
plt.plot([x + 0.5, x + 0.5], [-0.5, h - 0.5], '-k', lw=2)
def plot_state_values(self, state_values, value_format="{:.1f}",plot_title=None):
# plot the state values
# input: state_values (total_state_number, )-numpy array, state value function
# plot_title str, title of the plot
plt.figure(figsize=(5, 5))
plt.imshow((self._grid_padded <= -1) + (self._grid_padded > 0) * 0.5, cmap='Greys', vmin=0, vmax=1)
ax = plt.gca()
ax.grid(0)
plt.xticks([])
plt.yticks([])
if plot_title:
plt.title(plot_title)
for (int_obs, state_value) in enumerate(state_values):
y, x = self.int_to_state(int_obs)
if (y, x) in self._non_term_states:
plt.text(x + 1, y + 1, value_format.format(state_value), ha='center', va='center')
h, w = self._grid_padded.shape
for y in range(h - 1):
plt.plot([-0.5, w - 0.5], [y + 0.5, y + 0.5], '-k', lw=2)
for x in range(w - 1):
plt.plot([x + 0.5, x + 0.5], [-0.5, h - 0.5], '-k', lw=2)
def plot_policy(self, policy, plot_title=None):
# plot a deterministic policy
# input: policy (total_state_number, )-numpy array, contains action as integer from 0 to 3
# plot_title str, title of the plot
action_names = [r"$\uparrow$", r"$\rightarrow$", r"$\downarrow$", r"$\leftarrow$"]
plt.figure(figsize=(5, 5))
plt.imshow((self._grid_padded <= -1) + (self._grid_padded > 0) * 0.5, cmap='Greys', vmin=0, vmax=1)
ax = plt.gca()
ax.grid(0)
plt.xticks([])
plt.yticks([])
if plot_title:
plt.title(plot_title)
for (int_obs, action) in enumerate(policy):
y, x = self.int_to_state(int_obs)
if (y, x) in self._non_term_states:
action_arrow = action_names[action]
plt.text(x + 1, y + 1, action_arrow, ha='center', va='center')
def transition(self, action):
if action == 0: # up
anchor_state_padded = (0, 1)
elif action == 1: # right
anchor_state_padded = (1, 2)
elif action == 2: # down
anchor_state_padded = (2, 1)
elif action == 3: # left
anchor_state_padded = (1, 0)
else:
raise ValueError("Invalid action: {} is not 0, 1, 2, or 3.".format(action))
state_num = self.get_state_num()
h, w = self._grid.shape
y_a, x_a = anchor_state_padded
reward = np.multiply(self._grid_padded[y_a:y_a + h, x_a:x_a + w],self._grid==0)
state_grid, state_grid_padded = self.get_state_grid()
next_state = state_grid_padded[y_a:y_a + h, x_a:x_a + w]
next_state = np.multiply(state_grid, next_state == -1) + np.multiply(next_state, next_state > -1)
next_state[self._grid == -1] = -1
next_state[self._grid > 0] = state_grid[self._grid > 0]
next_state_vec = next_state.flatten()
state_vec = state_grid.flatten()
probability = np.zeros((state_num, state_num))
probability[state_vec[state_vec > -1], next_state_vec[state_vec > -1]] = 1
return reward.flatten(), probability