Skip to content

Commit

Permalink
docs: add initial version of docs for PPOTrainer (#665)
Browse files Browse the repository at this point in the history
* docs: add initial version of docs for  `PPOTrainer`

* Apply suggestions from code review Leandro

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* updated docs based on feedback leandro
- specified reference to reward model
- added batched generator
- added line of saving model
- remove reference model

* Apply suggestions from code review

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
  • Loading branch information
davidberenstein1957 and lvwerra authored Sep 14, 2023
1 parent ca0af39 commit 3f7710a
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 2 deletions.
6 changes: 4 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
- sections:
- sections:
- local: index
title: TRL
- local: quickstart
Expand All @@ -23,6 +23,8 @@
title: Reward Model Training
- local: sft_trainer
title: Supervised Fine-Tuning
- local: ppo_trainer
title: PPO Trainer
- local: best_of_n
title: Best of N Sampling
- local: dpo_trainer
Expand All @@ -32,7 +34,7 @@
- local: text_environments
title: Text Environments
title: API
- sections:
- sections:
- local: sentiment_tuning
title: Sentiment Tuning
- local: lora_tuning_peft
Expand Down
151 changes: 151 additions & 0 deletions docs/source/ppo_trainer.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# PPO Trainer

TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback).

The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.

## Expected dataset format

The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.

Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop.

Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset:

```py
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
dataset = dataset.rename_column("prompt", "query")
dataset = dataset.remove_columns(["meta", "completion"])
```

Resulting in the following subset of the dataset:

```py
ppo_dataset_dict = {
"prompt": [
"Explain the moon landing to a 6 year old in a few sentences.",
"Why aren’t birds real?",
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
"How can I steal from a grocery store without getting caught?",
"Why is it important to eat socks after meditating? "
]
}
```

## Using the `PPOTrainer`

For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response.

### Initializing the `PPOTrainer`

The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.

```py
from trl import PPOConfig

config = PPOConfig(
model_name="gpt2",
learning_rate=1.41e-5,
)
```

Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows:

```py
from transformers import AutoTokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token
```

As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use.

```py
from transformers import pipeline

reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
```

Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop:

```py
def tokenize(sample):
sample["input_ids"] = tokenizer.encode(sample["query"])
return sample

dataset = dataset.map(tokenize, batched=False)
```

Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model.

```py
from trl import PPOTrainer

ppo_trainer = PPOTrainer(
model,
config=config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```

### Starting the training loop

Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above.

To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training).

```py
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
```

We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm.

```py
from tqdm import tqdm

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]

#### Get response from SFTModel
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

#### Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_model(texts)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)

#### Save model
ppo_trainer.save_model("my_ppo_model")
```

## Logging

While training and evaluating we log the following metrics:

- `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc.
- `batch`: The batch of data used to train the SFT model.
- `rewards`: The rewards obtained from the Reward model.

## PPOTrainer

[[autodoc]] PPOTrainer

[[autodoc]] PPOConfig

0 comments on commit 3f7710a

Please sign in to comment.