Skip to content

Commit

Permalink
Add PPO and GAE (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert authored Feb 5, 2025
1 parent 1ba6254 commit 5b19066
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 6 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ jobs:
make
make files
# Need this if main domain is www.rlhfbook.com
# For now, it is rlhfbook.com, and seems to be working.
# - name: Create CNAME file
# run: |
# echo "www.rlhfbook.com" > build/html/CNAME

# pages deploy steps
- name: Setup Pages
uses: actions/configure-pages@v5
Expand Down
92 changes: 88 additions & 4 deletions chapters/11-policy-gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,52 @@ Other implementations of REINFORCE algorithms have been designed for language mo
*This section follows similar to [@achiam2018spinning].*

Proximal Policy Optimization (PPO) [@schulman2017proximal] is one of the foundational algorithms to Deep RL's successes (such as OpenAI's DOTA 5 [@berner2019dota] and large amounts of research).
The loss function is as follows:

For now, see: https://spinningup.openai.com/en/latest/algorithms/ppo.html
$$J(\theta) = \frac{1}{G}\sum_{i=1}^G \min\left(\frac{\pi_\theta(a_i|s)}{\pi_{\theta_{old}}(a_i|s)}A_i, \text{clip} \left( \frac{\pi_\theta(a_i|s)}{\pi_{\theta_{old}}(a_i|s)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right)).$$ {#eq:PPO_EQN}

#### Generalized Advantage Estimation (GAE)
Here we will explain the difference cases this loss function triggers given various advantages and policy ratios.
At an implementation level, the inner computations for PPO involve standard policy gradient and a clipped policy gradient.

To understand how different situations emerge, we can define the policy ratio as:

$$R(\theta) = \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)}$$ {#eq:PPO_POL_RATIO}

The first case is when the advantage is positive and the policy ratio exceeds $1+\varepsilon$ (meaning that the new policy is more likely to take said action), which is clipped, and the objective becomes:

$$J(\theta) = \min \left(R(\theta), (1 + \varepsilon)\right)A = (1 + \varepsilon)A $$

This will increase the probability ratio, making the action even more likely, but only up until the clipping parameter epsilon.
The similar conditions are when the advantage is still positive, but the likelihood ratio shifts.

For positive advantage and ratio less than $1-\varepsilon$, a we get a partially substituted equation:

$$J(\theta) = \min \left(R(\theta), (1 - \varepsilon)\right)A$$

That reduces to

$$J(\theta) = R(\theta)A$$

because of the less than assumption.

Similarly, if the probability ratio is not clipping, the objective also reduces to the $\min(R(\theta),R(\theta))$, yielding a standard policy gradient with an advantage estimator.

If the advantage is negative, this looks similar. A clipped objective will occur when $R(\theta) < (1-\varepsilon)$, appearing through:

$$J(\theta) = \min \left(R(\theta)A, (1 - \varepsilon)A\right),$$

Which, because $A<0$ we have $R(\theta)A > (1-\varepsilon)A$ and can flip the min to the max when pulling $A$ from the equation, is equivalent to

$$J(\theta) = \max \left(R(\theta), (1 - \varepsilon)\right)A.$$

Then the objective becomes:

$$J(\theta) = (1 - \varepsilon)A$$

The other cases follow as above, inverted, and are left as an exercise to the reader.

All of these are designed to make the behaviors where advantage is positive more likely and keep the gradient step within the trust region.
It is crucial to remember that PPO within the trust region is the same as standard forms of policy gradient.

### Group Relative Policy Optimization

Expand All @@ -172,6 +213,7 @@ To state this formally, the GRPO objective is very similar to the PPO objective

$$J(\theta) = \frac{1}{G}\sum_{i=1}^G \left(\min\left(\frac{\pi_\theta(a_i|s)}{\pi_{\theta_{old}}(a_i|s)}A_i, \text{clip} \left( \frac{\pi_\theta(a_i|s)}{\pi_{\theta_{old}}(a_i|s)}, 1-\varepsilon, 1+\varepsilon \right) A_i \right) - \beta D_{KL}(\pi_\theta||\pi_{ref})\right).$$

Note that relative to PPO, the standard implementation of GRPO includes the KL distance in the loss.
With the advantage computation for the completion index $i$:

$$A_i = \frac{r_i - \text{mean}({r_1, r_2, \cdots, r_G})}{\text{std}({r_1, r_2, \cdots, r_G})}.$$ {#eq:GRPO_ADV}
Expand Down Expand Up @@ -351,14 +393,56 @@ with torch.no_grad():

For more details on how to interpret this code, see the PPO section above.

## Auxiliary Topics

### Generalized Advantage Estimation (GAE)

Generalized Advantage Estimation (GAE) is an alternate method to compute the advantage for policy gradient algorithms [@schulman2015high] that better balances the bias-variance tradeoff.
Traditional single-step advantage estimates often suffer from high variance, while using complete trajectories can introduce too much bias.
GAE works by combining two ideas -- multi-step prediction and weighted running average (or just one of these).

Advantage estimates can take many forms, but we can define a $k$ step advantage estimator (similar to the TD residual at the beginning of the chapter) as follows:

$$
\hat{A}_t^{(n)} = \begin{cases}
r_t + \gamma V(s_{t+1}) - V(s_t), & n = 1 \\
r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2}) - V(s_t), & n = 2 \\
\vdots \\
r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots - V(s_t), & n = \infty
\end{cases}
$$ {#eq:K_STEP_ADV}
Here a shorter $k$ will have lower variance but higher bias as we are attributing more learning power to each trajectory -- it can overfit.
GAE attempts to generalize this formulation into a weighted multi-step average instead of a specific $k$.
To start, we must define the temporal difference (TD) residual of predicted value.
$$
\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_t)
$$ {#eq:TD_RESIDUAL}
To utilize this, we introduce another variable $\lambda$ as the GAE mixing parameter. This folds into an exponential decay of future advantages we wish to estimate:
$$
\begin{array}{l}
\hat{A}_t^{GAE(\gamma,\lambda)} = (1-\lambda)(\hat{A}_t^{(1)} + \lambda\hat{A}_t^{(2)} + \lambda^2\hat{A}_t^{(3)} + \cdots) \\
= (1-\lambda)(\delta_t^V + \lambda(\delta_t^V + \gamma\delta_{t+1}^V) + \lambda^2(\delta_t^V + \gamma\delta_{t+1}^V + \gamma^2\delta_{t+2}^V) + \cdots) \\
= (1-\lambda)(\delta_t^V(1 + \lambda + \lambda^2 + \cdots) + \gamma\delta_{t+1}^V(\lambda + \lambda^2 + \cdots) + \cdots) \\
= (1-\lambda)(\delta_t^V\frac{1}{1-\lambda} + \gamma\delta_{t+1}^V\frac{\lambda}{1-\lambda} + \cdots) \\
= \sum_{l=0}^{\infty}(\gamma\lambda)^l\delta_{t+l}^V
\end{array}
$$ {#eq:GAE_DFN}
Intuitively, this can be used to average of multi-step estimates of Advantage in an elegant fashion.
*For further reading, see [@seita2017gae].*
## KL Controllers
### KL Controllers
TODO: adaptive vs static KL control
See table 10 for impelementation details in tulu 2.5 paper
## Double regularization
### Double regularization
Many popular policy gradient algorithms from Deep Reinforcement Learning originated due to the need to control the learning process of the agent.
In RLHF, as discussed extensively in Chapter 8 on Regularization and in Chapter 4 on Problem Formulation, there is a built in regularization term via the distance penalty relative to the original policy one is finetuning.
Expand Down
7 changes: 6 additions & 1 deletion chapters/bib.bib
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,12 @@ @article{schulman2015high
journal={arXiv preprint arXiv:1506.02438},
year={2015}
}

@misc{seita2017gae,
author = {Daniel Seita},
title = {Notes on the Generalized Advantage Estimation Paper},
year = {2017},
url = {https://danieltakeshi.github.io/2017/04/02/notes-on-the-generalized-advantage-estimation-paper/}
}
@article{lambert2020objective,
title={Objective mismatch in model-based reinforcement learning},
author={Lambert, Nathan and Amos, Brandon and Yadan, Omry and Calandra, Roberto},
Expand Down
2 changes: 1 addition & 1 deletion metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ lang: en-US
mainlang: english
otherlang: english
tags: [rlhf, ebook, ai, ml]
date: 19 January 2025
date: 04 February 2025
abstract: |
Reinforcement learning from human feedback (RLHF) has become an important technical and storytelling tool to deploy the latest machine learning systems.
In this book, we hope to give a gentle introduction to the core methods for people with some level of quantitative background.
Expand Down
1 change: 1 addition & 0 deletions templates/html.html
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ <h2>Abstract</h2>
<body>
<section id="changelog" style="padding: 20px; text-align: center;">
<h2>Changelog</h2>
<p><strong>4 Feb. 2025</strong>: PPO and GAE </p>
<p><strong>2 Feb. 2025</strong>: Added changelog, revamped introduction, </p>
</section>
<section id="acknowledgements" style="padding: 20px; text-align: center;">
Expand Down

0 comments on commit 5b19066

Please sign in to comment.