Skip to content

Commit

Permalink
[Generation] Fix Transition probs (huggingface#17311)
Browse files Browse the repository at this point in the history
* [Draft] fix transition probs

* up

* up

* up

* make it work

* fix

* finish

* update
  • Loading branch information
patrickvonplaten authored May 19, 2022
1 parent e8714c0 commit 518bd02
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 75 deletions.
45 changes: 39 additions & 6 deletions src/transformers/generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def process(
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
Expand Down Expand Up @@ -256,9 +257,16 @@ def process(
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
else:
beam_index = None

beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
)
else:
# add next predicted token since it is not eos_token
Expand Down Expand Up @@ -299,6 +307,7 @@ def finalize(
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

Expand All @@ -313,11 +322,13 @@ def finalize(
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)

# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
best_indices = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

# retrieve best hypotheses
Expand All @@ -327,30 +338,50 @@ def finalize(
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

# append to lists
# append hyp to lists
best.append(best_hyp)

# append indices to list
best_indices.append(best_index)

best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)

if len(best_indices) > 0 and best_indices[0] is not None:
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
else:
indices = None

# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)

if indices is not None:
indices.fill_(-1)

# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo

if indices is not None:
indices[i, : len(best_idx)] = torch.tensor(best_idx)

if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id

return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
"beam_indices": indices,
}
)

Expand Down Expand Up @@ -789,6 +820,7 @@ def finalize(

# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1

sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
Expand All @@ -801,6 +833,7 @@ def finalize(
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id

return UserDict(
{
"sequences": decoded,
Expand All @@ -826,15 +859,15 @@ def __len__(self):
"""
return len(self.beams)

def add(self, hyp: torch.LongTensor, sum_logprobs: float):
def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
Expand Down
Loading

0 comments on commit 518bd02

Please sign in to comment.