diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 8783294170..01c7634125 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -1,4 +1,4 @@ -- sections: +- sections: - local: index title: TRL - local: quickstart @@ -23,12 +23,14 @@ title: Reward Model Training - local: sft_trainer title: Supervised Fine-Tuning + - local: ppo_trainer + title: PPO Trainer - local: best_of_n title: Best of N Sampling - local: dpo_trainer title: DPO Trainer title: API -- sections: +- sections: - local: sentiment_tuning title: Sentiment Tuning - local: lora_tuning_peft diff --git a/docs/source/ppo_trainer.mdx b/docs/source/ppo_trainer.mdx new file mode 100644 index 0000000000..16c99c5ab9 --- /dev/null +++ b/docs/source/ppo_trainer.mdx @@ -0,0 +1,151 @@ +# PPO Trainer + +TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback). + +The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm. + +## Expected dataset format + +The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm. + +Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop. + +Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset: + +```py +from datasets import load_dataset + +dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train") +dataset = dataset.rename_column("prompt", "query") +dataset = dataset.remove_columns(["meta", "completion"]) +``` + +Resulting in the following subset of the dataset: + +```py +ppo_dataset_dict = { + "prompt": [ + "Explain the moon landing to a 6 year old in a few sentences.", + "Why aren’t birds real?", + "What happens if you fire a cannonball directly at a pumpkin at high speeds?", + "How can I steal from a grocery store without getting caught?", + "Why is it important to eat socks after meditating? " + ] +} +``` + +## Using the `PPOTrainer` + +For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response. + +### Initializing the `PPOTrainer` + +The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer. + +```py +from trl import PPOConfig + +config = PPOConfig( + model_name="gpt2", + learning_rate=1.41e-5, +) +``` + +Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows: + +```py +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + +model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) +tokenizer = AutoTokenizer.from_pretrained(config.model_name) + +tokenizer.pad_token = tokenizer.eos_token +``` + +As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use. + +```py +from transformers import pipeline + +reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb") +``` + +Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop: + +```py +def tokenize(sample): + sample["input_ids"] = tokenizer.encode(sample["query"]) + return sample + +dataset = dataset.map(tokenize, batched=False) +``` + +Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model. + +```py +from trl import PPOTrainer + +ppo_trainer = PPOTrainer( + model, + config=config, + train_dataset=train_dataset, + tokenizer=tokenizer, +) +``` + +### Starting the training loop + +Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above. + +To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training). + +```py +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, +} +``` + +We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm. + +```py +from tqdm import tqdm + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + #### Get response from SFTModel + response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) + batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + #### Compute reward score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_model(texts) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + + #### Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + +#### Save model +ppo_trainer.save_model("my_ppo_model") +``` + +## Logging + +While training and evaluating we log the following metrics: + +- `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc. +- `batch`: The batch of data used to train the SFT model. +- `rewards`: The rewards obtained from the Reward model. + +## PPOTrainer + +[[autodoc]] PPOTrainer + +[[autodoc]] PPOConfig \ No newline at end of file