Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: sequence bias can handle same terminations #24822

Merged
merged 3 commits into from
Jul 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 6 additions & 28 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,7 @@ def __init__(self, sequence_bias: Dict[Tuple[int], float]):

# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
# is infered in the first usage, which inhibits initializing here)
self.sequences_length_greater_than_1 = []
self.length_1_bias = None
self.length_greather_than_1_bias = None
self.prepared_bias_variables = False

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand All @@ -642,11 +640,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
bias += self.length_1_bias

# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
# `matching_mask` is a (batch_size, vocab_size) boolean mask that is True for all tokens whose corresponding
# bias should be applied. The bias is applied on the last token of the sequence, if (and only if) the sequence
# may become complete this iteration.
matching_mask = torch.zeros_like(scores, dtype=torch.bool)
for sequence_ids in self.sequences_length_greater_than_1:
for sequence_ids, sequence_bias in self.sequence_bias.items():
if len(sequence_ids) == 1: # the sequence is of length 1, already applied
continue
if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
continue
prefix_length = len(sequence_ids) - 1
Expand All @@ -655,25 +651,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
input_ids[:, -prefix_length:],
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
).prod(dim=1)
matching_mask[:, last_token] |= matching_rows.bool()
bias += torch.where(
matching_mask,
self.length_greather_than_1_bias,
torch.tensor(0.0, device=self.length_greather_than_1_bias.device),
)
bias[:, last_token] += sequence_bias * matching_rows

# 5 - apply the bias to the scores
scores = scores + bias
return scores

def _prepare_bias_variables(self, scores: torch.FloatTensor):
vocabulary_size = scores.shape[-1]
sequence_bias = self.sequence_bias
tokens_with_bias = []

# Check biased tokens out of bounds
invalid_biases = []
for sequence_ids in sequence_bias:
for sequence_ids in self.sequence_bias:
for token_id in sequence_ids:
if token_id >= vocabulary_size:
invalid_biases.append(token_id)
Expand All @@ -686,20 +675,9 @@ def _prepare_bias_variables(self, scores: torch.FloatTensor):
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# with simpler logic.
self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
self.length_greather_than_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
for sequence_ids, bias in sequence_bias.items():
for sequence_ids, bias in self.sequence_bias.items():
if len(sequence_ids) == 1:
self.length_1_bias[sequence_ids[-1]] = bias
else:
self.sequences_length_greater_than_1.append(sequence_ids)
if self.length_greather_than_1_bias[sequence_ids[-1]] != 0.0:
raise ValueError(
"Setting a bias on sequences that share a common token termination is not yet supported. "
"Please open an issue if you see this error message (after checking that it doesn't already "
"exist)."
)
self.length_greather_than_1_bias[sequence_ids[-1]] = bias
tokens_with_bias.append(sequence_ids[-1])

self.prepared_bias_variables = True

Expand Down