Skip to content

Commit

Permalink
add alignatt policy
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Oct 9, 2024
1 parent d5f47fd commit 0cef674
Showing 1 changed file with 221 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import shift_tokens_by_multi_audios
from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids

from nemo.utils import logging
import numpy as np

# the text representation of eos_id, it applies for all tokenizers
END_OF_SEQ = '<|endoftext|>'

Expand Down Expand Up @@ -167,6 +170,56 @@ def end_of_generation_condition(


class CrossAttendAudioToTextGenerationStrategy(AudioToTextGenerationStrategy):

def read_next_audio_chunk(self, **strategy_args):

if strategy_args["debug_mode"]:
logging.warning("\n" + "*******"*10)
logging.warning(f"read_next_audio_chunk")
logging.warning("*******"*10 + "\n")

audio_length = self.audio_length
audio_signal = self.audio_signal[:]

# read params
pre_decision_ratio = strategy_args['pre_decision_ratio']
sample_rate = strategy_args.get('sample_rate', 16000)
right_context = strategy_args.get('right_context', 13)
audio_encoder_fs = strategy_args.get('audio_encoder_fs', 80)

# read audio chunk
cur_enc_len = pre_decision_ratio * (self.speech_chunk_step + strategy_args['alignatt_waitk'])
self.speech_chunk_step += 1
cur_src_len = (cur_enc_len + right_context) * audio_encoder_fs * sample_rate // 1000
audio_signal = audio_signal[:, :cur_src_len]
audio_length = torch.minimum(
audio_length, torch.from_numpy(np.array([cur_src_len])).to(audio_length.device)
)
if strategy_args["debug_mode"]:
logging.warning(f"cur_enc_len: {cur_enc_len}")

# [b, t, c] run audio encoder model
speech_encoded, speech_encoded_len = self.model.perception(
input_signal=audio_signal,
input_signal_length=audio_length,
processed_signal=None,
processed_signal_length=None,
)
# truncate to the current encoder length
speech_encoded = speech_encoded[:, :cur_enc_len]
speech_encoded_len = torch.from_numpy(np.array([cur_enc_len])).to(speech_encoded_len.device)

self.cur_speech_encoded = speech_encoded
self.cur_speech_encoded_len = speech_encoded_len

self.audio_signal_is_finished = False
if cur_enc_len * audio_encoder_fs * sample_rate // 1000 >= audio_length:
logging.warning(f"audio_signal_is_finished: {cur_enc_len}")
self.audio_signal_is_finished = True

return speech_encoded, speech_encoded_len, cur_enc_len


def init_batch_per_step(
self,
step: int,
Expand All @@ -179,7 +232,78 @@ def init_batch_per_step(
assert torch.equal(context_lengths, torch.ones_like(context_lengths) * cl)
audio_length = self.audio_length
audio_signal = self.audio_signal[:]
if 'waitk_lagging' in strategy_args:

# logging.warning(f"strategy_args['decode_policy']: {strategy_args['decode_policy']}")
if strategy_args['decode_policy'] == 'alignatt':
# increase speech chunk size
speech_encoded, speech_encoded_len, cur_enc_len = self.read_next_audio_chunk(**strategy_args)

# call xattn for step 0
input_embeds = self.model._get_text_embeddings(self.context_tokens, None).transpose(0, 1)
if step == 0:
assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0])
context_length = context_lengths[0]
# empty fixed feature for attention masking in context tokens
encoder_input_prev, self.extra_outputs = self.model.perception_cross_attn(
torch.zeros_like(speech_encoded[:, :1]),
torch.ones_like(speech_encoded_len),
input_embeds[:, : context_length - 1],
input_lengths=context_lengths - 1,
return_mems=True,
)

# compute xatt of speech and current context tokens
apply_xatt = True
max_steps_with_same_audio = 3
step_count = 0
decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None)
while apply_xatt and step_count < max_steps_with_same_audio:
encoder_input, self.extra_outputs = self.model.perception_cross_attn(
speech_encoded,
speech_encoded_len,
input_embeds[:, context_length - 1 : context_length],
input_lengths=torch.ones_like(context_lengths),
return_mems=True,
decoder_mems_list=decoder_mems_list,
)
step_count += 1
# compute the most attended audio frame for each text token from the specified cross attention layer
# [batch, head, text_context, audio]
xatt_scores = self.extra_outputs['xatt_scores_list'][strategy_args['xatt_layer']]
xatt_scores = torch.mean(xatt_scores, 1)
# exclude first audio tokens (sink tokens)
most_attended_idx = torch.argmax(xatt_scores[:,:,2:], dim=-1)+2
# logging.warning(f"cur_enc_len ({cur_enc_len}) -1 - most_attended_idx ({most_attended_idx[0].item()}) = {cur_enc_len-1 - most_attended_idx[0].item()}")
if strategy_args["debug_mode"]:
logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}")
logging.warning(f"with value: {xatt_scores[:,:, most_attended_idx].item():.4f}")
logging.warning(f"xatt_scores: {xatt_scores[-1,:,-8:]}")
if cur_enc_len-1 - most_attended_idx >= strategy_args['alignatt_thr']:
apply_xatt = False
else:
if strategy_args["debug_mode"]:
logging.warning(f"[zero step]: !!! condition is not OK, increase speech chunk size")
# increase speech chunk size
speech_encoded, speech_encoded_len, cur_enc_len = self.read_next_audio_chunk(**strategy_args)

