diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ae59da97ec417d..3c702c580f5fd6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -424,7 +424,7 @@ def test_greedy_generate(self): def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( @@ -457,7 +457,7 @@ def test_greedy_generate_dict_outputs(self): def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -503,7 +503,7 @@ def test_sample_generate(self): def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( @@ -552,7 +552,7 @@ def test_beam_search_generate(self): def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -588,7 +588,7 @@ def test_beam_search_generate_dict_output(self): def test_beam_search_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -696,7 +696,7 @@ def test_beam_sample_generate(self): def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -786,7 +786,7 @@ def test_group_beam_search_generate(self): def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_diverse_beam_kwargs() @@ -880,7 +880,7 @@ def test_constrained_beam_search_generate(self): def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() @@ -963,7 +963,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self): self.skipTest(reason="Won't fix: old model with different cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1121,7 +1121,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1196,7 +1196,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1252,7 +1252,7 @@ def test_dola_decoding_sample(self): # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] # Encoder-decoder models are not supported if config.is_encoder_decoder: @@ -1310,7 +1310,7 @@ def test_assisted_decoding_sample(self): # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1857,7 +1857,7 @@ def test_generate_with_static_cache(self): self.skipTest(reason="This model does not support the static cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[self.input_name] + main_input = inputs_dict[model_class.main_input_name] if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")