Skip to content

Commit

Permalink
[DOCS] Add descriptive docstring to MinNewTokensLength (#25196)
Browse files Browse the repository at this point in the history
* Add descriptive docstring to MinNewTokensLength

It addresses #24783

* Refine the differences between `min_length` and `min_new_tokens`

* Remove extra line

* Remove extra arguments in generate

* Add a missing space

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Run the linter

* Add clarification comments

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
nablabits and amyeroberts authored Aug 8, 2023
1 parent 080a971 commit a23ac36
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,53 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
Note that for decoder-only models, such as Llama2, `min_length` will compute the length of `prompt + newly
generated tokens` whereas for other models it will behave as `min_new_tokens`, that is, taking only into account
the newly generated ones.
Args:
prompt_length_to_skip (`int`):
The input tokens length.
The input tokens length. Not a valid argument when used with `generate` as it will automatically assign the
input length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model.config.pad_token_id = model.config.eos_token_id
>>> model.generation_config.pad_token_id = model.config.eos_token_id
>>> input_context = "Hugging Face Company is"
>>> input_ids = tokenizer.encode(input_context, return_tensors="pt")
>>> # Without `eos_token_id`, it will generate the default length, 20, ignoring `min_new_tokens`
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=30)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company that has been working on a new product for the past year.
>>> # If `eos_token_id` is set to ` company` it will take into account how many `min_new_tokens` have been generated
>>> # before stopping. Note that ` Company` (5834) and ` company` (1664) are not actually the same token, and even
>>> # if they were ` Company` would be ignored by `min_new_tokens` as it excludes the prompt.
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=1, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company
>>> # Increasing `min_new_tokens` will bury the first occurrence of ` company` generating a different sequence.
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=2, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a new company
>>> # If no more occurrences of the `eos_token` happen after `min_new_tokens` it returns to the 20 default tokens.
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=10, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a new and innovative brand of facial recognition technology that is designed to help you
```
"""

def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
Expand Down Expand Up @@ -194,7 +233,6 @@ class TemperatureLogitsWarper(LogitsWarper):
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> model.config.pad_token_id = model.config.eos_token_id
Expand Down

0 comments on commit a23ac36

Please sign in to comment.