Skip to content

Commit

Permalink
use asr tokenizer for ctc head
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Aug 5, 2024
1 parent fe11e08 commit f704df9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def collate_text_data(
for k, v in text_processor._process_example(
context=cut.context,
output=cut.supervisions[0].text,
lang_id=cut.supervisions[0].language,
).items()
}
for cut in cuts
Expand Down
30 changes: 16 additions & 14 deletions nemo/collections/multimodal/speech_llm/models/modular_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def setup_ctc_head(self, cfg):
# self.cfg.aux_ctc.decoder.vocabulary = ListConfig(self.tokenizer.vocab)
# the error arises from 5303 token "${" in the tokenizer
# use dummy vocab for now (temporary fix)
self.cfg.aux_ctc.decoder.vocabulary = [1] * len(self.tokenizer.vocab)
self.cfg.aux_ctc.decoder.num_classes = len(self.tokenizer.vocab)
# self.cfg.aux_ctc.decoder.vocabulary = [1] * len(self.tokenizer.asr_tokenizer.vocab)
self.cfg.aux_ctc.decoder.vocabulary = self.tokenizer.asr_tokenizer.vocab
self.cfg.aux_ctc.decoder.num_classes = len(self.tokenizer.asr_tokenizer.vocab)

self.ctc_decoder = self.from_config_dict(self.cfg.aux_ctc.decoder)
self.ctc_loss_weight = self.cfg.aux_ctc.get("ctc_loss_weight", 0.1)
Expand Down Expand Up @@ -134,10 +135,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.enforce_divisible_batch = False
self.setup_perception_modules(cfg)

### CTC head start:
self.setup_ctc_head(cfg)
### CTC head end.

# print out params in more details
self.summarize(max_depth=2)

Expand Down Expand Up @@ -550,14 +547,15 @@ def loss_func(output_tensor):

loss_for_ub = (1 - self.ctc_loss_weight) * loss_for_ub + self.ctc_loss_weight * ctc_loss

logging.warning("*************"*10)
# logging.warning(f"batch: {batch}")
# logging.warning(f"ctc_head_output[0].shape: {ctc_head_output[0].shape}")
logging.warning(f"batch['ctc_tokens']: {batch['ctc_tokens']}")
logging.warning(f"batch['ctc_tokens'][0]: {self.tokenizer.asr_tokenizer.ids_to_tokens(batch['ctc_tokens'][0].tolist())}")
logging.warning(f"CTC Loss: {ctc_loss}")
logging.warning(f"loss_for_ub: {loss_for_ub}")
raise NotImplementedError("CTC loss implementation in progress...")
# logging.warning("*************"*10)
# # logging.warning(f"batch: {batch}")
# # logging.warning(f"ctc_head_output[0].shape: {ctc_head_output[0].shape}")
# logging.warning(f"batch['ctc_tokens']: {batch['ctc_tokens']}")
# logging.warning(f"batch['ctc_tokens'][0]: {self.tokenizer.asr_tokenizer.ids_to_tokens(batch['ctc_tokens'][0].tolist())}")
# logging.warning(f"batch['ctc_tokens_length']: {batch['ctc_tokens_length']}")
# logging.warning(f"CTC Loss: {ctc_loss}")
# logging.warning(f"loss_for_ub: {loss_for_ub}")
# raise NotImplementedError("CTC loss implementation in progress...")

if self.cfg.data.get(
"return_output_tensors", False
Expand Down Expand Up @@ -866,6 +864,10 @@ def restore_from_pretrained_models(
# load audio model weights
model = cls.load_pretrained_audio_weights(cfg, model, audio_model, speaker_model)

### CTC head start:
model.setup_ctc_head(cfg.model)
### CTC head end.

if 'inference' in cfg:
inference_cfg = OmegaConf.to_container(cfg.inference, resolve=True)
model.set_inference_config(inference_cfg)
Expand Down
12 changes: 5 additions & 7 deletions nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __init__(
self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape')
assert self.truncation_field in ["answer", "context"]

def _process_example(self, context: str, output: str):
def _process_example(self, context: str, output: str, lang_id: str):
"""
Create an example by concatenating text and answer.
Truncation is carried out when needed, but it is performed only on the prompt side.
Expand Down Expand Up @@ -318,13 +318,11 @@ def _process_example(self, context: str, output: str):
# Labels for ctc head
#ctc_tokens_ids = answer_ids[1:]
# logging.warning("++++"*10)
# logging.warning(f"text: {text}")
# logging.warning(f"answer_text: {answer_text}")
logging.warning(f"output: {output}")
# logging.warning(f"original_text: {original_text}")
ctc_tokens_ids = self.tokenizer.asr_tokenizer.text_to_ids(output, "en")
logging.warning(f"ctc_tokens_ids: {ctc_tokens_ids}")
raise ValueError("stop here")
ctc_tokens_ids = self.tokenizer.asr_tokenizer.text_to_ids(output, lang_id)
# logging.warning(f"lang_id: {lang_id}")
# logging.warning(f"ctc_tokens_ids: {ctc_tokens_ids}")
# raise ValueError("stop here")

if self.end_string:
answer_ids += self.tokenizer.text_to_ids(self.end_string)
Expand Down

0 comments on commit f704df9

Please sign in to comment.