generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move BCO to separate BCOTrainer with fixes (#1869)
* kto_trainer: skip KL data for BCO * kto_trainer: BCO allow no positives or no negatives in batch * kto_trainer: make RunningMoments object serializable * add BCOTrainer * fix BCO UDM for not interleaved data * kto_trainer: remove unused UDM part * bco_trainer: add tests and docs, minor fixes * code style fixes * Update docs/source/bco_trainer.mdx Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * fix BCO UDM for bfloat16 * Update trl/trainer/bco_config.py * Update trl/trainer/bco_config.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/utils.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/bco_trainer.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/bco_config.py * Update _toctree.yml * Update trl/trainer/bco_config.py * Update trl/trainer/bco_trainer.py * RunningMoments, fix multi GPU serialization * fix tests --------- Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Seungjae Jung <seanexplode@gmail.com>
- Loading branch information
1 parent
6171cdd
commit 9929370
Showing
12 changed files
with
2,179 additions
and
351 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# BCO Trainer | ||
|
||
TRL supports the Binary Classifier Optimization (BCO). | ||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. | ||
For a full example have a look at [`examples/scripts/bco.py`]. | ||
|
||
## Expected dataset format | ||
|
||
The BCO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns: | ||
|
||
- `prompt` | ||
- `completion` | ||
- `label` | ||
|
||
for example: | ||
|
||
``` | ||
bco_dataset_dict = { | ||
"prompt": [ | ||
"Hey, hello", | ||
"How are you", | ||
"What is your name?", | ||
"What is your name?", | ||
"Which is the best programming language?", | ||
"Which is the best programming language?", | ||
"Which is the best programming language?", | ||
], | ||
"completion": [ | ||
"hi nice to meet you", | ||
"leave me alone", | ||
"I don't have a name", | ||
"My name is Mary", | ||
"Python", | ||
"C++", | ||
"Java", | ||
], | ||
"label": [ | ||
True, | ||
False, | ||
False, | ||
True, | ||
True, | ||
False, | ||
False, | ||
], | ||
} | ||
``` | ||
|
||
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`). | ||
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion. | ||
|
||
|
||
## Expected model format | ||
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. | ||
|
||
## Using the `BCOTrainer` | ||
|
||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response. | ||
|
||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder). | ||
|
||
|
||
|
||
```py | ||
training_args = BCOConfig( | ||
beta=0.1, | ||
) | ||
|
||
bco_trainer = BCOTrainer( | ||
model, | ||
model_ref, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
tokenizer=tokenizer, | ||
) | ||
``` | ||
After this one can then call: | ||
|
||
```py | ||
bco_trainer.train() | ||
``` | ||
|
||
## Underlying Distribution matching (UDM) | ||
|
||
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts. | ||
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts. | ||
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM. | ||
|
||
Choose an embedding model and tokenizer: | ||
|
||
```py | ||
embedding_model = AutoModel.from_pretrained(your_model_id) | ||
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id) | ||
|
||
# customize this function depending on your embedding model | ||
def embed_prompt(input_ids, attention_mask, model): | ||
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | ||
return outputs.last_hidden_state.mean(dim=1) | ||
|
||
embedding_model = Accelerator().prepare_model(self.embedding_model) | ||
embedding_func = partial(embed_prompt, model=embedding_model) | ||
``` | ||
|
||
Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: | ||
|
||
```py | ||
training_args = BCOConfig( | ||
beta=0.1, | ||
prompt_sample_size=512, | ||
) | ||
|
||
bco_trainer = BCOTrainer( | ||
model, | ||
model_ref, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
tokenizer=tokenizer, | ||
embedding_func=embedding_func, | ||
embedding_tokenizer=self.embedding_tokenizer, | ||
) | ||
|
||
bco_trainer.train() | ||
``` | ||
|
||
### For Mixture of Experts Models: Enabling the auxiliary loss | ||
|
||
MOEs are the most efficient if the load is about equally distributed between experts. | ||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. | ||
|
||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). | ||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). | ||
|
||
## BCOTrainer | ||
|
||
[[autodoc]] BCOTrainer | ||
|
||
## BCOConfig | ||
|
||
[[autodoc]] BCOConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.