Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: group_beam_search requires diversity_penalty>0.0 #24456

Merged
merged 2 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,9 @@ the `num_beams` greater than 1, and set `do_sample=True` to use this decoding st

The diverse beam search decoding strategy is an extension of the beam search strategy that allows for generating a more diverse
set of beam sequences to choose from. To learn how it works, refer to [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf).
This approach has two main parameters: `num_beams` and `num_beam_groups`.
The groups are selected to ensure they are distinct enough compared to the others, and regular beam search is used within each group.
This approach has three main parameters: `num_beams`, `num_beam_groups`, and `diversity_penalty`.
The diversily penalty ensures the outputs are distinct across groups, and beam search is used within each group.


```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
Expand All @@ -328,9 +329,9 @@ The groups are selected to ensure they are distinct enough compared to the other

>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30)
>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30, diversity_penalty=1.0)
>>> tokenizer.decode(outputs[0], skip_special_tokens=True)
'The Design Principles are a set of universal design principles that can be applied to any location, climate and culture, and they allow us to design the most efficient and sustainable human habitation and food production systems.'
'The aim of this project is to create a new type of living system, one that is more sustainable and efficient than the current one.'
```

This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,11 @@ def generate(
if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

if generation_config.diversity_penalty == 0.0:
raise ValueError(
"`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical."
)

if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")

Expand Down
1 change: 1 addition & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2366,6 +2366,7 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):
num_beams=2,
num_beam_groups=2,
num_return_sequences=2,
diversity_penalty=1.0,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
Expand Down