From 41f27a5e68934761b05dd373e7b288aa798858e5 Mon Sep 17 00:00:00 2001 From: Vahid Noroozi Date: Tue, 21 Jun 2022 19:52:18 -0700 Subject: [PATCH] Fixing bugs in calling method ctc_decoder_predictions_tensor. (#4414) * updated ctc decoding calls. Signed-off-by: Vahid * fixed the ones for timestamp_utils.py Signed-off-by: Vahid * fixed the ones for timestamp_utils.py Signed-off-by: Vahid * fixed the ones for timestamp_utils.py Signed-off-by: Vahid --- examples/asr/quantization/speech_to_text_quant_infer.py | 2 +- .../asr/quantization/speech_to_text_quant_infer_trt.py | 2 +- nemo/collections/asr/models/ctc_models.py | 2 +- nemo/collections/asr/parts/utils/transcribe_utils.py | 2 +- .../ngram_lm/eval_beamsearch_ngram.py | 2 +- .../ngram_lm/install_beamsearch_decoders.sh | 8 ++++++-- tutorials/asr/ASR_with_NeMo.ipynb | 2 +- 7 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/asr/quantization/speech_to_text_quant_infer.py b/examples/asr/quantization/speech_to_text_quant_infer.py index 87f26ecae88d..bed52fb6cc7a 100644 --- a/examples/asr/quantization/speech_to_text_quant_infer.py +++ b/examples/asr/quantization/speech_to_text_quant_infer.py @@ -202,7 +202,7 @@ def evaluate(asr_model, labels_map, wer): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=test_batch[0], input_signal_length=test_batch[1] ) - hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) + hypotheses += wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)[0] for batch_ind in range(greedy_predictions.shape[0]): seq_len = test_batch[3][batch_ind].cpu().detach().numpy() seq_ids = test_batch[2][batch_ind].cpu().detach().numpy() diff --git a/examples/asr/quantization/speech_to_text_quant_infer_trt.py b/examples/asr/quantization/speech_to_text_quant_infer_trt.py index 72c5cf43f787..017d935cef6d 100644 --- a/examples/asr/quantization/speech_to_text_quant_infer_trt.py +++ b/examples/asr/quantization/speech_to_text_quant_infer_trt.py @@ -212,7 +212,7 @@ def evaluate(asr_model, asr_onnx, labels_map, wer, qat): input_signal=processed_signal, input_signal_length=processed_signal_length, ) - hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions) + hypotheses += wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)[0] for batch_ind in range(greedy_predictions.shape[0]): seq_len = test_batch[3][batch_ind].cpu().detach().numpy() seq_ids = test_batch[2][batch_ind].cpu().detach().numpy() diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 80d45bc604b9..619c6cb105e0 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -592,7 +592,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): else: log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) - transcribed_texts = self._wer.ctc_decoder_predictions_tensor( + transcribed_texts, _ = self._wer.decoding.ctc_decoder_predictions_tensor( predictions=log_probs, predictions_len=encoded_len, return_hypotheses=False, ) diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 2947da215122..999e53aea41e 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -74,7 +74,7 @@ def transcribe_partial_audio( lg = logits[idx][: logits_len[idx]] hypotheses.append(lg.cpu().numpy()) else: - current_hypotheses = asr_model._wer.ctc_decoder_predictions_tensor( + current_hypotheses, _ = asr_model._wer.decoding.ctc_decoder_predictions_tensor( greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, ) diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py index e67a7d64d4fb..277c4695a939 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py @@ -280,7 +280,7 @@ def autocast(): for batch_idx, probs in enumerate(all_probs): preds = np.argmax(probs, axis=1) preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0) - pred_text = asr_model._wer.ctc_decoder_predictions_tensor(preds_tensor)[0] + pred_text = asr_model._wer.decoding.ctc_decoder_predictions_tensor(preds_tensor)[0][0] pred_split_w = pred_text.split() target_split_w = target_transcripts[batch_idx].split() diff --git a/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh b/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh index e530520b50f6..0a2836ad9385 100644 --- a/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh +++ b/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh @@ -1,8 +1,12 @@ +#!/usr/bin/env bash # install Boost package sudo apt-get install build-essential libboost-all-dev cmake zlib1g-dev libbz2-dev liblzma-dev -git clone https://github.com/NVIDIA/OpenSeq2Seq -b ctc-decoders +git clone https://github.com/NVIDIA/OpenSeq2Seq +cd OpenSeq2Seq +git checkout ctc-decoders +cd .. mv OpenSeq2Seq/decoders . rm -rf OpenSeq2Seq cd decoders ./setup.sh -cd .. +cd .. \ No newline at end of file diff --git a/tutorials/asr/ASR_with_NeMo.ipynb b/tutorials/asr/ASR_with_NeMo.ipynb index 6f3eac25e52b..959ba9750a44 100644 --- a/tutorials/asr/ASR_with_NeMo.ipynb +++ b/tutorials/asr/ASR_with_NeMo.ipynb @@ -1090,7 +1090,7 @@ " logits = torch.from_numpy(alogits[0])\n", " greedy_predictions = logits.argmax(dim=-1, keepdim=False)\n", " wer = WER(vocabulary=quartznet.decoder.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)\n", - " hypotheses = wer.ctc_decoder_predictions_tensor(greedy_predictions)\n", + " hypotheses, _ = wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)\n", " print(hypotheses)\n", " break\n" ],