diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index a123b3f3bfe8..ab232d9d6c56 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -108,7 +108,7 @@ def setup_ctc_head(self, cfg): with open_dict(self.cfg.aux_ctc): self.cfg.aux_ctc.decoding = ctc_decoding_cfg - self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer.asr_tokenizer) self.ctc_wer = WER( decoding=self.ctc_decoding, use_cer=self.cfg.aux_ctc.get('use_cer', False), @@ -548,7 +548,39 @@ def loss_func(output_tensor): loss_for_ub = (1 - self.ctc_loss_weight) * loss_for_ub + self.ctc_loss_weight * ctc_loss + # if validation_step: + # logging.warning("*************"*100) + + if validation_step: + # compute WER from the CTC head + + # logging.warning("*************"*10) + # logging.warning(f"CTC Loss: {ctc_loss}") + # logging.warning(f"ctc_log_probs.shape: {ctc_log_probs.shape}") + # logging.warning(f"batch['ctc_tokens_length']: {batch['ctc_tokens_length']}") + # predictions_logprobs, predictions_labels = ctc_log_probs.max(dim=-1) + # for i in range(predictions_labels.shape[0]): + # predictions_labels = predictions_labels[i, :batch["ctc_tokens_length"][i]].tolist() + # logging.warning(f"predictions_labels: {predictions_labels}") + # predictions_labels = self.tokenizer.asr_tokenizer.ids_to_tokens(predictions_labels) + # logging.warning(f"predictions_labels: {predictions_labels}") + + self.ctc_wer.update( + predictions=ctc_log_probs, + targets=batch["ctc_tokens"], + targets_lengths=batch["ctc_tokens_length"], + predictions_lengths=ctc_input_lengths, + ) + ctc_wer, _, _ = self.ctc_wer.compute() + self.ctc_wer.reset() + # logging.warning("*************"*10) + # logging.warning(f"ctc_wer: {ctc_wer}") + self.log("training_batch_wer_from_ctc_head", ctc_wer, batch_size=1) + # self.log('W_training_batch_wer_ctc_head_v1', ctc_wer, prog_bar=True, rank_zero_only=True, batch_size=1, sync_dist=True) + # raise NotImplementedError("CTC loss implementation in progress...") + # logging.warning("*************"*10) + # logging.warning(f"output_tensor.shape: {output_tensor.shape}") # logging.warning(f"ctc_log_probs.shape: {ctc_log_probs.shape}") # logging.warning(f"ctc_input_lengths: {ctc_input_lengths}") # logging.warning(f"batch['ctc_tokens']: {batch['ctc_tokens']}") @@ -597,19 +629,9 @@ def loss_func(output_tensor): loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() ) - # compute WER from the CTC head - self.ctc_wer.update( - predictions=ctc_log_probs, - targets=batch["ctc_tokens"], - targets_lengths=batch["ctc_tokens_length"], - predictions_lengths=ctc_input_lengths, - ) - ctc_wer, _, _ = self.ctc_wer.compute() - self.ctc_wer.reset() - self.log({'training_batch_wer_ctc_head': ctc_wer}) - return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} - else: + + else: reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) # logging.warning(f"reduced_loss: {reduced_loss}") return loss_for_ub * cp_size, {'avg': reduced_loss} @@ -1032,18 +1054,26 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): return_state_dict = {} state_dict = self.perception.state_dict(prefix="perception.") + state_dict_ctc_ma = self.ctc_modality_adapter.state_dict(prefix="ctc_modality_adapter.") + state_dict_ctc_dec = self.ctc_decoder.state_dict(prefix="ctc_decoder.") if self.cfg.freeze_audio_encoder: state_dict = {k: v for k, v in state_dict.items() if not k.startswith("perception.encoder.")} return_state_dict.update(state_dict) state_dict = self.perception.state_dict(prefix="perception.") return_state_dict.update(state_dict) + return_state_dict.update(state_dict_ctc_ma) + return_state_dict.update(state_dict_ctc_dec) return return_state_dict elif self.setup_complete and self.trainer.state.fn != "fit": # used to save the whole model as a nemo file return_state_dict = self.model.state_dict(prefix="model.") state_dict = self.perception.state_dict(prefix="perception.") + state_dict_ctc_ma = self.ctc_modality_adapter.state_dict(prefix="ctc_modality_adapter.") + state_dict_ctc_dec = self.ctc_decoder.state_dict(prefix="ctc_decoder.") return_state_dict.update(state_dict) + return_state_dict.update(state_dict_ctc_ma) + return_state_dict.update(state_dict_ctc_dec) return return_state_dict else: # we want all the params with the same keys as calling self.state_dict() @@ -1054,9 +1084,13 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): else: return_state_dict = {} state_dict = self.perception.state_dict(prefix="perception.") + state_dict_ctc_ma = self.ctc_modality_adapter.state_dict(prefix="ctc_modality_adapter.") + state_dict_ctc_dec = self.ctc_decoder.state_dict(prefix="ctc_decoder.") if self.cfg.freeze_audio_encoder: state_dict = {k: v for k, v in state_dict.items() if not k.startswith("perception.encoder.")} return_state_dict.update(state_dict) + return_state_dict.update(state_dict_ctc_ma) + return_state_dict.update(state_dict_ctc_dec) return return_state_dict def load_state_dict(self, state_dict, strict: bool = True):