-
Notifications
You must be signed in to change notification settings - Fork 0
/
minatar_env.py
77 lines (61 loc) · 3.16 KB
/
minatar_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
from minatar import Environment
from minatar.gui import GUI
import gym
from gym.spaces import Discrete, Box
class MinAtarEnv(gym.Env):
def __init__(self, env_id='breakout', sticky_action_prob=0.1, difficulty_ramping=True, random_seed=None,
broken_pixel_mask=None, broken_pixels_ratio=0):
self.sticky_action_prob = sticky_action_prob
self.difficulty_ramping = difficulty_ramping
self.env_id = env_id
self.game = Environment(env_id, sticky_action_prob=sticky_action_prob,
difficulty_ramping=difficulty_ramping, random_seed=random_seed)
self.observation_space = Box(0, 1, shape=[self.game.n_channels, 10, 10])
self.oracle_observation_space = self.observation_space
self.action_space = Discrete(self.game.num_actions())
if broken_pixel_mask is None:
self.broken_pixels_ratio = broken_pixels_ratio
self.broken_pixel_mask = np.random.choice(2, size=[1, 10, 10],
p=[1 - self.broken_pixels_ratio, self.broken_pixels_ratio])
self.random_broken_pixels = True
else:
self.broken_pixel_mask = broken_pixel_mask
self.broken_pixels_ratio = np.sum(broken_pixel_mask) / 100
self.random_broken_pixels = False
def seed(self, seeding=None):
self.game = Environment(self.env_id, sticky_action_prob=self.sticky_action_prob,
difficulty_ramping=self.difficulty_ramping, random_seed=seeding)
def reset(self):
self.game.reset()
obs = self.game.state().astype(np.float32).swapaxes(0, 2).swapaxes(2, 1)
if self.broken_pixels_ratio:
if self.random_broken_pixels:
broken_pixel_mask = np.random.choice(2, size=[1, 10, 10],
p=[1 - self.broken_pixels_ratio, self.broken_pixels_ratio])
else:
broken_pixel_mask = self.broken_pixel_mask
obs = obs - (obs * broken_pixel_mask) + 2 * broken_pixel_mask # broken pixels obs = 2, normal obs = 0 or 1
return obs
def get_oracle_obs(self):
return self.game.state().astype(np.float32).swapaxes(0, 2).swapaxes(2, 1)
def get_full_obs(self):
return np.concatenate([self.get_obs(), self.get_oracle_obs()], axis=0)
def get_obs(self):
obs = self.game.state().astype(np.float32).swapaxes(0, 2).swapaxes(2, 1)
if self.broken_pixels_ratio:
if self.random_broken_pixels:
broken_pixel_mask = np.random.choice(2, size=[1, 10, 10],
p=[1 - self.broken_pixels_ratio, self.broken_pixels_ratio])
else:
broken_pixel_mask = self.broken_pixel_mask
obs = obs - (obs * broken_pixel_mask) + 2 * broken_pixel_mask # broken pixels obs = 2, normal obs = 0 or 1
return obs
def step(self, action):
reward, done = self.game.act(action)
obs = self.get_obs()
return obs, reward, done, {}
def render(self, mode='human'):
pass
def close(self):
pass