diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py index b113a5717f29..f6b190952d2c 100644 --- a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py @@ -257,6 +257,8 @@ def init_batch_per_step( max_steps_with_same_token = 10 step_count = 0 decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None) + exclude_sink_frames = strategy_args.get('exclude_sink_frames', 2) + exclude_sink_frames = exclude_sink_frames // 2 while apply_xatt and step_count < max_steps_with_same_token: encoder_input, self.extra_outputs = self.model.perception_cross_attn( speech_encoded, @@ -272,9 +274,11 @@ def init_batch_per_step( xatt_scores = self.extra_outputs['xatt_scores_list'][strategy_args['xatt_layer']] xatt_scores = torch.mean(xatt_scores, 1) # exclude first audio tokens (sink tokens) - most_attended_idx = torch.argmax(xatt_scores[:,:,2:], dim=-1)+2 - average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,2:]) - most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+2 + # most_attended_idx = torch.argmax(xatt_scores[:,:,2:], dim=-1)+2 + # alternative serach for the most attended frame by smoothing the attention scores + most_attended_idx = torch.argmax(xatt_scores[:,:,exclude_sink_frames:], dim=-1)+exclude_sink_frames + average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,exclude_sink_frames:]) + most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+exclude_sink_frames if strategy_args["debug_mode"]: logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}") logging.warning(f"=== most_attended_idx_pool: {most_attended_idx_pool[0].item()}") @@ -372,7 +376,7 @@ def init_batch_per_step( self.attention_mask = self.model._create_attention_mask(encoder_input) context_tokens = self.context_tokens else: - raise NotImplementedError("This function is not implemented yet") + # raise NotImplementedError("This function is not implemented yet") batch = { 'audio_signal': audio_signal, 'audio_signal_length': audio_length, @@ -480,6 +484,7 @@ def prepare_batch_at_step( max_steps_with_same_token = 10 # max number of speech increasing steps with the same token step_count = 0 decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None) + exclude_sink_frames = strategy_args.get('exclude_sink_frames', 2) while apply_xatt and step_count < max_steps_with_same_token: embeddings2use_tmp, self.extra_outputs_tmp = self.model.perception_cross_attn( @@ -497,11 +502,12 @@ def prepare_batch_at_step( xatt_scores = self.extra_outputs_tmp['xatt_scores_list'][strategy_args['xatt_layer']] xatt_scores = torch.mean(xatt_scores, 1) # mean by attention heads # most_attended_idx = torch.argmax(xatt_scores, dim=-1) - most_attended_idx = torch.argmax(xatt_scores[:,:,8:], dim=-1)+8 + most_attended_idx = torch.argmax(xatt_scores[:,:,exclude_sink_frames:], dim=-1)+exclude_sink_frames sum_of_last_xatt = torch.sum(xatt_scores[:,:, -strategy_args['alignatt_thr']:]) - average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,8:]) - most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+8 + # alternative serach for the most attended frame by smoothing the attention scores + average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,exclude_sink_frames:]) + most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+exclude_sink_frames if strategy_args["debug_mode"]: logging.warning(f"self.cur_speech_encoded_len: {self.cur_speech_encoded_len.item()}")