Skip to content

Commit

Permalink
fix ctc head resotring
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Aug 16, 2024
1 parent 24f8e2e commit 30e9251
Showing 1 changed file with 47 additions and 13 deletions.
60 changes: 47 additions & 13 deletions nemo/collections/multimodal/speech_llm/models/modular_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def setup_ctc_head(self, cfg):
with open_dict(self.cfg.aux_ctc):
self.cfg.aux_ctc.decoding = ctc_decoding_cfg

self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer.asr_tokenizer)
self.ctc_wer = WER(
decoding=self.ctc_decoding,
use_cer=self.cfg.aux_ctc.get('use_cer', False),
Expand Down Expand Up @@ -548,7 +548,39 @@ def loss_func(output_tensor):

loss_for_ub = (1 - self.ctc_loss_weight) * loss_for_ub + self.ctc_loss_weight * ctc_loss

# if validation_step:
# logging.warning("*************"*100)

if validation_step:
# compute WER from the CTC head

# logging.warning("*************"*10)
# logging.warning(f"CTC Loss: {ctc_loss}")
# logging.warning(f"ctc_log_probs.shape: {ctc_log_probs.shape}")
# logging.warning(f"batch['ctc_tokens_length']: {batch['ctc_tokens_length']}")
# predictions_logprobs, predictions_labels = ctc_log_probs.max(dim=-1)
# for i in range(predictions_labels.shape[0]):
# predictions_labels = predictions_labels[i, :batch["ctc_tokens_length"][i]].tolist()
# logging.warning(f"predictions_labels: {predictions_labels}")
# predictions_labels = self.tokenizer.asr_tokenizer.ids_to_tokens(predictions_labels)
# logging.warning(f"predictions_labels: {predictions_labels}")

self.ctc_wer.update(
predictions=ctc_log_probs,
targets=batch["ctc_tokens"],
targets_lengths=batch["ctc_tokens_length"],
predictions_lengths=ctc_input_lengths,
)
ctc_wer, _, _ = self.ctc_wer.compute()
self.ctc_wer.reset()
# logging.warning("*************"*10)
# logging.warning(f"ctc_wer: {ctc_wer}")
self.log("training_batch_wer_from_ctc_head", ctc_wer, batch_size=1)
# self.log('W_training_batch_wer_ctc_head_v1', ctc_wer, prog_bar=True, rank_zero_only=True, batch_size=1, sync_dist=True)
# raise NotImplementedError("CTC loss implementation in progress...")

# logging.warning("*************"*10)
# logging.warning(f"output_tensor.shape: {output_tensor.shape}")
# logging.warning(f"ctc_log_probs.shape: {ctc_log_probs.shape}")
# logging.warning(f"ctc_input_lengths: {ctc_input_lengths}")
# logging.warning(f"batch['ctc_tokens']: {batch['ctc_tokens']}")
Expand Down Expand Up @@ -597,19 +629,9 @@ def loss_func(output_tensor):
loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()
)

# compute WER from the CTC head
self.ctc_wer.update(
predictions=ctc_log_probs,
targets=batch["ctc_tokens"],
targets_lengths=batch["ctc_tokens_length"],
predictions_lengths=ctc_input_lengths,
)
ctc_wer, _, _ = self.ctc_wer.compute()
self.ctc_wer.reset()
self.log({'training_batch_wer_ctc_head': ctc_wer})

return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
else:

else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
# logging.warning(f"reduced_loss: {reduced_loss}")
return loss_for_ub * cp_size, {'avg': reduced_loss}
Expand Down Expand Up @@ -1032,18 +1054,26 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
return_state_dict = {}

state_dict = self.perception.state_dict(prefix="perception.")
state_dict_ctc_ma = self.ctc_modality_adapter.state_dict(prefix="ctc_modality_adapter.")
state_dict_ctc_dec = self.ctc_decoder.state_dict(prefix="ctc_decoder.")
if self.cfg.freeze_audio_encoder:
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("perception.encoder.")}

return_state_dict.update(state_dict)
state_dict = self.perception.state_dict(prefix="perception.")
return_state_dict.update(state_dict)
return_state_dict.update(state_dict_ctc_ma)
return_state_dict.update(state_dict_ctc_dec)
return return_state_dict
elif self.setup_complete and self.trainer.state.fn != "fit":
# used to save the whole model as a nemo file
return_state_dict = self.model.state_dict(prefix="model.")
state_dict = self.perception.state_dict(prefix="perception.")
state_dict_ctc_ma = self.ctc_modality_adapter.state_dict(prefix="ctc_modality_adapter.")
state_dict_ctc_dec = self.ctc_decoder.state_dict(prefix="ctc_decoder.")
return_state_dict.update(state_dict)
return_state_dict.update(state_dict_ctc_ma)
return_state_dict.update(state_dict_ctc_dec)
return return_state_dict
else:
# we want all the params with the same keys as calling self.state_dict()
Expand All @@ -1054,9 +1084,13 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
else:
return_state_dict = {}
state_dict = self.perception.state_dict(prefix="perception.")
state_dict_ctc_ma = self.ctc_modality_adapter.state_dict(prefix="ctc_modality_adapter.")
state_dict_ctc_dec = self.ctc_decoder.state_dict(prefix="ctc_decoder.")
if self.cfg.freeze_audio_encoder:
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("perception.encoder.")}
return_state_dict.update(state_dict)
return_state_dict.update(state_dict_ctc_ma)
return_state_dict.update(state_dict_ctc_dec)
return return_state_dict

def load_state_dict(self, state_dict, strict: bool = True):
Expand Down

0 comments on commit 30e9251

Please sign in to comment.