Skip to content

Commit

Permalink
add compute_alignatt_lagging function for LALA and AL computation
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Oct 11, 2024
1 parent f6046b0 commit c92ddcc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
13 changes: 11 additions & 2 deletions nemo/collections/multimodal/speech_llm/models/modular_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
MultiAudioPerceptionModule,
)
from nemo.collections.multimodal.speech_llm.parts.mixins.adapter_mixin import SpeechLLMAdapterMixin
from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import get_nested_dict_value, compute_waitk_lagging
from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import get_nested_dict_value, compute_waitk_lagging, compute_alignatt_lagging
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.megatron.utils import (
Expand Down Expand Up @@ -1260,6 +1260,7 @@ def inference_step(self, dataloader_iter, mode):
self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')])
for t, l in zip(output['token_ids'], batch['context_lengths'])
]
pred_tokens_alignment = output['pred_tokens_alignment']

if data_cfg.get("end_string", None):
# sometimes data_cfg.end_string != self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(data_cfg.end_string))
Expand All @@ -1284,11 +1285,19 @@ def inference_step(self, dataloader_iter, mode):
labels_text = [remove_punctuations(l.lower(), data_cfg.get("punctuations", None)) for l in labels_text]

strategy_args = self.get_inference_config()
if 'waitk_lagging' in strategy_args:
# if 'waitk_lagging' in strategy_args:
if strategy_args['decode_policy'] == 'waitk':
context_lengths = batch['context_lengths']
assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0])
predicted_token_ids = [i[context_lengths[0].item() :] for i in output['token_ids']]
# logging.warning(f"predicted_token_ids: {predicted_token_ids}")
metadata = compute_waitk_lagging(batch, predicted_token_ids, metadata, labels_text, strategy_args, self.tokenizer)
elif strategy_args['decode_policy'] == 'alignatt':
context_lengths = batch['context_lengths']
assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0])
predicted_token_ids = [i[context_lengths[0].item() :] for i in output['token_ids']]
# logging.warning(f"predicted_token_ids: {predicted_token_ids}")
metadata = compute_alignatt_lagging(batch, predicted_token_ids, metadata, labels_text, strategy_args, self.tokenizer, pred_tokens_alignment)


if data_cfg.get("log_every_n_steps", None) is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def synced_generate(
**strategy_args,
)

for tokens, lengths, output_logits, full_logits, audio_feat_lens in batch_token_iterator:
for tokens, lengths, output_logits, full_logits, audio_feat_lens, pred_tokens_alignment in batch_token_iterator:
context_length += 1
context_length += audio_feat_lens.min().item()
if parallel_state.is_pipeline_last_stage():
Expand Down Expand Up @@ -299,7 +299,7 @@ def synced_generate(
)
torch.distributed.broadcast(full_logits, src, group)
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits, audio_feat_lens
return tokens[:, :context_length], output_logits, full_logits, audio_feat_lens, pred_tokens_alignment
return None


Expand Down Expand Up @@ -449,7 +449,7 @@ def generate(
if hasattr(tokenizer, 'mask_token') and tokenizer.mask_token is not None:
special_tokens.add(tokenizer.mask_token)
if output is not None:
decode_tokens, output_logits, full_logits, audio_feat_lens = output
decode_tokens, output_logits, full_logits, audio_feat_lens, pred_tokens_alignment = output
resp_sentences = []
resp_sentences_seg = []

Expand Down Expand Up @@ -495,6 +495,7 @@ def generate(
output['token_ids'] = decode_tokens
output['offsets'] = all_offsets
output['audio_feat_lens'] = audio_feat_lens
output['pred_tokens_alignment'] = pred_tokens_alignment
output = inference_strategy.post_generation_process(output)
return output
return None
Expand Down Expand Up @@ -696,11 +697,11 @@ def sample_sequence_batch(
torch.distributed.broadcast(done, src, group)
if compute_logprob:
if all_probs:
yield tokens, lengths, output_logits, full_logits, audio_feat_lens
yield tokens, lengths, output_logits, full_logits, audio_feat_lens, inference_strategy.token_alignatt
else:
yield tokens, lengths, output_logits, None, audio_feat_lens
yield tokens, lengths, output_logits, None, audio_feat_lens, inference_strategy.token_alignatt
else:
yield tokens, lengths, None, None, audio_feat_lens
yield tokens, lengths, None, None, audio_feat_lens, inference_strategy.token_alignatt

else:
if parallel_state.is_pipeline_first_stage():
Expand Down
21 changes: 21 additions & 0 deletions nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,27 @@ def compute_waitk_lagging(batch, predicted_token_ids, metadata, labels_text, str
return metadata


def compute_alignatt_lagging(batch, predicted_token_ids, metadata, labels_text, strategy_args, tokenizer, pred_tokens_alignment, BOW_PREFIX = "\u2581"):
assert len(predicted_token_ids[0]) == len(pred_tokens_alignment) # sanity check for alignment length
target_length_word = [len(a.split()) for a in labels_text]
for i, tokens in enumerate(predicted_token_ids):
audio_signal_length = batch['audio_signal_length'][i] * 1000 # convert to ms
audio_signal_length = audio_signal_length // strategy_args.get('sample_rate', 16000)
audio_encoder_fs = strategy_args.get('audio_encoder_fs', 80)
# obtain lagging for alignatt
lagging = []
for cur_t, pred_idx in pred_tokens_alignment:
cur_t = cur_t[0]
eos_token = tokenizer.vocab[tokenizer.eos_id]
if (cur_t.startswith(BOW_PREFIX) and cur_t != BOW_PREFIX) or cur_t == eos_token: # word boundary
lagging.append(pred_idx * audio_encoder_fs)
if cur_t == eos_token:
break
# logging.warning(f"lagging: {lagging}")
metadata[i]['LAAL'] = compute_laal(lagging, audio_signal_length, target_length_word[i]).tolist()
metadata[i]['AL'] = compute_al(lagging, audio_signal_length, target_length_word[i]).tolist()
return metadata


def build_loss_mask(processed_example: dict, answer_only_loss: bool = True):
"""Pad input_ids in batch to max batch length while building loss mask"""
Expand Down

0 comments on commit c92ddcc

Please sign in to comment.