-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adiciona DDPG #45
base: main
Are you sure you want to change the base?
Adiciona DDPG #45
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No geral tem que colocar as equações como imagens (coloquei já como sugerido) e colocar um gráfico do aumento de rewards no notebook, assim como nos outros do repo.
@@ -0,0 +1 @@ | |||
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"DDPG.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"hjGWyD9Hyfg-"},"source":["# DDPG - Deep Deterministic Policy Gradient\n","\n","Deep Deterministic Policy Gradient (DDPG) é um algoritmo off-policy que aprende simultaneamente uma função Q-valor e uma política. Com a equação de Bellman, aprende a função Q e, com esta, aprende a política.\n","\n","A ideia de um algoritmo com gradiente da política surgiu primeiro com [Silver (sempre ele)](http://proceedings.mlr.press/v32/silver14.pdf), mas o algoritmo DDPG em si foi estabelecido na publicação [por Lilicrap](https://arxiv.org/abs/1509.02971)."]},{"cell_type":"markdown","metadata":{"id":"EXwpMV1OAnb_"},"source":["## O algoritmo de DDPG\n","\n","Este algoritmo é um Actor-Critic, mas, diferentemente dos que vimos antes, é off policy. Diferente do que estamos acostumados, este algoritmo só pode ser usado em espaços de ação contínuos. Podemos pensar então no DDPG como uma \"DQN para espaço de ação contínuo\".\n","\n","$$a^*(s) = arg max_a Q^*(s,a)$$\n","\n","A equação acima, da ação ótima para um dado estado, é familiar, mas se torna problemática para ações não discretas. É fácil contabilizar o máximo de uma série finita de ações, mas ações contínuas tornam essa operação problemática.\n","\n","Como sabemos que o espaço de ações é contínuo, podemos definir a política (ou função ator) $\\mu_\\theta(s) $, que vai nos ajudar a computar o max. Lembrando que, nesse algoritmo, a política é determinística.\n","\n","Dessa forma, temos então $max_a Q(s,a) \\approx Q(s, \\mu(s))$, facilitando a computação do max.\n","\n","### Q-Learning\n","\n","$$Q^*(s,a) = E_{s \\sim P} \\bigg[r(s,a) + \\gamma max_{a'} Q^*(s',a') \\bigg] $$\n","\n","A parte de Q learning, cuja equação para encontrar o valor ótimo de Q está acima, consiste em minimizar o erro quadrático médio de Bellman (MSBE), próximo do que já se era feito com a DQN.\n","\n","O [post da OpenAI](https://spinningup.openai.com/en/latest/algorithms/ddpg.html) oferece duas ferramentas para facilitar o treinamento da função Q, o uso de Replay Buffers e o uso de Target Networks. Uma vantagem de DDPG ser off-policy é permitir o uso de Replay Buffers.\n","\n","\n","### Policy Gradient\n","\n","O objetivo é achar a política determinística $\\mu_{\\theta}(s) $ que fornece a ação que maximiza $Q(s,a)$. Diferenciando a função Q em a, é possível então realizar gradiente ascendente para maximizar a política:\n","\n","$$\\nabla_\\theta J = \\mathbb{E}_{s_t} \\bigg[ \\nabla_a Q(s,a|\\theta^Q) |_{s = s_t, a = \\mu(s_t)} \\nabla_\\theta \\mu_\\theta(s|\\theta^\\mu)|_{s = s_t} \\bigg] $$\n","\n","\n","## Exploração\n","\n","Um problema para casos de ação contínua é a exploração. Por outro lado, por se tratar de um algoritmo off-policy, podemos tratar o problema da exploração independente do algoritmo de aprendizado. Dessa forma, podemos criar uma política de exploração $\\mu'$ adicionando ruído na nossa política/função ator:\n","\n","$$\\mu'(s) = \\mu_\\theta(s) + \\mathcal{N} $$\n","\n","No paper original, o ruído foi gerado por um processo [Ornstein-Uhlenbeck](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process). Porém, a openAI recomenda o uso de um ruído gaussiano não correlacionado e de média zero, por ser mais simples e funcionar tão bem quanto.\n","\n","\n","## Pseudocódigo\n","\n","![DDPG](https://i.postimg.cc/bwfJDNFb/DDPG-1.png)\n","\n","\n","## Referências\n","https://spinningup.openai.com/en/latest/algorithms/ddpg.html\n","\n","http://proceedings.mlr.press/v32/silver14.pdf"]},{"cell_type":"markdown","metadata":{"id":"btLijj8y7gLH"},"source":["# Código\n"]},{"cell_type":"markdown","metadata":{"id":"YkTEyq25ms2B"},"source":["## Ator"]},{"cell_type":"code","metadata":{"id":"JOdt9KwNmcFR","executionInfo":{"status":"ok","timestamp":1606592473854,"user_tz":180,"elapsed":924,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}}},"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class Actor(nn.Module):\n"," \"\"\"Rede do Ator.\"\"\"\n"," def __init__(self, observation_shape, action_shape, action_high, action_low):\n"," \"\"\"Inicializa a rede.\n"," \n"," Parâmetros\n"," ----------\n"," observation_shape: int\n"," Formato do estado do ambiente.\n"," \n"," action_shape: int\n"," Número de ações do ambiente.\n"," \"\"\"\n"," super(Actor, self).__init__()\n"," self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"," self.action_high = torch.from_numpy(action_high).to(self.device)\n"," self.action_low = torch.from_numpy(action_low).to(self.device)\n","\n"," self.linear1 = nn.Linear(observation_shape, 512)\n"," self.linear2 = nn.Linear(512, 256)\n"," self.linear3 = nn.Linear(256, action_shape)\n","\n"," def forward(self, state):\n"," \"\"\"\n"," Calcula a probabilidade de ação para o estado atual.\n"," \n"," Parâmetros\n"," ----------\n"," state: np.array\n"," Estado atual.\n"," \n"," Retorna\n"," -------\n"," action: torch.Tensor\n"," Ações.\n"," \"\"\"\n"," x = F.relu(self.linear1(state))\n"," x = F.relu(self.linear2(x))\n"," x = torch.tanh(self.linear3(x))\n"," \n"," action = x * (self.action_high - self.action_low) / 2.0 +\\\n"," (self.action_high + self.action_low) / 2.0\n","\n"," return action"],"execution_count":1,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"TDbUOyZDmu6o"},"source":["## Crítico"]},{"cell_type":"code","metadata":{"id":"L2Ne0at-msW4","executionInfo":{"status":"ok","timestamp":1606592473857,"user_tz":180,"elapsed":920,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}}},"source":["import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class Critic(nn.Module):\n"," \"\"\"Rede do Crítico.\"\"\"\n"," def __init__(self, observation_shape, action_shape):\n"," \"\"\"Inicializa a rede.\n"," \n"," Parâmetros\n"," ----------\n"," observation_shape: int\n"," Formato do estado do ambiente.\n","\n"," action_shape: int\n"," Número de ações do ambiente.\n"," \"\"\"\n"," super(Critic, self).__init__()\n","\n"," self.linear1 = nn.Linear(observation_shape, 512)\n"," self.linear2 = nn.Linear(512 + action_shape, 512)\n"," self.linear3 = nn.Linear(512, 300)\n"," self.linear4 = nn.Linear(300, 1)\n","\n"," def forward(self, state, action):\n"," \"\"\"\n"," Calcula o valor do estado atual.\n"," \n"," Parâmetros\n"," ----------\n"," state: np.array\n"," Estado atual.\n","\n"," state: np.array\n"," Ação escolhida.\n"," \n"," Retorna\n"," -------\n"," q: float\n"," Valor da ação escolhida.\n"," \"\"\"\n"," x = F.relu(self.linear1(state))\n"," xa_cat = torch.cat([x, action], 1)\n"," xa = F.relu(self.linear2(xa_cat))\n"," xa = F.relu(self.linear3(xa))\n"," q = self.linear4(xa)\n","\n"," return q"],"execution_count":2,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1ZiSoaCSnX-3"},"source":["## Experience Replay"]},{"cell_type":"code","metadata":{"id":"oh5ztI22nXiX","executionInfo":{"status":"ok","timestamp":1606592473859,"user_tz":180,"elapsed":918,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}}},"source":["import numpy as np\n","\n","class ExperienceReplay:\n"," \"\"\"Experience Replay Buffer para A2C.\"\"\"\n"," def __init__(self, max_length, observation_space):\n"," \"\"\"Cria um Replay Buffer.\n","\n"," Parâmetros\n"," ----------\n"," max_length: int\n"," Tamanho máximo do Replay Buffer.\n"," observation_space: int\n"," Tamanho do espaço de observação.\n"," \"\"\"\n"," self.index, self.length, self.max_length = 0, 0, max_length\n","\n"," self.states = np.zeros((max_length, *observation_space), dtype=np.float32)\n"," self.actions = np.zeros((max_length), dtype=np.int32)\n"," self.rewards = np.zeros((max_length), dtype=np.float32)\n"," self.next_states = np.zeros((max_length, *observation_space), dtype=np.float32)\n"," self.dones = np.zeros((max_length), dtype=np.float32)\n","\n"," def update(self, state, action, reward, next_state, done):\n"," \"\"\"Adiciona uma experiência ao Replay Buffer.\n","\n"," Parâmetros\n"," ----------\n"," state: np.array\n"," Estado da transição.\n"," action: int\n"," Ação tomada.\n"," reward: float\n"," Recompensa recebida.\n"," state: np.array\n"," Estado seguinte.\n"," done: int\n"," Flag indicando se o episódio acabou.\n"," \"\"\"\n"," self.states[self.index] = state\n"," self.actions[self.index] = action\n"," self.rewards[self.index] = reward\n"," self.next_states[self.index] = next_state\n"," self.dones[self.index] = done\n","\n"," self.index = (self.index + 1) % self.max_length\n"," if self.length < self.max_length:\n"," self.length += 1\n","\n"," def sample(self, batch_size):\n"," \"\"\"Retorna um batch de experiências.\n"," \n"," Parâmetros\n"," ----------\n"," batch_size: int\n"," Tamanho do batch de experiências.\n","\n"," Retorna\n"," -------\n"," states: np.array\n"," Batch de estados.\n"," actions: np.array\n"," Batch de ações.\n"," rewards: np.array\n"," Batch de recompensas.\n"," next_states: np.array\n"," Batch de estados seguintes.\n"," dones: np.array\n"," Batch de flags indicando se o episódio acabou.\n"," \"\"\"\n"," # Escolhe índices aleatoriamente do Replay Buffer\n"," idxs = np.random.randint(0, self.length, size=batch_size)\n","\n"," return (self.states[idxs], self.actions[idxs], self.rewards[idxs], self.next_states[idxs], self.dones[idxs])"],"execution_count":3,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-u-dalX7m-Co"},"source":["## Agente"]},{"cell_type":"code","metadata":{"id":"nNBne9yyr4vC","executionInfo":{"status":"ok","timestamp":1606592474496,"user_tz":180,"elapsed":1551,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}}},"source":["import numpy as np\n","import torch\n","from torch import optim\n","\n","class DDPG:\n"," def __init__(self, observation_space, action_space, pi_lr=0.001, q_lr=0.001, gamma=0.99, tau=0.005, action_noise=0.1):\n"," self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n"," self.gamma = gamma\n"," self.tau = tau\n","\n"," self.memory = ExperienceReplay(10000, observation_space.shape)\n","\n"," self.action_noise = action_noise\n"," self.action_shape = action_space.shape\n"," self.action_high = action_space.high\n"," self.action_low = action_space.low\n","\n"," self.actor = Actor(observation_space.shape[0], action_space.shape[0], self.action_high, self.action_low).to(self.device)\n"," self.target_actor = Actor(observation_space.shape[0], action_space.shape[0], self.action_high, self.action_low).to(self.device)\n","\n"," self.critic = Critic(observation_space.shape[0], action_space.shape[0]).to(self.device)\n"," self.target_critic = Critic(observation_space.shape[0], action_space.shape[0]).to(self.device)\n","\n"," for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):\n"," target_param.data.copy_(param.data)\n","\n"," for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):\n"," target_param.data.copy_(param.data)\n","\n"," self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=pi_lr)\n"," self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=q_lr)\n","\n"," def act(self, state):\n"," state = torch.FloatTensor(state).unsqueeze(0).to(self.device)\n"," \n"," with torch.no_grad():\n"," action = self.actor.forward(state)\n"," \n"," action = action.squeeze(0).cpu().numpy()\n"," action += self.action_noise * np.random.randn(*self.action_shape)\n","\n"," return np.clip(action, self.action_low, self.action_high)\n","\n"," def remember(self, state, action, reward, next_state, done):\n"," self.memory.update(state, action, reward, next_state, done)\n","\n"," def train(self, batch_size=32, epochs=1):\n"," if batch_size > self.memory.length:\n"," return\n"," \n"," for epoch in range(epochs):\n"," (states, actions, rewards, next_states, dones) = self.memory.sample(batch_size)\n","\n"," states = torch.FloatTensor(states).to(self.device)\n"," actions = torch.FloatTensor(actions).unsqueeze(-1).to(self.device)\n"," rewards = torch.FloatTensor(rewards).unsqueeze(-1).to(self.device)\n"," next_states = torch.FloatTensor(next_states).to(self.device)\n"," dones = torch.FloatTensor(dones).unsqueeze(-1).to(self.device)\n","\n"," self._train_critic(states, actions, rewards, next_states, dones)\n"," self._train_actor(states, actions, rewards, next_states, dones)\n"," self.update_target()\n","\n"," def _train_critic(self, states, actions, rewards, next_states, dones):\n"," q = self.critic.forward(states, actions)\n","\n"," with torch.no_grad():\n"," a2 = self.target_actor.forward(next_states)\n"," q2 = self.target_critic.forward(next_states, a2)\n"," \n"," target = rewards + self.gamma * q2\n"," critic_loss = F.mse_loss(q, target)\n","\n"," self.critic_optimizer.zero_grad()\n"," critic_loss.backward() \n"," self.critic_optimizer.step()\n","\n"," def _train_actor(self, states, actions, rewards, next_states, dones):\n"," policy_loss = -self.critic.forward(states, self.actor.forward(states)).mean()\n"," \n"," self.actor_optimizer.zero_grad()\n"," policy_loss.backward()\n"," self.actor_optimizer.step()\n","\n"," def update_target(self):\n"," with torch.no_grad():\n"," for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):\n"," target_param.data.mul_(1 - self.tau)\n"," torch.add(target_param.data, param.data, alpha=self.tau, out=target_param.data)\n","\n"," for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):\n"," target_param.data.mul_(1 - self.tau)\n"," torch.add(target_param.data, param.data, alpha=self.tau, out=target_param.data)"],"execution_count":4,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iu7U6Vmzoydm"},"source":["## Treinando"]},{"cell_type":"code","metadata":{"id":"zI1aZ7DypOfN","executionInfo":{"status":"ok","timestamp":1606592474498,"user_tz":180,"elapsed":1551,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}}},"source":["import math\n","from collections import deque\n","\n","def train(agent, env, total_timesteps):\n"," total_reward = 0\n"," episode_returns = deque(maxlen=20)\n"," avg_returns = []\n","\n"," state = env.reset()\n"," timestep = 0\n"," episode = 0\n","\n"," while timestep < total_timesteps:\n"," action = agent.act(state)\n"," next_state, reward, done, _ = env.step(action)\n"," agent.remember(state, action, reward, next_state, done)\n"," loss = agent.train()\n"," timestep += 1\n","\n"," total_reward += reward\n","\n"," if done:\n"," episode_returns.append(total_reward)\n"," episode += 1\n"," next_state = env.reset()\n","\n"," if episode_returns:\n"," avg_returns.append(np.mean(episode_returns))\n","\n"," total_reward *= 1 - done\n"," state = next_state\n","\n"," ratio = math.ceil(100 * timestep / total_timesteps)\n","\n"," avg_return = avg_returns[-1] if avg_returns else np.nan\n"," \n"," print(f\"\\r[{ratio:3d}%] timestep = {timestep}/{total_timesteps}, episode = {episode:3d}, avg_return = {avg_return:10.4f}\", end=\"\")\n","\n"," return avg_returns"],"execution_count":5,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"cJIRJQ1NLVvV"},"source":["### Pendulum-v0\n","\n","![Pendulum-v0](https://mspries.github.io/img/jimmy-pendulum/pendulum.gif)"]},{"cell_type":"code","metadata":{"id":"1OGWuPmeox9X","executionInfo":{"status":"ok","timestamp":1606592474498,"user_tz":180,"elapsed":1548,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}}},"source":["import gym\n","\n","env = gym.make(\"Pendulum-v0\")"],"execution_count":6,"outputs":[]},{"cell_type":"code","metadata":{"id":"BMax8WJxo01m","executionInfo":{"status":"ok","timestamp":1606592478730,"user_tz":180,"elapsed":5777,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}}},"source":["agent = DDPG(env.observation_space, env.action_space)"],"execution_count":7,"outputs":[]},{"cell_type":"code","metadata":{"id":"IcWebqIBo2PG","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1606592595686,"user_tz":180,"elapsed":122665,"user":{"displayName":"Bernardo Coutinho","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gjbcuim7oGIm-uXpKRCJxDYg0Nhguq2a4_xKQcpjw=s64","userId":"08343358744938767290"}},"outputId":"f4bb8987-d3dd-416a-bccb-db5b32a09e35"},"source":["returns = train(agent, env, 15000)"],"execution_count":8,"outputs":[{"output_type":"stream","text":["[100%] timestep = 15000/15000, episode = 75, avg_return = -129.0745"],"name":"stdout"}]}]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seria legal colocar um gráfico do aumento da reward, assim como tem nos outros.
Co-authored-by: Nelson Alves Yamashita <46365985+nelsonayamashita@users.noreply.github.com>
Acho que dá pra dar uma reformulada no notebook pra aula de DDPG, e só depois botar as mudanças no repo. |
No description provided.