# combain all the embeddings
encoder_input = torch.cat([encoder_input_prev, encoder_input, input_embeds[:, context_length:]], dim=1)
# logging.warning(f"encoder_input.shape: {encoder_input.shape}")
base_module = self.model.model
lm_embedding = (
base_module.language_model.embedding
if hasattr(base_module, 'language_model')
else base_module.embedding
)
self.attention_mask = self.model._create_attention_mask(encoder_input)
if not hasattr(lm_embedding, 'transpose_batch_sequence') or lm_embedding.transpose_batch_sequence:
encoder_input = encoder_input.transpose(0, 1).contiguous()
else:
encoder_input = input_embeds.transpose(0, 1).contiguous()
self.attention_mask = self.model._create_attention_mask(encoder_input)
context_tokens = self.context_tokens

elif strategy_args['decode_policy'] == 'waitk':
waitk_lagging = strategy_args['waitk_lagging']
pre_decision_ratio = strategy_args['pre_decision_ratio']
sample_rate = strategy_args.get('sample_rate', 16000)
Expand All @@ -188,9 +312,14 @@ def init_batch_per_step(
# for now only support sharing the same text context for a batch
cur_enc_len = pre_decision_ratio * (step + waitk_lagging)
cur_src_len = (cur_enc_len + right_context) * audio_encoder_fs * sample_rate // 1000

# logging.warning(f"cur_enc_len: {cur_enc_len}")
# logging.warning(f"self.context_tokens.shape: {self.context_tokens.shape}")
# logging.warning(f"self.context_tokens: {self.context_tokens}")
# logging.warning(f"context_lengths: {context_lengths}")
# raise NotImplementedError("This function is not implemented yet")

audio_signal = audio_signal[:, :cur_src_len]
import numpy as np

audio_length = torch.minimum(
audio_length, torch.from_numpy(np.array([cur_src_len])).to(audio_length.device)
)
Expand All @@ -204,7 +333,7 @@ def init_batch_per_step(
)
# call xattn for step 0
input_embeds = self.model._get_text_embeddings(self.context_tokens, None).transpose(0, 1)
if step == 1:
if step == 0:
assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0])
context_length = context_lengths[0]
# empty fixed feature for attention masking in context tokens
Expand All @@ -225,6 +354,7 @@ def init_batch_per_step(
return_mems=True,
decoder_mems_list=decoder_mems_list,
)

