Skip to content

Commit

Permalink
Add using simple prompt for Qwen2 Audio to align (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed Nov 23, 2024
1 parent 35e002e commit 971439f
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions lmms_eval/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
use_cache=True,
add_generation_prompt: bool = True,
add_system_prompt: bool = True,
simple_prompt: bool = False,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -41,6 +42,9 @@ def __init__(
accelerator = Accelerator()
self.add_generation_prompt = add_generation_prompt
self.add_system_prompt = add_system_prompt
# If using simple prompt, only add "<|audio_bos|><|AUDIO|><|audio_eos|>"
# and then prompt to align with original Qwen2 Audio
self.simple_prompt = simple_prompt
if accelerator.num_processes > 1:
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
Expand Down Expand Up @@ -206,17 +210,20 @@ def _collate(x):
if isinstance(contexts, tuple):
contexts = list(contexts)

conversations = []
for idx, context in enumerate(contexts):
conv = [{"role": "user", "content": []}]
for _ in batched_audios[idx]:
# This placeholder is just use to make chat template work
# We already have the sampled audio array
conv[0]["content"].append({"type": "audio", "audio_url": "placeholder.wav"})
conv[0]["content"].append({"type": "text", "text": context})
conversations.append(conv)

text = [self.processor.apply_chat_template(conversation, add_generation_prompt=self.add_generation_prompt, tokenize=False) for conversation in conversations]
if not self.simple_prompt:
conversations = []
for idx, context in enumerate(contexts):
conv = [{"role": "user", "content": []}]
for _ in batched_audios[idx]:
# This placeholder is just use to make chat template work
# We already have the sampled audio array
conv[0]["content"].append({"type": "audio", "audio_url": "placeholder.wav"})
conv[0]["content"].append({"type": "text", "text": context})
conversations.append(conv)

text = [self.processor.apply_chat_template(conversation, add_generation_prompt=self.add_generation_prompt, tokenize=False) for conversation in conversations]
else:
text = ["<|audio_bos|><|AUDIO|><|audio_eos|>" + context for context in contexts]
audios = [downsample_audio(audio["array"], audio["sampling_rate"], self.processor.feature_extractor.sampling_rate) for audio in flattened_audios]

inputs = self.processor(text=text, audios=audios, return_tensors="pt", padding=True, sampling_rate=self.processor.feature_extractor.sampling_rate)
Expand Down

0 comments on commit 971439f

Please sign in to comment.