Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advantage Actor Critic (A2C) Model #598

Merged
merged 46 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6f1afc9
a2c draft
blahBlahhhJ Mar 19, 2021
d6e6652
finish logic but not training
blahBlahhhJ Mar 19, 2021
b9ee7e9
cli pass converge on cartpole environment
blahBlahhhJ Mar 19, 2021
9a3a309
test by calling from package, fix code formatting, ready for review
blahBlahhhJ Mar 20, 2021
ed891bc
add tests, fix formatting
blahBlahhhJ Mar 20, 2021
415437b
fix typo
blahBlahhhJ Mar 20, 2021
47932be
fix tests, ready for review
blahBlahhhJ Mar 20, 2021
f2b19c8
Add A2C to __init__
akihironitta Mar 20, 2021
22f3b85
Update docs
akihironitta Mar 20, 2021
8221035
Fix formatting
akihironitta Mar 20, 2021
16bcd4a
Use self.hparams and remove n_steps
akihironitta Mar 20, 2021
e2ffd14
Update CHANGELOG
akihironitta Mar 20, 2021
a06528e
Merge branch 'master' into feature/596_a2c
blahBlahhhJ Mar 20, 2021
e397c47
fix typing hints, add documentation for A2C
blahBlahhhJ Mar 21, 2021
245feb0
minor formatting issue
blahBlahhhJ Mar 21, 2021
9211f20
delete print and add normalization
blahBlahhhJ Mar 21, 2021
17fc418
Adjust fig size
akihironitta Mar 21, 2021
b26b271
Fix typing
akihironitta Mar 21, 2021
f7d0a74
switch to function based pytest
blahBlahhhJ Apr 19, 2021
a1f2949
Merge branch 'feature/596_a2c' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Apr 19, 2021
85c407e
fix formatting
blahBlahhhJ Apr 19, 2021
0d10f0a
fix import
blahBlahhhJ Apr 19, 2021
cc9909b
fix format again
blahBlahhhJ Apr 19, 2021
46785bd
fix format again again
blahBlahhhJ Apr 19, 2021
bf14f13
ad another function test
blahBlahhhJ May 8, 2021
53a5703
Merge branch 'master' into feature/596_a2c
Borda Jun 24, 2021
83f5cef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
fa64829
formt
Borda Jun 24, 2021
8e1c783
Merge branch 'feature/596_a2c' of https://github.com/blahBlahhhJ/ligh…
Borda Jun 24, 2021
023912b
Apply suggestions from code review
Borda Jun 24, 2021
6167d04
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 25, 2021
53ff8cc
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 25, 2021
1159c63
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 29, 2021
1faa5f5
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 1, 2021
73b240f
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 4, 2021
cdada9d
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 4, 2021
89a3b1a
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 7, 2021
c90beb9
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 13, 2021
baa512a
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 13, 2021
eb30b22
fix test
blahBlahhhJ Jul 20, 2021
b37888d
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 26, 2021
a509d04
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 28, 2021
74bfa34
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 9, 2021
57542aa
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 13, 2021
d717a71
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 13, 2021
4687f9a
Update CHANGELOG.md
Borda Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased] - 2021-MM-DD

### Added

- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))


## [0.3.4] - 2021-06-17

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
114 changes: 106 additions & 8 deletions docs/source/reinforce_learn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ Contributions by: `Donal Byrne <https://github.com/djbyrne>`_

------------

.. note::
.. note::=
Borda marked this conversation as resolved.
Show resolved Hide resolved
RL models currently only support CPU and single GPU training with `distributed_backend=dp`.
Full GPU support will be added in later updates.

------------

DQN Models
----------
Expand Down Expand Up @@ -86,7 +87,7 @@ Example::
trainer = Trainer()
trainer.fit(dqn)

.. autoclass:: pl_bolts.models.rl.dqn_model.DQN
.. autoclass:: pl_bolts.models.rl.DQN
:noindex:

---------------
Expand Down Expand Up @@ -150,7 +151,7 @@ Example::
trainer = Trainer()
trainer.fit(ddqn)

.. autoclass:: pl_bolts.models.rl.double_dqn_model.DoubleDQN
.. autoclass:: pl_bolts.models.rl.DoubleDQN
:noindex:

---------------
Expand Down Expand Up @@ -240,7 +241,7 @@ Example::
trainer = Trainer()
trainer.fit(dueling_dqn)

.. autoclass:: pl_bolts.models.rl.dueling_dqn_model.DuelingDQN
.. autoclass:: pl_bolts.models.rl.DuelingDQN
:noindex:

--------------
Expand Down Expand Up @@ -326,7 +327,7 @@ Example::
trainer = Trainer()
trainer.fit(noisy_dqn)

.. autoclass:: pl_bolts.models.rl.noisy_dqn_model.NoisyDQN
.. autoclass:: pl_bolts.models.rl.NoisyDQN
:noindex:

