From e6c6eb49fde3e9a64a9833c358c30a6ec7147ae9 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Wed, 11 Sep 2024 01:03:06 -0700 Subject: [PATCH] disable ctc modality adapter Signed-off-by: andrusenkoau --- .../collections/multimodal/speech_llm/models/modular_models.py | 3 ++- .../multimodal/speech_llm/modules/perception_modules.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index b6ba3c59d06b..a932e4667f68 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -519,7 +519,8 @@ def loss_func(output_tensor): cp_size = self.cfg.get('context_parallel_size', 1) # compute ctc loss - ctc_encoded, ctc_encoded_len = self.perception.ctc_modality_adapter(audio_signal=audio_encoder_outputs[0], length=audio_encoder_outputs[1]) + # ctc_encoded, ctc_encoded_len = self.perception.ctc_modality_adapter(audio_signal=audio_encoder_outputs[0], length=audio_encoder_outputs[1]) + ctc_encoded, ctc_encoded_len = audio_encoder_outputs[0], audio_encoder_outputs[1] ctc_log_probs = self.perception.ctc_decoder(encoder_output=ctc_encoded) ctc_input_lengths = ctc_encoded_len # ctc_log_probs = self.ctc_decoder(encoder_output=audio_encoder_outputs[0]) diff --git a/nemo/collections/multimodal/speech_llm/modules/perception_modules.py b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py index 33e3b9f2b89a..96b4b06882e7 100644 --- a/nemo/collections/multimodal/speech_llm/modules/perception_modules.py +++ b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py @@ -102,7 +102,7 @@ def __init__(self, cfg: DictConfig): raise ValueError( "The config need to have a section for the CTC decoder named as aux_ctc for Hybrid models." ) - self.ctc_modality_adapter = self.from_config_dict(cfg.aux_ctc.modality_adapter) + # self.ctc_modality_adapter = self.from_config_dict(cfg.aux_ctc.modality_adapter) self.cfg.aux_ctc.decoder.vocabulary = [1]*self.cfg.aux_ctc.decoder.num_classes # self.cfg.aux_ctc.decoder.num_classes = len(ctc_tokenizer.vocab)