Skip to content

Commit

Permalink
add exclude_sink_frames parameter
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Oct 12, 2024
1 parent c92ddcc commit 549e051
Showing 1 changed file with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def init_batch_per_step(
max_steps_with_same_token = 10
step_count = 0
decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None)
exclude_sink_frames = strategy_args.get('exclude_sink_frames', 2)
exclude_sink_frames = exclude_sink_frames // 2
while apply_xatt and step_count < max_steps_with_same_token:
encoder_input, self.extra_outputs = self.model.perception_cross_attn(
speech_encoded,
Expand All @@ -272,9 +274,11 @@ def init_batch_per_step(
xatt_scores = self.extra_outputs['xatt_scores_list'][strategy_args['xatt_layer']]
xatt_scores = torch.mean(xatt_scores, 1)
# exclude first audio tokens (sink tokens)
most_attended_idx = torch.argmax(xatt_scores[:,:,2:], dim=-1)+2
average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,2:])
most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+2
# most_attended_idx = torch.argmax(xatt_scores[:,:,2:], dim=-1)+2
# alternative serach for the most attended frame by smoothing the attention scores
most_attended_idx = torch.argmax(xatt_scores[:,:,exclude_sink_frames:], dim=-1)+exclude_sink_frames
average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,exclude_sink_frames:])
most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+exclude_sink_frames
if strategy_args["debug_mode"]:
logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}")
logging.warning(f"=== most_attended_idx_pool: {most_attended_idx_pool[0].item()}")
Expand Down Expand Up @@ -372,7 +376,7 @@ def init_batch_per_step(
self.attention_mask = self.model._create_attention_mask(encoder_input)
context_tokens = self.context_tokens
else:
raise NotImplementedError("This function is not implemented yet")
# raise NotImplementedError("This function is not implemented yet")
batch = {
'audio_signal': audio_signal,
'audio_signal_length': audio_length,
Expand Down Expand Up @@ -480,6 +484,7 @@ def prepare_batch_at_step(
max_steps_with_same_token = 10 # max number of speech increasing steps with the same token
step_count = 0
decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None)
exclude_sink_frames = strategy_args.get('exclude_sink_frames', 2)

while apply_xatt and step_count < max_steps_with_same_token:
embeddings2use_tmp, self.extra_outputs_tmp = self.model.perception_cross_attn(
Expand All @@ -497,11 +502,12 @@ def prepare_batch_at_step(
xatt_scores = self.extra_outputs_tmp['xatt_scores_list'][strategy_args['xatt_layer']]
xatt_scores = torch.mean(xatt_scores, 1) # mean by attention heads
# most_attended_idx = torch.argmax(xatt_scores, dim=-1)
most_attended_idx = torch.argmax(xatt_scores[:,:,8:], dim=-1)+8
most_attended_idx = torch.argmax(xatt_scores[:,:,exclude_sink_frames:], dim=-1)+exclude_sink_frames
sum_of_last_xatt = torch.sum(xatt_scores[:,:, -strategy_args['alignatt_thr']:])

average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,8:])
most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+8
# alternative serach for the most attended frame by smoothing the attention scores
average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,exclude_sink_frames:])
most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+exclude_sink_frames

if strategy_args["debug_mode"]:
logging.warning(f"self.cur_speech_encoded_len: {self.cur_speech_encoded_len.item()}")
Expand Down

0 comments on commit 549e051

Please sign in to comment.