From f6046b07480427f61341af8c436e228f628bc4fd Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Thu, 10 Oct 2024 12:08:08 -0700 Subject: [PATCH] some fixes Signed-off-by: andrusenkoau --- .../modules/common/audio_text_generation_strategy.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py index 65164b378b7f..b113a5717f29 100644 --- a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py @@ -254,10 +254,10 @@ def init_batch_per_step( # compute xatt of speech and current context tokens apply_xatt = True - max_steps_with_same_audio = 3 + max_steps_with_same_token = 10 step_count = 0 decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None) - while apply_xatt and step_count < max_steps_with_same_audio: + while apply_xatt and step_count < max_steps_with_same_token: encoder_input, self.extra_outputs = self.model.perception_cross_attn( speech_encoded, speech_encoded_len, @@ -273,9 +273,11 @@ def init_batch_per_step( xatt_scores = torch.mean(xatt_scores, 1) # exclude first audio tokens (sink tokens) most_attended_idx = torch.argmax(xatt_scores[:,:,2:], dim=-1)+2 - # logging.warning(f"cur_enc_len ({cur_enc_len}) -1 - most_attended_idx ({most_attended_idx[0].item()}) = {cur_enc_len-1 - most_attended_idx[0].item()}") + average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,2:]) + most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+2 if strategy_args["debug_mode"]: - logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}") + logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}") + logging.warning(f"=== most_attended_idx_pool: {most_attended_idx_pool[0].item()}") logging.warning(f"with value: {xatt_scores[:,:, most_attended_idx].item():.4f}") logging.warning(f"xatt_scores: {xatt_scores[-1,:,-8:]}") if cur_enc_len-1 - most_attended_idx >= strategy_args['alignatt_thr']: @@ -508,9 +510,11 @@ def prepare_batch_at_step( logging.warning(f"with value: {xatt_scores[:,:, most_attended_idx].item():.4f}") logging.warning(f"xatt_scores: {xatt_scores[-1,:,-8:]}") logging.warning(f"+++sum_of_last_xatt: {sum_of_last_xatt}") + # # alignatt policy if self.audio_signal_is_finished or \ self.cur_speech_encoded_len-1 - most_attended_idx_pool >= strategy_args['alignatt_thr'] or \ step_count == max_steps_with_same_token-1: + # EDatt policy # if self.audio_signal_is_finished or \ # sum_of_last_xatt < 0.1 or \ # step_count == max_steps_with_same_token-1: