Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Oct 10, 2024
1 parent b0170e5 commit f6046b0
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ def init_batch_per_step(

# compute xatt of speech and current context tokens
apply_xatt = True
max_steps_with_same_audio = 3
max_steps_with_same_token = 10
step_count = 0
decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None)
while apply_xatt and step_count < max_steps_with_same_audio:
while apply_xatt and step_count < max_steps_with_same_token:
encoder_input, self.extra_outputs = self.model.perception_cross_attn(
speech_encoded,
speech_encoded_len,
Expand All @@ -273,9 +273,11 @@ def init_batch_per_step(
xatt_scores = torch.mean(xatt_scores, 1)
# exclude first audio tokens (sink tokens)
most_attended_idx = torch.argmax(xatt_scores[:,:,2:], dim=-1)+2
# logging.warning(f"cur_enc_len ({cur_enc_len}) -1 - most_attended_idx ({most_attended_idx[0].item()}) = {cur_enc_len-1 - most_attended_idx[0].item()}")
average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,2:])
most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+2
if strategy_args["debug_mode"]:
logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}")
logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}")
logging.warning(f"=== most_attended_idx_pool: {most_attended_idx_pool[0].item()}")
logging.warning(f"with value: {xatt_scores[:,:, most_attended_idx].item():.4f}")
logging.warning(f"xatt_scores: {xatt_scores[-1,:,-8:]}")
if cur_enc_len-1 - most_attended_idx >= strategy_args['alignatt_thr']:
Expand Down Expand Up @@ -508,9 +510,11 @@ def prepare_batch_at_step(
logging.warning(f"with value: {xatt_scores[:,:, most_attended_idx].item():.4f}")
logging.warning(f"xatt_scores: {xatt_scores[-1,:,-8:]}")
logging.warning(f"+++sum_of_last_xatt: {sum_of_last_xatt}")
# # alignatt policy
if self.audio_signal_is_finished or \
self.cur_speech_encoded_len-1 - most_attended_idx_pool >= strategy_args['alignatt_thr'] or \
step_count == max_steps_with_same_token-1:
# EDatt policy
# if self.audio_signal_is_finished or \
# sum_of_last_xatt < 0.1 or \
# step_count == max_steps_with_same_token-1:
Expand Down

0 comments on commit f6046b0

Please sign in to comment.