Skip to content

Commit

Permalink
should be model main input
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Oct 21, 2024
1 parent c39c5ed commit 37d25b1
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def test_greedy_generate(self):
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
Expand Down Expand Up @@ -457,7 +457,7 @@ def test_greedy_generate_dict_outputs(self):
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
Expand Down Expand Up @@ -503,7 +503,7 @@ def test_sample_generate(self):
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
Expand Down Expand Up @@ -552,7 +552,7 @@ def test_beam_search_generate(self):
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
Expand Down Expand Up @@ -588,7 +588,7 @@ def test_beam_search_generate_dict_output(self):
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
Expand Down Expand Up @@ -696,7 +696,7 @@ def test_beam_sample_generate(self):
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
Expand Down Expand Up @@ -786,7 +786,7 @@ def test_group_beam_search_generate(self):
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs()
Expand Down Expand Up @@ -880,7 +880,7 @@ def test_constrained_beam_search_generate(self):
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

model = model_class(config).to(torch_device).eval()

Expand Down Expand Up @@ -963,7 +963,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
self.skipTest(reason="Won't fix: old model with different cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):

# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1196,7 +1196,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):

# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1252,7 +1252,7 @@ def test_dola_decoding_sample(self):

# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

# Encoder-decoder models are not supported
if config.is_encoder_decoder:
Expand Down Expand Up @@ -1310,7 +1310,7 @@ def test_assisted_decoding_sample(self):

# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand Down Expand Up @@ -1857,7 +1857,7 @@ def test_generate_with_static_cache(self):
self.skipTest(reason="This model does not support the static cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
main_input = inputs_dict[model_class.main_input_name]

if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
Expand Down

0 comments on commit 37d25b1

Please sign in to comment.