diff --git a/README.md b/README.md index c86c40529..9329551b0 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m All of the platforms use 5 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns (except for PyTorch-DRL). -We will add results of Atari Pong / Mujoco these days. +The Atari/Mujoco benchmark results are under [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders. ### Reproducible diff --git a/docs/tutorials/trick.rst b/docs/tutorials/trick.rst index 36a55f3e7..5a73ff9e6 100644 --- a/docs/tutorials/trick.rst +++ b/docs/tutorials/trick.rst @@ -73,6 +73,12 @@ Tianshou has many short-but-efficient lines of code. For example, when we want t .. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference. +Atari/Mujoco Task Specific +-------------------------- + +Please refer to `Atari examples page `_ and `Mujoco examples page `_. + + Finally ------- diff --git a/examples/atari/README.md b/examples/atari/README.md new file mode 100644 index 000000000..40c025c4f --- /dev/null +++ b/examples/atari/README.md @@ -0,0 +1,25 @@ +# Atari General + +The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network. + +The Atari env seed cannot be fixed due to the discussion [here](https://github.com/openai/gym/issues/1478), but it is not a big issue since on Atari it will always have the similar results. + +The env wrapper is a crucial thing. Without wrappers, the agent cannot perform well enough on Atari games. Many existing RL codebases use [OpenAI wrapper](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py), but it is not the original DeepMind version ([related issue](https://github.com/openai/baselines/issues/240)). Dopamine has a different [wrapper](https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py) but unfortunately it cannot work very well in our codebase. + +# DQN (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | time cost | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- | +| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) | +| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | +| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) | +| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | +| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | +| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | +| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | + +Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. + +We haven't tuned this result to the best, so have fun with playing these hyperparameters! diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py new file mode 100644 index 000000000..0ddbf07e7 --- /dev/null +++ b/examples/atari/atari_dqn.py @@ -0,0 +1,147 @@ +import os +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import DQNPolicy +from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.discrete import DQN +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps_test', type=float, default=0.005) + parser.add_argument('--eps_train', type=float, default=1.) + parser.add_argument('--eps_train_final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--n_step', type=int, default=3) + parser.add_argument('--target_update_freq', type=int, default=500) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step_per_epoch', type=int, default=10000) + parser.add_argument('--collect_per_step', type=int, default=10) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--training_num', type=int, default=16) + parser.add_argument('--test_num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--frames_stack', type=int, default=4) + parser.add_argument('--resume_path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') + return parser.parse_args() + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_dqn(args=get_args()): + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape: ", args.state_shape) + print("Actions shape: ", args.action_shape) + # make environments + train_envs = SubprocVectorEnv([lambda: make_atari_env(args) + for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = DQN(*args.state_shape, + args.action_shape, args.device).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = DQNPolicy(net, optim, args.gamma, args.n_step, + target_update_freq=args.target_update_freq) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, + save_last_obs=True, stack_num=args.frames_stack) + # collector + train_collector = Collector(policy, train_envs, buffer) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'dqn') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(x): + if env.env.spec.reward_threshold: + return x >= env.spec.reward_threshold + elif 'Pong' in args.task: + return x >= 20 + + def train_fn(x): + # nature DQN setting, linear decay in the first 1M steps + now = x * args.collect_per_step * args.step_per_epoch + if now <= 1e6: + eps = args.eps_train - now / 1e6 * \ + (args.eps_train - args.eps_train_final) + policy.set_eps(eps) + else: + policy.set_eps(args.eps_train_final) + print("set eps =", policy.eps) + + def test_fn(x): + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch(): + print("Testing agent ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * 4) + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_dqn(get_args()) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py new file mode 100644 index 000000000..53a662613 --- /dev/null +++ b/examples/atari/atari_wrapper.py @@ -0,0 +1,237 @@ +# Borrow a lot from openai baselines: +# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py + +import cv2 +import gym +import numpy as np +from collections import deque + + +class NoopResetEnv(gym.Wrapper): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + + :param gym.Env env: the environment to wrap. + :param int noop_max: the maximum value of no-ops to run. + """ + + def __init__(self, env, noop_max=30): + super().__init__(env) + self.noop_max = noop_max + self.noop_action = 0 + assert env.unwrapped.get_action_meanings()[0] == 'NOOP' + + def reset(self): + self.env.reset() + noops = np.random.randint(1, self.noop_max + 1) + for _ in range(noops): + obs, _, done, _ = self.env.step(self.noop_action) + if done: + obs = self.env.reset() + return obs + + +class MaxAndSkipEnv(gym.Wrapper): + """Return only every `skip`-th frame (frameskipping) using most recent raw + observations (for max pooling across time steps) + + :param gym.Env env: the environment to wrap. + :param int skip: number of `skip`-th frame. + """ + + def __init__(self, env, skip=4): + super().__init__(env) + self._skip = skip + + def step(self, action): + """Step the environment with the given action. Repeat action, sum + reward, and max over last observations. + """ + obs_list, total_reward, done = [], 0., False + for i in range(self._skip): + obs, reward, done, info = self.env.step(action) + obs_list.append(obs) + total_reward += reward + if done: + break + max_frame = np.max(obs_list[-2:], axis=0) + return max_frame, total_reward, done, info + + +class EpisodicLifeEnv(gym.Wrapper): + """Make end-of-life == end-of-episode, but only reset on true game over. It + helps the value estimation. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env): + super().__init__(env) + self.lives = 0 + self.was_real_done = True + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.was_real_done = done + # check current lives, make loss of life terminal, then update lives to + # handle bonus lives + lives = self.env.unwrapped.ale.lives() + if 0 < lives < self.lives: + # for Qbert sometimes we stay in lives == 0 condition for a few + # frames, so its important to keep lives > 0, so that we only reset + # once the environment is actually done. + done = True + self.lives = lives + return obs, reward, done, info + + def reset(self): + """Calls the Gym environment reset, only when lives are exhausted. This + way all states are still reachable even though lives are episodic, and + the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_done: + obs = self.env.reset() + else: + # no-op step to advance from terminal/lost life state + obs = self.env.step(0)[0] + self.lives = self.env.unwrapped.ale.lives() + return obs + + +class FireResetEnv(gym.Wrapper): + """Take action on reset for environments that are fixed until firing. + Related discussion: https://github.com/openai/baselines/issues/240 + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env): + super().__init__(env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self): + self.env.reset() + return self.env.step(1)[0] + + +class WarpFrame(gym.ObservationWrapper): + """Warp frames to 84x84 as done in the Nature paper and later work. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env): + super().__init__(env) + self.size = 84 + self.observation_space = gym.spaces.Box( + low=np.min(env.observation_space.low), + high=np.max(env.observation_space.high), + shape=(self.size, self.size), dtype=env.observation_space.dtype) + + def observation(self, frame): + """returns the current observation from a frame""" + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + return cv2.resize(frame, (self.size, self.size), + interpolation=cv2.INTER_AREA) + + +class ScaledFloatFrame(gym.ObservationWrapper): + """Normalize observations to 0~1. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env): + super().__init__(env) + low = np.min(env.observation_space.low) + high = np.max(env.observation_space.high) + self.bias = low + self.scale = high - low + self.observation_space = gym.spaces.Box( + low=0., high=1., shape=env.observation_space.shape, + dtype=np.float32) + + def observation(self, observation): + return (observation - self.bias) / self.scale + + +class ClipRewardEnv(gym.RewardWrapper): + """clips the reward to {+1, 0, -1} by its sign. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env): + super().__init__(env) + self.reward_range = (-1, 1) + + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0.""" + return np.sign(reward) + + +class FrameStack(gym.Wrapper): + """Stack n_frames last frames. + + :param gym.Env env: the environment to wrap. + :param int n_frames: the number of frames to stack. + """ + + def __init__(self, env, n_frames): + super().__init__(env) + self.n_frames = n_frames + self.frames = deque([], maxlen=n_frames) + shape = (n_frames,) + env.observation_space.shape + self.observation_space = gym.spaces.Box( + low=np.min(env.observation_space.low), + high=np.max(env.observation_space.high), + shape=shape, dtype=env.observation_space.dtype) + + def reset(self): + obs = self.env.reset() + for _ in range(self.n_frames): + self.frames.append(obs) + return self._get_ob() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.frames.append(obs) + return self._get_ob(), reward, done, info + + def _get_ob(self): + # the original wrapper use `LazyFrames` but since we use np buffer, + # it has no effect + return np.stack(self.frames, axis=0) + + +def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, + frame_stack=4, scale=False, warp_frame=True): + """Configure environment for DeepMind-style Atari. The observation is + channel-first: (c, h, w) instead of (h, w, c). + + :param str env_id: the atari environment id. + :param bool episode_life: wrap the episode life wrapper. + :param bool clip_rewards: wrap the reward clipping wrapper. + :param int frame_stack: wrap the frame stacking wrapper. + :param bool scale: wrap the scaling observation wrapper. + :param bool warp_frame: wrap the grayscale + resize observation wrapper. + :return: the wrapped atari environment. + """ + assert 'NoFrameskip' in env_id + env = gym.make(env_id) + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + if episode_life: + env = EpisodicLifeEnv(env) + if 'FIRE' in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + if warp_frame: + env = WarpFrame(env) + if scale: + env = ScaledFloatFrame(env) + if clip_rewards: + env = ClipRewardEnv(env) + if frame_stack: + env = FrameStack(env, frame_stack) + return env diff --git a/examples/atari/results/dqn/Breakout_rew.png b/examples/atari/results/dqn/Breakout_rew.png new file mode 100644 index 000000000..2deed236a Binary files /dev/null and b/examples/atari/results/dqn/Breakout_rew.png differ diff --git a/examples/atari/results/dqn/Enduro_rew.png b/examples/atari/results/dqn/Enduro_rew.png new file mode 100644 index 000000000..27e0c6f39 Binary files /dev/null and b/examples/atari/results/dqn/Enduro_rew.png differ diff --git a/examples/atari/results/dqn/MsPacman_rew.png b/examples/atari/results/dqn/MsPacman_rew.png new file mode 100644 index 000000000..6a3a88ab4 Binary files /dev/null and b/examples/atari/results/dqn/MsPacman_rew.png differ diff --git a/examples/atari/results/dqn/Pong_rew.png b/examples/atari/results/dqn/Pong_rew.png new file mode 100644 index 000000000..75b289873 Binary files /dev/null and b/examples/atari/results/dqn/Pong_rew.png differ diff --git a/examples/atari/results/dqn/Qbert_rew.png b/examples/atari/results/dqn/Qbert_rew.png new file mode 100644 index 000000000..8f674881a Binary files /dev/null and b/examples/atari/results/dqn/Qbert_rew.png differ diff --git a/examples/atari/results/dqn/Seaquest_rew.png b/examples/atari/results/dqn/Seaquest_rew.png new file mode 100644 index 000000000..5ed082103 Binary files /dev/null and b/examples/atari/results/dqn/Seaquest_rew.png differ diff --git a/examples/atari/results/dqn/SpaceInvader_rew.png b/examples/atari/results/dqn/SpaceInvader_rew.png new file mode 100644 index 000000000..57c51fe9a Binary files /dev/null and b/examples/atari/results/dqn/SpaceInvader_rew.png differ diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 16bf5c34f..9fcccd904 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -91,11 +91,13 @@ def test_stack(size=5, bufsize=9, stack_num=4): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) + buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) + buf3.add([None, None, obs], 1, rew, done, [None, obs], info) obs = obs_next if done: obs = env.reset(1) @@ -104,6 +106,8 @@ def test_stack(size=5, bufsize=9, stack_num=4): [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) + assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs')) + assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next')) _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 7d77b3058..3c5658f8e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -115,6 +115,9 @@ class ReplayBuffer: than or equal to 1, defaults to 1 (no stacking). :param bool ignore_obs_next: whether to store obs_next, defaults to ``False``. + :param bool save_only_last_obs: only save the last obs/obs_next when it has + a shape of (timestep, ...) because of temporal stacking, defaults to + ``False``. :param bool sample_avail: the parameter indicating sampling only available index when using frame-stack sampling method, defaults to ``False``. This feature is not supported in Prioritized Replay Buffer currently. @@ -122,6 +125,7 @@ class ReplayBuffer: def __init__(self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, + save_only_last_obs: bool = False, sample_avail: bool = False) -> None: super().__init__() self._maxsize = size @@ -131,6 +135,7 @@ def __init__(self, size: int, stack_num: int = 1, self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next + self._last_obs = save_only_last_obs self._index = 0 self._size = 0 self._meta = Batch() @@ -210,6 +215,8 @@ def add(self, """Add a batch of data into replay buffer.""" assert isinstance(info, (dict, Batch)), \ 'You should return a dict in the last argument of env.step().' + if self._last_obs: + obs = obs[-1] self._add_to_buffer('obs', obs) self._add_to_buffer('act', act) self._add_to_buffer('rew', rew) @@ -217,6 +224,8 @@ def add(self, if self._save_s_: if obs_next is None: obs_next = Batch() + elif self._last_obs: + obs_next = obs_next[-1] self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('info', info) self._add_to_buffer('policy', policy)