Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jul 25, 2024
1 parent eaa5afe commit 7f60594
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,31 @@ def test_custom_4d_attention_mask(self):
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)

def test_generate_output_type(self):
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs()
model = model_class(config).to(torch_device).eval()

# short-form generation without fallback
pred_ids = model.generate(**inputs)
assert isinstance(pred_ids, torch.Tensor)

# short-form generation with fallback
pred_ids = model.generate(**inputs, logprob_threshold=-1.0, temperature=[0.0, 0.1])
assert isinstance(pred_ids, torch.Tensor)

# create artificial long-form inputs
inputs["input_features"] = torch.cat([inputs["input_features"], inputs["input_features"]], dim=-1)
inputs["attention_mask"] = torch.ones(inputs["input_features"].shape[:2], dtype=torch.int, device=inputs["input_features"].device)
model.generation_config.no_timestamps_token_id = model.generation_config.decoder_start_token_id

# long-form generation without fallback
pred_ids = model.generate(**inputs)
assert isinstance(pred_ids, torch.Tensor)

# short-form generation with fallback
pred_ids = model.generate(**inputs, logprob_threshold=-1.0, temperature=[0.0, 0.1])
assert isinstance(pred_ids, torch.Tensor)

@require_torch
@require_torchaudio
Expand Down

0 comments on commit 7f60594

Please sign in to comment.