diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py index 608989a3417b..4cdc296eccc7 100644 --- a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -124,6 +124,7 @@ def collate_text_data( for k, v in text_processor._process_example( context=cut.context, output=cut.supervisions[0].text, + lang_id=cut.supervisions[0].language, ).items() } for cut in cuts diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 5daa729847b1..82ada46685a1 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -90,8 +90,9 @@ def setup_ctc_head(self, cfg): # self.cfg.aux_ctc.decoder.vocabulary = ListConfig(self.tokenizer.vocab) # the error arises from 5303 token "${" in the tokenizer # use dummy vocab for now (temporary fix) - self.cfg.aux_ctc.decoder.vocabulary = [1] * len(self.tokenizer.vocab) - self.cfg.aux_ctc.decoder.num_classes = len(self.tokenizer.vocab) + # self.cfg.aux_ctc.decoder.vocabulary = [1] * len(self.tokenizer.asr_tokenizer.vocab) + self.cfg.aux_ctc.decoder.vocabulary = self.tokenizer.asr_tokenizer.vocab + self.cfg.aux_ctc.decoder.num_classes = len(self.tokenizer.asr_tokenizer.vocab) self.ctc_decoder = self.from_config_dict(self.cfg.aux_ctc.decoder) self.ctc_loss_weight = self.cfg.aux_ctc.get("ctc_loss_weight", 0.1) @@ -134,10 +135,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.enforce_divisible_batch = False self.setup_perception_modules(cfg) - ### CTC head start: - self.setup_ctc_head(cfg) - ### CTC head end. - # print out params in more details self.summarize(max_depth=2) @@ -550,14 +547,15 @@ def loss_func(output_tensor): loss_for_ub = (1 - self.ctc_loss_weight) * loss_for_ub + self.ctc_loss_weight * ctc_loss - logging.warning("*************"*10) - # logging.warning(f"batch: {batch}") - # logging.warning(f"ctc_head_output[0].shape: {ctc_head_output[0].shape}") - logging.warning(f"batch['ctc_tokens']: {batch['ctc_tokens']}") - logging.warning(f"batch['ctc_tokens'][0]: {self.tokenizer.asr_tokenizer.ids_to_tokens(batch['ctc_tokens'][0].tolist())}") - logging.warning(f"CTC Loss: {ctc_loss}") - logging.warning(f"loss_for_ub: {loss_for_ub}") - raise NotImplementedError("CTC loss implementation in progress...") + # logging.warning("*************"*10) + # # logging.warning(f"batch: {batch}") + # # logging.warning(f"ctc_head_output[0].shape: {ctc_head_output[0].shape}") + # logging.warning(f"batch['ctc_tokens']: {batch['ctc_tokens']}") + # logging.warning(f"batch['ctc_tokens'][0]: {self.tokenizer.asr_tokenizer.ids_to_tokens(batch['ctc_tokens'][0].tolist())}") + # logging.warning(f"batch['ctc_tokens_length']: {batch['ctc_tokens_length']}") + # logging.warning(f"CTC Loss: {ctc_loss}") + # logging.warning(f"loss_for_ub: {loss_for_ub}") + # raise NotImplementedError("CTC loss implementation in progress...") if self.cfg.data.get( "return_output_tensors", False @@ -866,6 +864,10 @@ def restore_from_pretrained_models( # load audio model weights model = cls.load_pretrained_audio_weights(cfg, model, audio_model, speaker_model) + ### CTC head start: + model.setup_ctc_head(cfg.model) + ### CTC head end. + if 'inference' in cfg: inference_cfg = OmegaConf.to_container(cfg.inference, resolve=True) model.set_inference_config(inference_cfg) diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py index a35ba57dfe3f..8c7ab41486ff 100644 --- a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py +++ b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py @@ -265,7 +265,7 @@ def __init__( self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape') assert self.truncation_field in ["answer", "context"] - def _process_example(self, context: str, output: str): + def _process_example(self, context: str, output: str, lang_id: str): """ Create an example by concatenating text and answer. Truncation is carried out when needed, but it is performed only on the prompt side. @@ -318,13 +318,11 @@ def _process_example(self, context: str, output: str): # Labels for ctc head #ctc_tokens_ids = answer_ids[1:] # logging.warning("++++"*10) - # logging.warning(f"text: {text}") - # logging.warning(f"answer_text: {answer_text}") - logging.warning(f"output: {output}") # logging.warning(f"original_text: {original_text}") - ctc_tokens_ids = self.tokenizer.asr_tokenizer.text_to_ids(output, "en") - logging.warning(f"ctc_tokens_ids: {ctc_tokens_ids}") - raise ValueError("stop here") + ctc_tokens_ids = self.tokenizer.asr_tokenizer.text_to_ids(output, lang_id) + # logging.warning(f"lang_id: {lang_id}") + # logging.warning(f"ctc_tokens_ids: {ctc_tokens_ids}") + # raise ValueError("stop here") if self.end_string: answer_ids += self.tokenizer.text_to_ids(self.end_string)