--------------
Expand Down Expand Up @@ -519,7 +520,7 @@ Example::
trainer = Trainer()
trainer.fit(per_dqn)

.. autoclass:: pl_bolts.models.rl.per_dqn_model.PERDQN
.. autoclass:: pl_bolts.models.rl.PERDQN
:noindex:


Expand Down Expand Up @@ -611,7 +612,7 @@ Example::
trainer = Trainer()
trainer.fit(reinforce)

.. autoclass:: pl_bolts.models.rl.reinforce_model.Reinforce
.. autoclass:: pl_bolts.models.rl.Reinforce
:noindex:

--------------
Expand Down Expand Up @@ -664,5 +665,102 @@ Example::
trainer = Trainer()
trainer.fit(vpg)

.. autoclass:: pl_bolts.models.rl.vanilla_policy_gradient_model.VanillaPolicyGradient
.. autoclass:: pl_bolts.models.rl.VanillaPolicyGradient
:noindex:

--------------

Actor-Critic Models
-------------------
The following models are based on Actor Critic. Actor Critic conbines the approaches of value-based learning (the DQN family)
and the policy-based learning (the PG family) by learning the value function as well as the policy distribution. This approach
updates the policy network according to the policy gradient, and updates the value network to fit the discounted rewards.

Actor Critic Key Points:
- Actor outputs a distribution of actions for controlling the agent
- Critic outputs a value of current state for policy update suggestion
- The addition of critic allows the model to do n-step training instead of generating an entire trajectory

Advantage Actor Critic (A2C)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

(Asynchronous) Advantage Actor Critic model introduced in `Asynchronous Methods for Deep Reinforcement Learning <https://arxiv.org/abs/1602.01783>`_
Paper authors: Volodymyr Mnih, Adrià Puigdomènech Badia, Mehdi Mirza, Alex Graves, Timothy P. Lillicrap, Tim Harley, David Silver, Koray Kavukcuoglu

Original implementation by: `Jason Wang <https://github.com/blahBlahhhJ>`_

Advantage Actor Critic (A2C) is the classical actor critic approach in reinforcement learning. The underlying neural
network has an actor head and a critic head to output action distribution as well as value of current state. Usually the
first few layers are shared by the two heads to prevent learning similar stuff twice. It builds upon the idea of using a
baseline of average reward to reduce variance (in VPG) by using the critic as a baseline which could theoretically have
better performance.

The algorithm can use an n-step training approach instead of generating an entire trajectory. The algorithm is as follows:

1. Initialize our network.
2. Rollout n steps and save the transitions (states, actions, rewards, values, dones).
3. Calculate the n-step (discounted) return by bootstrapping the last value.

.. math::

G_{n+1} = V_{n+1}, G_t = r_t + \gamma G_{t+1} \ \forall t \in [0,n]

4. Calculate actor loss using values as baseline.

.. math::

L_{actor} = - \frac1n \sum_t (G_t - V_t) \log \pi (a_t | s_t)

5. Calculate critic loss using returns as target.

.. math::
L_{critic} = \frac1n \sum_t (V_t - G_t)^2

6. Calculate entropy bonus to encourage exploration.

.. math::

H_\pi = - \frac1n \sum_t \pi (a_t | s_t) \log \pi (a_t | s_t)

7. Calculate total loss as a weighted sum of the three components above.

.. math::

L = L_{actor} + \beta_{critic} L_{critic} - \beta_{entropy} H_\pi

8. Perform gradient descent to update our network.

.. note::
The current implementation only support discrete action space, and has only been tested on the CartPole environment.

A2C Benefits
~~~~~~~~~~~~~~~

- Combines the benefit from value-based learning and policy-based learning

- Further reduces variance using the critic as a value estimator

A2C Results
~~~~~~~~~~~~~~~~

Hyperparameters:

- Batch Size: 32
- Learning Rate: 0.001
- Entropy Beta: 0.01
- Critic Beta: 0.5
- Gamma: 0.99

.. image:: _images/rl_benchmark/cartpole_a2c_results.jpg
:width: 300
:alt: A2C Results

Example::

from pl_bolts.models.rl import AdvantageActorCritic
a2c = AdvantageActorCritic("CartPole-v0")
trainer = Trainer()
trainer.fit(a2c)

.. autoclass:: pl_bolts.models.rl.AdvantageActorCritic
:noindex:
16 changes: 9 additions & 7 deletions pl_bolts/models/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401
from pl_bolts.models.rl.dqn_model import DQN # noqa: F401
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401
from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401
from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401
from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic
from pl_bolts.models.rl.double_dqn_model import DoubleDQN
from pl_bolts.models.rl.dqn_model import DQN
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
from pl_bolts.models.rl.per_dqn_model import PERDQN
from pl_bolts.models.rl.reinforce_model import Reinforce
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient

__all__ = [
"AdvantageActorCritic",
"DoubleDQN",
"DQN",
"DuelingDQN",
Expand Down
Loading