From 518bd02c9b71291333ef374f055a4d1ac3042654 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 19 May 2022 22:17:02 +0200 Subject: [PATCH] [Generation] Fix Transition probs (#17311) * [Draft] fix transition probs * up * up * up * make it work * fix * finish * update --- src/transformers/generation_beam_search.py | 45 +++++++- src/transformers/generation_utils.py | 106 ++++++++---------- .../generation/test_generation_beam_search.py | 21 +++- tests/generation/test_generation_utils.py | 100 ++++++++++++++++- 4 files changed, 197 insertions(+), 75 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 7a9ffe7908504f..2dfb275c2c3425 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -212,6 +212,7 @@ def process( next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, + beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor]: cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) @@ -256,9 +257,16 @@ def process( is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: continue + if beam_indices is not None: + beam_index = beam_indices[batch_beam_idx] + beam_index = beam_index + (next_index,) + else: + beam_index = None + beam_hyp.add( input_ids[batch_beam_idx].clone(), next_score.item(), + beam_indices=beam_index, ) else: # add next predicted token since it is not eos_token @@ -299,6 +307,7 @@ def finalize( max_length: int, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, + beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) @@ -313,11 +322,13 @@ def finalize( batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score) + beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] + best_indices = [] best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) # retrieve best hypotheses @@ -327,23 +338,42 @@ def finalize( best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] best_hyp = best_hyp_tuple[1] + best_index = best_hyp_tuple[2] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) - # append to lists + # append hyp to lists best.append(best_hyp) + + # append indices to list + best_indices.append(best_index) + best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos sent_lengths_max = sent_lengths.max().item() + 1 sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + + if len(best_indices) > 0 and best_indices[0] is not None: + indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + else: + indices = None + # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) + + if indices is not None: + indices.fill_(-1) + # fill with hypotheses and eos_token_id if the latter fits in - for i, hypo in enumerate(best): + for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): decoded[i, : sent_lengths[i]] = hypo + + if indices is not None: + indices[i, : len(best_idx)] = torch.tensor(best_idx) + if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id @@ -351,6 +381,7 @@ def finalize( { "sequences": decoded, "sequence_scores": best_scores, + "beam_indices": indices, } ) @@ -789,6 +820,7 @@ def finalize( # prepare for adding eos sent_lengths_max = sent_lengths.max().item() + 1 + sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) # shorter batches are padded if needed @@ -801,6 +833,7 @@ def finalize( decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id + return UserDict( { "sequences": decoded, @@ -826,15 +859,15 @@ def __len__(self): """ return len(self.beams) - def add(self, hyp: torch.LongTensor, sum_logprobs: float): + def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None): """ Add a new hypothesis to the list. """ score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) if len(self) < self.num_beams or score > self.worst_score: - self.beams.append((score, hyp)) + self.beams.append((score, hyp, beam_indices)) if len(self) > self.num_beams: - sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) + sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) del self.beams[sorted_next_scores[0][1]] self.worst_score = sorted_next_scores[1][0] else: diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 8e4bcbd7ad0883..0c8187acc7e196 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -217,8 +217,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, input_ids.shape[-1])`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. @@ -230,7 +230,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + beam_indices: Optional[torch.LongTensor] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -254,8 +254,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, config.vocab_size)`). beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, max_length-1)`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, @@ -278,7 +278,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + beam_indices: Optional[torch.LongTensor] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -303,8 +303,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, input_ids.shape[-1])`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. @@ -316,7 +316,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + beam_indices: Optional[torch.LongTensor] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -339,9 +339,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, max_length-1)`. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. @@ -362,7 +362,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None + beam_indices: Optional[torch.LongTensor] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -811,32 +811,33 @@ def compute_transition_beam_scores( """compute the transition probabilities of sequences given generation scores and beam indices""" - # reshape scores as [vocab_size * batch_size, # generation steps] + # 1. reshape scores as [vocab_size * batch_size, # generation steps] # with batch_size being 2 * vocab_size and # generation steps being # seq_len - input_length scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) - # start of generated tokens - cut_idx = sequences.shape[-1] - scores.shape[-1] - # adjust for beam indices - beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size - # compute real indices + # 2. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 3. Set indices of beams that finished early to 0 + # such indices will be masked correctly afterwards + beam_indices[beam_indices_mask] = 0 + + # 4. multiply beam_indices with vocab size to gather correctly from scores + beam_sequence_indices = beam_indices * self.config.vocab_size + + # 5. Define which indices contributed to scores + cut_idx = sequences.shape[-1] - max_beam_length indices = sequences[:, cut_idx:] + beam_sequence_indices - # gather scores and run + + # 6. Compute scores transition_scores = scores.gather(0, indices) - # make sure that if EOS token was used before length of sequence `sequence.shape[-1]` - # get first occurence of EOS token - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - if eos_token_id is not None: - is_eos_token_id = sequences[:, cut_idx:] == eos_token_id - # make sure first eos token still contributes to transition probs - is_eos_token_id[:, -1] = False - is_eos_token_id = is_eos_token_id.roll(1, -1) - # all indices after eos shoud be masked - zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool() - # zero out padded probs - transition_scores.masked_fill_(zero_transition_prob_mask, 0.0) + # 7. Mask out transition_scores of beams that stopped early + transition_scores[beam_indices_mask] = 0 return transition_scores @@ -2256,6 +2257,7 @@ def beam_search( next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + beam_indices=beam_indices, ) beam_scores = beam_outputs["next_beam_scores"] @@ -2290,25 +2292,19 @@ def beam_search( pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, + beam_indices=beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None - else: - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) - ) - beam_indices = sum(beam_indices, ()) if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, - beam_indices=beam_indices, + beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -2320,7 +2316,7 @@ def beam_search( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, - beam_indices=beam_indices, + beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) @@ -2580,6 +2576,7 @@ def beam_sample( next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + beam_indices=beam_indices, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] @@ -2613,25 +2610,19 @@ def beam_sample( pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, + beam_indices=beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None - else: - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) - ) - beam_indices = sum(beam_indices, ()) if self.config.is_encoder_decoder: return BeamSampleEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, - beam_indices=beam_indices, + beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -2643,7 +2634,7 @@ def beam_sample( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, - beam_indices=beam_indices, + beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) @@ -2909,6 +2900,7 @@ def group_beam_search( next_tokens = next_tokens % vocab_size # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, @@ -2916,6 +2908,7 @@ def group_beam_search( next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + beam_indices=process_beam_indices, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] @@ -2971,6 +2964,7 @@ def group_beam_search( else: this_peer_finished = True + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, @@ -2979,26 +2973,19 @@ def group_beam_search( pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None - else: - beam_indices = sum(beam_indices, ()) - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) - ) - beam_indices = sum(beam_indices, ()) if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, - beam_indices=beam_indices, + beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -3010,6 +2997,7 @@ def group_beam_search( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) diff --git a/tests/generation/test_generation_beam_search.py b/tests/generation/test_generation_beam_search.py index 7ca4ac9b08baa6..885cefa62cbd51 100644 --- a/tests/generation/test_generation_beam_search.py +++ b/tests/generation/test_generation_beam_search.py @@ -126,7 +126,11 @@ def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_sc tokens = next_tokens.clone() tokens[:, : self.num_beams] = self.eos_token_id - beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) + beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device) + beam_indices = tuple(tuple(b) for b in beam_indices) + beam_scorer.process( + input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices + ) # beam scorer should be done self.parent.assertTrue(beam_scorer.is_done) @@ -136,7 +140,7 @@ def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_sc tokens = next_tokens.clone() tokens[:, 1] = self.eos_token_id beam_outputs = beam_scorer.process( - input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id + input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices ) output_scores = beam_outputs["next_beam_scores"] output_tokens = beam_outputs["next_beam_tokens"] @@ -161,10 +165,15 @@ def cut_expected_tensor(tensor): self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) # make sure ids of eos token are correctly saved in beam_hyps of beam scorer + expected_beam_indices = list(range(10)) for batch_idx in range(self.batch_size): correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] self.parent.assertListEqual( - input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist() + input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist() + ) + self.parent.assertListEqual( + expected_beam_indices + [next_indices[batch_idx, 1].item()], + torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(), ) def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): @@ -188,6 +197,8 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_ input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) # finalize + beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device) + beam_indices = tuple(tuple(b) for b in beam_indices) sequence_output = beam_scorer.finalize( input_ids, output_scores, @@ -196,6 +207,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_ pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, max_length=max_length, + beam_indices=beam_indices, ) sequences = sequence_output["sequences"] @@ -225,6 +237,7 @@ def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_ pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, max_length=max_length, + beam_indices=beam_indices, ) sequences = sequence_output["sequences"] sequence_scores = sequence_output["sequence_scores"] @@ -394,7 +407,7 @@ def cut_expected_tensor(tensor): for batch_idx in range(self.batch_size): correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] self.parent.assertListEqual( - input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist() + input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist() ) def check_constrained_beam_scorer_finalize( diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 707f1f84d738d9..952b9792d64588 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -2322,6 +2322,94 @@ def test_transition_scores_group_beam_search_encoder_decoder(self): self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) + @slow + def test_transition_scores_early_stopping(self): + # This is an aggressive test that makes sure that `beam_search's` + # transition scores are computed correctly for varying `num_return_sequences`, + # `num_beams` and `batch_size > 1` + # 2 x input_ids for "question: How are you? \n context: I had a long day, " + input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to( + torch_device + ) + + model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device) + + result = model.generate( + input_ids, + max_length=10, + return_dict_in_generate=True, + output_scores=True, + forced_eos_token_id=model.config.eos_token_id, + num_beams=4, + do_sample=False, + num_return_sequences=3, + length_penalty=0.0, + ) + + transition_scores = model.compute_transition_beam_scores( + sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices + ) + + sum_transition_scores = torch.sum(transition_scores, dim=1) + + self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist()) + + def test_log_scores_sample_decoder_only(self): + articles = ["I need input_ids to generate", "Short and"] + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device) + + result = model.generate( + **inputs, + max_length=15, + return_dict_in_generate=True, + do_sample=False, + output_scores=True, + ) + + # decoder-only starts generating from `input_ids` + begin_generation = inputs.input_ids.shape[-1] + + gen_sequences = result.sequences[:, begin_generation:] + probs = torch.stack(result.scores, dim=1).softmax(-1) + + gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1) + expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]]) + + self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3)) + + def test_log_scores_sample_encoder_decoder(self): + articles = ["I need input_ids to generate", "Short and"] + tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device) + + inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device) + + result = model.generate( + **inputs, + max_length=3, + return_dict_in_generate=True, + do_sample=False, + num_beams=1, + output_scores=True, + ) + + # encoder-decoder has one decoder_start_token_id by default + begin_generation = 1 + + gen_sequences = result.sequences[:, begin_generation:] + probs = torch.stack(result.scores, dim=1).softmax(-1) + + gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1) + expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]]) + + self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3)) + @slow def test_beam_search_example_integration(self): # exactly the example provided in the docstrings of beam search, which previously @@ -2366,8 +2454,8 @@ def test_beam_search_example_integration(self): @slow def test_constrained_beam_search(self): - model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids @@ -2403,8 +2491,8 @@ def test_constrained_beam_search(self): @slow def test_constrained_beam_search_mixed(self): - model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids flexible_phrases = tokenizer( @@ -2442,8 +2530,8 @@ def test_constrained_beam_search_mixed(self): @slow def test_constrained_beam_search_mixed_mixin(self): - model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") force_word = "scared" force_flexible = ["scream", "screams", "screaming", "screamed"]