encoder_input = torch.cat([encoder_input_prev, encoder_input, input_embeds[:, context_length:]], dim=1)
base_module = self.model.model
lm_embedding = (
Expand All @@ -240,6 +370,7 @@ def init_batch_per_step(
self.attention_mask = self.model._create_attention_mask(encoder_input)
context_tokens = self.context_tokens
else:
raise NotImplementedError("This function is not implemented yet")
batch = {
'audio_signal': audio_signal,
'audio_signal_length': audio_length,
Expand All @@ -258,11 +389,14 @@ def init_batch_per_step(
(speech_encoded, speech_encoded_len, extra_outputs),
) = self.model.prepare_llm_input(batch, **strategy_args)

if 'waitk_lagging' in strategy_args:
if strategy_args['decode_policy'] in ['alignatt', 'waitk']:
# get rid of right context
speech_encoded_len = torch.minimum(
speech_encoded_len, torch.from_numpy(np.array([cur_enc_len])).to(speech_encoded_len.device)
)
speech_encoded = speech_encoded[:, :cur_enc_len]
self.cur_speech_encoded = speech_encoded
self.cur_speech_encoded_len = speech_encoded_len
self.position_ids = build_position_ids(encoder_input[:, :, 0].transpose(0, 1))
return (
context_tokens,
Expand All @@ -286,7 +420,7 @@ def init_batch(
self.context_tokens = context_tokens[:]
self.context_lengths = context_lengths[:]

return self.init_batch_per_step(1, **strategy_args)
return self.init_batch_per_step(0, **strategy_args)

def prepare_batch_at_step(
self,
Expand All @@ -302,16 +436,25 @@ def prepare_batch_at_step(
) -> Tuple[List[torch.Tensor], List[int]]:
# types2use = None
input_embeddings, speech_encoded, speech_encoded_len = input_embeddings

# logging.warning(f"*********"*10)
# logging.warning(f"step: {step}")
# raise StopIteration


if step == 0:
# Allocate memory for the entire context.
set_inference_key_value_memory = True
tokens2use = tokens[:, :curr_context_length]
positions2use = self.position_ids[:, :curr_context_length]
embeddings2use = input_embeddings[:curr_context_length]
# logging.warning(f"embeddings2use.shape: {embeddings2use.shape}")
else:
# Set this to false so the memory is not reallocated.
set_inference_key_value_memory = False
# take last predicted token
tokens2use = tokens[:, curr_context_length - 1].view(micro_batch_size, -1)
# logging.warning(f"tokens2use: {tokens2use}")
positions2use = self.position_ids[:, curr_context_length - 1].view(micro_batch_size, -1)
embeddings2use = self.model._get_text_embeddings(tokens2use, positions2use).transpose(0, 1)
started = context_lengths <= curr_context_length
Expand All @@ -321,12 +464,81 @@ def prepare_batch_at_step(
decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None)
if decoder_mems_list is not None:
decoder_mems_list = decoder_mems_list[:, :, : curr_context_length - 1]
if 'waitk_lagging' in strategy_args:
if strategy_args['decode_policy'] == 'waitk':
# for now only support sharing the same text context for a batch
assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0])
(_, (_, cur_speech_encoded, cur_speech_encoded_len), _) = self.init_batch_per_step(
step + 1, **strategy_args
step, **strategy_args
)
elif strategy_args['decode_policy'] == 'alignatt':
# check alighatt condition by xatt before increasing speech chunk

# 1. compute xatt of speech and current context tokens
apply_xatt = True
max_steps_with_same_token = 10 # max number of speech increasing steps with the same token
step_count = 0
decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None)

while apply_xatt and step_count < max_steps_with_same_token:
embeddings2use_tmp, self.extra_outputs_tmp = self.model.perception_cross_attn(
self.cur_speech_encoded,
self.cur_speech_encoded_len,
embeddings2use,
input_lengths=tokens2use.squeeze(-1) != self.model.tokenizer.eos_id,
decoder_mems_list=decoder_mems_list,
return_mems=True,
)
# 2. compute the most attended audio frame for each text token from the second cross attention layer
# logging.warning(f"self.extra_outputs['xatt_scores_list'][0].shape: {self.extra_outputs['xatt_scores_list'][0].shape}")
# logging.warning(f"self.extra_outputs['xatt_scores_list'][1].shape: {self.extra_outputs['xatt_scores_list'][1].shape}")
step_count += 1
xatt_scores = self.extra_outputs_tmp['xatt_scores_list'][strategy_args['xatt_layer']]
xatt_scores = torch.mean(xatt_scores, 1) # mean by attention heads
# most_attended_idx = torch.argmax(xatt_scores, dim=-1)
most_attended_idx = torch.argmax(xatt_scores[:,:,8:], dim=-1)+8
sum_of_last_xatt = torch.sum(xatt_scores[:,:, -strategy_args['alignatt_thr']:])

average_pooling_xatt_scores = self.pool2d(xatt_scores[:,:,8:])
most_attended_idx_pool = torch.argmax(average_pooling_xatt_scores, dim=-1)+8

if strategy_args["debug_mode"]:
logging.warning(f"self.cur_speech_encoded_len: {self.cur_speech_encoded_len.item()}")
logging.warning(f"=== most_attended_idx: {most_attended_idx[0].item()}")
logging.warning(f"=== most_attended_idx_pool: {most_attended_idx_pool[0].item()}")
logging.warning(f"with value: {xatt_scores[:,:, most_attended_idx].item():.4f}")
logging.warning(f"xatt_scores: {xatt_scores[-1,:,-8:]}")
logging.warning(f"+++sum_of_last_xatt: {sum_of_last_xatt}")
if self.audio_signal_is_finished or \
self.cur_speech_encoded_len-1 - most_attended_idx_pool >= strategy_args['alignatt_thr'] or \
step_count == max_steps_with_same_token-1:
# if self.audio_signal_is_finished or \
# sum_of_last_xatt < 0.1 or \
# step_count == max_steps_with_same_token-1:
# condition is OK, just use the same speech chunk
apply_xatt = False
if strategy_args["debug_mode"]:
logging.warning(f"condition is OK, just use the same speech chunk")
embeddings2use = embeddings2use_tmp
self.extra_outputs = self.extra_outputs_tmp
"""Prepare batch for each of the inference steps"""
setkey_value_array = torch.tensor(
[set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device()
)
len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device())

batch = [tokens2use, embeddings2use, self.attention_mask, positions2use, setkey_value_array, len_array]
tensor_shape = [tokens2use.shape[1], micro_batch_size, self.model.cfg.hidden_size]
return batch, tensor_shape
else:
if strategy_args["debug_mode"]:
logging.warning(f"!!! condition is not OK, increase speech chunk size")
# for now only support sharing the same text context for a batch
assert torch.equal(context_lengths, torch.ones_like(context_lengths) * context_lengths[0])

cur_speech_encoded, cur_speech_encoded_len, _ = self.read_next_audio_chunk(**strategy_args)
self.cur_speech_encoded = cur_speech_encoded
self.cur_speech_encoded_len = cur_speech_encoded_len

else:
cur_speech_encoded = speech_encoded
cur_speech_encoded_len = speech_encoded_len
Expand All @@ -340,6 +552,7 @@ def prepare_batch_at_step(
decoder_mems_list=decoder_mems_list,
return_mems=True,
)

embeddings2use = switch(
input_embeddings[curr_context_length - 1].unsqueeze(0), embeddings2use.transpose(0, 1), started
)
Expand Down

0 comments on commit 0cef674

Please sign in to comment.