-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplay.py
126 lines (104 loc) · 3.51 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import print_function
import argparse
from pyglet.window import key
import gym
import numpy as np
import pickle
import os
from datetime import datetime
import gzip
import json
import copy
def key_press(k, mod):
global restart
if k == key.ESCAPE: restart = True
if k == key.UP:
a[3] = +1.0
if a[0] == 0.0:
a[1] = +1.0
if k == key.LEFT:
a[0] = -1.0
a[1] = 0.0 # Cut gas while turning
if k == key.RIGHT:
a[0] = +1.0
a[1] = 0.0 # Cut gas while turning
# if k == key.DOWN:
# a[2] = +0.4 # stronger brakes
def key_release(k, mod):
if k == key.LEFT and a[0] == -1.0:
a[0] = 0.0
if a[3] == 1.0:
a[1] = 1.0
if k == key.RIGHT and a[0] == +1.0:
a[0] = 0.0
if a[3] == 1.0:
a[1] = 1.0
if k == key.UP:
a[1] = 0.0
a[3] = 0.0
# if k == key.DOWN:
# a[2] = 0.0
def store_data(data, datasets_dir="./data"):
# save data
if not os.path.exists(datasets_dir):
os.mkdir(datasets_dir)
data_file = os.path.join(datasets_dir, 'data.pkl.gzip')
f = gzip.open(data_file,'wb')
pickle.dump(data, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collect_data", action="store_true", default=True, help="Collect the data in a pickle file.")
args = parser.parse_args()
good_samples = {
"state": [],
"next_state": [],
"reward": [],
"action": [],
"terminal" : [],
}
episode_samples = copy.deepcopy(good_samples)
env = gym.make('CarRacing-v0').unwrapped
env.reset()
env.viewer.window.on_key_press = key_press
env.viewer.window.on_key_release = key_release
a = np.zeros(4, dtype=np.float32)
episode_rewards = []
good_steps = episode_steps = 0
# Episode loop
while True:
episode_samples["state"] = []
episode_samples["action"] = []
episode_samples["next_state"] = []
episode_samples["reward"] = []
episode_samples["terminal"] = []
episode_reward = 0
state = env.reset()
restart = False
episode_steps = good_steps
# State loop
while True:
next_state, r, done, info = env.step(a[:3])
episode_reward += r
episode_samples["state"].append(state) # state has shape (96, 96, 3)
episode_samples["action"].append(np.array(a[:3])) # action has shape (1, 3)
episode_samples["next_state"].append(next_state)
episode_samples["reward"].append(r)
episode_samples["terminal"].append(done)
state = next_state
episode_steps += 1
if episode_steps % 1000 == 0 or done:
print("\nstep {}".format(episode_steps))
env.render()
if done or restart:
break
if not restart:
good_steps = episode_steps
episode_rewards.append(episode_reward)
good_samples["state"].append(episode_samples["state"])
good_samples["action"].append(episode_samples["action"])
good_samples["next_state"].append(episode_samples["next_state"])
good_samples["reward"].append(episode_samples["reward"])
good_samples["terminal"].append(episode_samples["terminal"])
print('... saving data')
store_data(good_samples, "./data")
env.close()