From c92ddcc440de031b605d9e3867dd0fd152ff5028 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Fri, 11 Oct 2024 07:18:37 -0700 Subject: [PATCH] add compute_alignatt_lagging function for LALA and AL computation Signed-off-by: andrusenkoau --- .../speech_llm/models/modular_models.py | 13 ++++++++++-- .../common/audio_text_generation_utils.py | 13 ++++++------ .../speech_llm/parts/utils/data_utils.py | 21 +++++++++++++++++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 5f03528e55ae..9ac0d5eae327 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -42,7 +42,7 @@ MultiAudioPerceptionModule, ) from nemo.collections.multimodal.speech_llm.parts.mixins.adapter_mixin import SpeechLLMAdapterMixin -from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import get_nested_dict_value, compute_waitk_lagging +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import get_nested_dict_value, compute_waitk_lagging, compute_alignatt_lagging from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel from nemo.collections.nlp.modules.common.megatron.utils import ( @@ -1260,6 +1260,7 @@ def inference_step(self, dataloader_iter, mode): self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) for t, l in zip(output['token_ids'], batch['context_lengths']) ] + pred_tokens_alignment = output['pred_tokens_alignment'] if data_cfg.get("end_string", None): # sometimes data_cfg.end_string != self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(data_cfg.end_string)) @@ -1284,11 +1285,19 @@ def inference_step(self, dataloader_iter, mode): labels_text = [remove_punctuations(l.lower(), data_cfg.get("punctuations", None)) for l in labels_text] strategy_args = self.get_inference_config() - if 'waitk_lagging' in strategy_args: + # if 'waitk_lagging' in strategy_args: + if strategy_args['decode_policy'] == 'waitk': context_lengths = batch['context_lengths'] assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0]) predicted_token_ids = [i[context_lengths[0].item() :] for i in output['token_ids']] + # logging.warning(f"predicted_token_ids: {predicted_token_ids}") metadata = compute_waitk_lagging(batch, predicted_token_ids, metadata, labels_text, strategy_args, self.tokenizer) + elif strategy_args['decode_policy'] == 'alignatt': + context_lengths = batch['context_lengths'] + assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0]) + predicted_token_ids = [i[context_lengths[0].item() :] for i in output['token_ids']] + # logging.warning(f"predicted_token_ids: {predicted_token_ids}") + metadata = compute_alignatt_lagging(batch, predicted_token_ids, metadata, labels_text, strategy_args, self.tokenizer, pred_tokens_alignment) if data_cfg.get("log_every_n_steps", None) is not None: diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py index 3d221cd245ec..50302f8c2d86 100644 --- a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py @@ -256,7 +256,7 @@ def synced_generate( **strategy_args, ) - for tokens, lengths, output_logits, full_logits, audio_feat_lens in batch_token_iterator: + for tokens, lengths, output_logits, full_logits, audio_feat_lens, pred_tokens_alignment in batch_token_iterator: context_length += 1 context_length += audio_feat_lens.min().item() if parallel_state.is_pipeline_last_stage(): @@ -299,7 +299,7 @@ def synced_generate( ) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: - return tokens[:, :context_length], output_logits, full_logits, audio_feat_lens + return tokens[:, :context_length], output_logits, full_logits, audio_feat_lens, pred_tokens_alignment return None @@ -449,7 +449,7 @@ def generate( if hasattr(tokenizer, 'mask_token') and tokenizer.mask_token is not None: special_tokens.add(tokenizer.mask_token) if output is not None: - decode_tokens, output_logits, full_logits, audio_feat_lens = output + decode_tokens, output_logits, full_logits, audio_feat_lens, pred_tokens_alignment = output resp_sentences = [] resp_sentences_seg = [] @@ -495,6 +495,7 @@ def generate( output['token_ids'] = decode_tokens output['offsets'] = all_offsets output['audio_feat_lens'] = audio_feat_lens + output['pred_tokens_alignment'] = pred_tokens_alignment output = inference_strategy.post_generation_process(output) return output return None @@ -696,11 +697,11 @@ def sample_sequence_batch( torch.distributed.broadcast(done, src, group) if compute_logprob: if all_probs: - yield tokens, lengths, output_logits, full_logits, audio_feat_lens + yield tokens, lengths, output_logits, full_logits, audio_feat_lens, inference_strategy.token_alignatt else: - yield tokens, lengths, output_logits, None, audio_feat_lens + yield tokens, lengths, output_logits, None, audio_feat_lens, inference_strategy.token_alignatt else: - yield tokens, lengths, None, None, audio_feat_lens + yield tokens, lengths, None, None, audio_feat_lens, inference_strategy.token_alignatt else: if parallel_state.is_pipeline_first_stage(): diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py index 469fc43e4c9c..fda108114ea0 100644 --- a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py +++ b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py @@ -229,6 +229,27 @@ def compute_waitk_lagging(batch, predicted_token_ids, metadata, labels_text, str return metadata +def compute_alignatt_lagging(batch, predicted_token_ids, metadata, labels_text, strategy_args, tokenizer, pred_tokens_alignment, BOW_PREFIX = "\u2581"): + assert len(predicted_token_ids[0]) == len(pred_tokens_alignment) # sanity check for alignment length + target_length_word = [len(a.split()) for a in labels_text] + for i, tokens in enumerate(predicted_token_ids): + audio_signal_length = batch['audio_signal_length'][i] * 1000 # convert to ms + audio_signal_length = audio_signal_length // strategy_args.get('sample_rate', 16000) + audio_encoder_fs = strategy_args.get('audio_encoder_fs', 80) + # obtain lagging for alignatt + lagging = [] + for cur_t, pred_idx in pred_tokens_alignment: + cur_t = cur_t[0] + eos_token = tokenizer.vocab[tokenizer.eos_id] + if (cur_t.startswith(BOW_PREFIX) and cur_t != BOW_PREFIX) or cur_t == eos_token: # word boundary + lagging.append(pred_idx * audio_encoder_fs) + if cur_t == eos_token: + break + # logging.warning(f"lagging: {lagging}") + metadata[i]['LAAL'] = compute_laal(lagging, audio_signal_length, target_length_word[i]).tolist() + metadata[i]['AL'] = compute_al(lagging, audio_signal_length, target_length_word[i]).tolist() + return metadata + def build_loss_mask(processed_example: dict, answer_only_loss: bool = True): """Pad input_ids in batch to max batch length while building loss mask"""