-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example: Simple RL example using DQN/Lightning (#1232)
* Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
- Loading branch information
1 parent
4e0d0ab
commit dab3b96
Showing
3 changed files
with
363 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,360 @@ | ||
""" | ||
# Deep Reinforcement Learning: Deep Q-network (DQN) | ||
this example is based off https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- | ||
Second-Edition/blob/master/Chapter06/02_dqn_pong.py | ||
The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the | ||
classic CartPole environment. | ||
to run the template just run: | ||
python dqn.py | ||
After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up tensor boards to | ||
see the metrics. | ||
tensorboard --logdir default | ||
""" | ||
|
||
import pytorch_lightning as pl | ||
|
||
from typing import Tuple, List | ||
|
||
import argparse | ||
from collections import OrderedDict, deque, namedtuple | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.dataset import IterableDataset | ||
|
||
|
||
class DQN(nn.Module): | ||
""" | ||
Simple MLP network | ||
Args: | ||
obs_size: observation/state size of the environment | ||
n_actions: number of discrete actions available in the environment | ||
hidden_size: size of hidden layers | ||
""" | ||
|
||
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): | ||
super(DQN, self).__init__() | ||
self.net = nn.Sequential( | ||
nn.Linear(obs_size, hidden_size), | ||
nn.ReLU(), | ||
nn.Linear(hidden_size, n_actions) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x.float()) | ||
|
||
|
||
# Named tuple for storing experience steps gathered in training | ||
Experience = namedtuple( | ||
'Experience', field_names=['state', 'action', 'reward', | ||
'done', 'new_state']) | ||
|
||
|
||
class ReplayBuffer: | ||
""" | ||
Replay Buffer for storing past experiences allowing the agent to learn from them | ||
Args: | ||
capacity: size of the buffer | ||
""" | ||
|
||
def __init__(self, capacity: int) -> None: | ||
self.buffer = deque(maxlen=capacity) | ||
|
||
def __len__(self) -> None: | ||
return len(self.buffer) | ||
|
||
def append(self, experience: Experience) -> None: | ||
""" | ||
Add experience to the buffer | ||
Args: | ||
experience: tuple (state, action, reward, done, new_state) | ||
""" | ||
self.buffer.append(experience) | ||
|
||
def sample(self, batch_size: int) -> Tuple: | ||
indices = np.random.choice(len(self.buffer), batch_size, replace=False) | ||
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices]) | ||
|
||
return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), | ||
np.array(dones, dtype=np.bool), np.array(next_states)) | ||
|
||
|
||
class RLDataset(IterableDataset): | ||
""" | ||
Iterable Dataset containing the ExperienceBuffer | ||
which will be updated with new experiences during training | ||
Args: | ||
buffer: replay buffer | ||
sample_size: number of experiences to sample at a time | ||
""" | ||
|
||
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: | ||
self.buffer = buffer | ||
self.sample_size = sample_size | ||
|
||
def __iter__(self) -> Tuple: | ||
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size) | ||
for i in range(len(dones)): | ||
yield states[i], actions[i], rewards[i], dones[i], new_states[i] | ||
|
||
|
||
class Agent: | ||
""" | ||
Base Agent class handeling the interaction with the environment | ||
Args: | ||
env: training environment | ||
replay_buffer: replay buffer storing experiences | ||
""" | ||
|
||
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: | ||
self.env = env | ||
self.replay_buffer = replay_buffer | ||
self.reset() | ||
self.state = self.env.reset() | ||
|
||
def reset(self) -> None: | ||
""" Resents the environment and updates the state""" | ||
self.state = self.env.reset() | ||
|
||
def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: | ||
""" | ||
Using the given network, decide what action to carry out | ||
using an epsilon-greedy policy | ||
Args: | ||
net: DQN network | ||
epsilon: value to determine likelihood of taking a random action | ||
device: current device | ||
Returns: | ||
action | ||
""" | ||
if np.random.random() < epsilon: | ||
action = self.env.action_space.sample() | ||
else: | ||
state = torch.tensor([self.state]) | ||
|
||
if device not in ['cpu']: | ||
state = state.cuda(device) | ||
|
||
q_values = net(state) | ||
_, action = torch.max(q_values, dim=1) | ||
action = int(action.item()) | ||
|
||
return action | ||
|
||
@torch.no_grad() | ||
def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]: | ||
""" | ||
Carries out a single interaction step between the agent and the environment | ||
Args: | ||
net: DQN network | ||
epsilon: value to determine likelihood of taking a random action | ||
device: current device | ||
Returns: | ||
reward, done | ||
""" | ||
|
||
action = self.get_action(net, epsilon, device) | ||
|
||
# do step in the environment | ||
new_state, reward, done, _ = self.env.step(action) | ||
|
||
exp = Experience(self.state, action, reward, done, new_state) | ||
|
||
self.replay_buffer.append(exp) | ||
|
||
self.state = new_state | ||
if done: | ||
self.reset() | ||
return reward, done | ||
|
||
|
||
class DQNLightning(pl.LightningModule): | ||
""" Basic DQN Model """ | ||
|
||
def __init__(self, hparams: argparse.Namespace) -> None: | ||
super().__init__() | ||
self.hparams = hparams | ||
|
||
self.env = gym.make(self.hparams.env) | ||
obs_size = self.env.observation_space.shape[0] | ||
n_actions = self.env.action_space.n | ||
|
||
self.net = DQN(obs_size, n_actions) | ||
self.target_net = DQN(obs_size, n_actions) | ||
|
||
self.buffer = ReplayBuffer(self.hparams.replay_size) | ||
self.agent = Agent(self.env, self.buffer) | ||
self.total_reward = 0 | ||
self.episode_reward = 0 | ||
self.populate(self.hparams.warm_start_steps) | ||
|
||
def populate(self, steps: int = 1000) -> None: | ||
""" | ||
Carries out several random steps through the environment to initially fill | ||
up the replay buffer with experiences | ||
Args: | ||
steps: number of random steps to populate the buffer with | ||
""" | ||
for i in range(steps): | ||
self.agent.play_step(self.net, epsilon=1.0) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Passes in a state x through the network and gets the q_values of each action as an output | ||
Args: | ||
x: environment state | ||
Returns: | ||
q values | ||
""" | ||
output = self.net(x) | ||
return output | ||
|
||
def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: | ||
""" | ||
Calculates the mse loss using a mini batch from the replay buffer | ||
Args: | ||
batch: current mini batch of replay data | ||
Returns: | ||
loss | ||
""" | ||
states, actions, rewards, dones, next_states = batch | ||
|
||
state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) | ||
|
||
with torch.no_grad(): | ||
next_state_values = self.target_net(next_states).max(1)[0] | ||
next_state_values[dones] = 0.0 | ||
next_state_values = next_state_values.detach() | ||
|
||
expected_state_action_values = next_state_values * self.hparams.gamma + rewards | ||
|
||
return nn.MSELoss()(state_action_values, expected_state_action_values) | ||
|
||
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: | ||
""" | ||
Carries out a single step through the environment to update the replay buffer. | ||
Then calculates loss based on the minibatch recieved | ||
Args: | ||
batch: current mini batch of replay data | ||
nb_batch: batch number | ||
Returns: | ||
Training loss and log metrics | ||
""" | ||
device = self.get_device(batch) | ||
epsilon = max(self.hparams.eps_end, self.hparams.eps_start - | ||
self.global_step + 1 / self.hparams.eps_last_frame) | ||
|
||
# step through environment with agent | ||
reward, done = self.agent.play_step(self.net, epsilon, device) | ||
self.episode_reward += reward | ||
|
||
# calculates training loss | ||
loss = self.dqn_mse_loss(batch) | ||
|
||
if self.trainer.use_dp or self.trainer.use_ddp2: | ||
loss = loss.unsqueeze(0) | ||
|
||
if done: | ||
self.total_reward = self.episode_reward | ||
self.episode_reward = 0 | ||
|
||
# Soft update of target network | ||
if self.global_step % self.hparams.sync_rate == 0: | ||
self.target_net.load_state_dict(self.net.state_dict()) | ||
|
||
log = {'total_reward': torch.tensor(self.total_reward).to(device), | ||
'reward': torch.tensor(reward).to(device), | ||
'steps': torch.tensor(self.global_step).to(device)} | ||
|
||
return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log}) | ||
|
||
def configure_optimizers(self) -> List[Optimizer]: | ||
""" Initialize Adam optimizer""" | ||
optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr) | ||
return [optimizer] | ||
|
||
def __dataloader(self) -> DataLoader: | ||
"""Initialize the Replay Buffer dataset used for retrieving experiences""" | ||
dataset = RLDataset(self.buffer, self.hparams.episode_length) | ||
dataloader = DataLoader(dataset=dataset, | ||
batch_size=self.hparams.batch_size, | ||
sampler=None | ||
) | ||
return dataloader | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
"""Get train loader""" | ||
return self.__dataloader() | ||
|
||
def get_device(self, batch) -> str: | ||
"""Retrieve device currently being used by minibatch""" | ||
return batch[0].device.index if self.on_gpu else 'cpu' | ||
|
||
|
||
def main(hparams) -> None: | ||
model = DQNLightning(hparams) | ||
|
||
trainer = pl.Trainer( | ||
gpus=1, | ||
distributed_backend='dp', | ||
early_stop_callback=False, | ||
val_check_interval=100 | ||
) | ||
|
||
trainer.fit(model) | ||
|
||
|
||
if __name__ == '__main__': | ||
torch.manual_seed(0) | ||
np.random.seed(0) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") | ||
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") | ||
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") | ||
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") | ||
parser.add_argument("--sync_rate", type=int, default=10, | ||
help="how many frames do we update the target network") | ||
parser.add_argument("--replay_size", type=int, default=1000, | ||
help="capacity of the replay buffer") | ||
parser.add_argument("--warm_start_size", type=int, default=1000, | ||
help="how many samples do we use to fill our buffer at the start of training") | ||
parser.add_argument("--eps_last_frame", type=int, default=1000, | ||
help="what frame should epsilon stop decaying") | ||
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") | ||
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") | ||
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") | ||
parser.add_argument("--max_episode_reward", type=int, default=200, | ||
help="max episode reward in the environment") | ||
parser.add_argument("--warm_start_steps", type=int, default=1000, | ||
help="max episode reward in the environment") | ||
|
||
args = parser.parse_args() | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
torchvision>=0.4.0 | ||
torchvision>=0.4.0 | ||
gym>=0.17.0 |
dab3b96
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we please next time correct the commit message, this generated plenty of false cross-references and all contributors to these issues are notified with this new commit...