Skip to content

Commit

Permalink
Fixing bugs in calling method ctc_decoder_predictions_tensor. (#4414)
Browse files Browse the repository at this point in the history
* updated ctc decoding calls.

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed the ones for timestamp_utils.py

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed the ones for timestamp_utils.py

Signed-off-by: Vahid <vnoroozi@nvidia.com>

* fixed the ones for timestamp_utils.py

Signed-off-by: Vahid <vnoroozi@nvidia.com>
  • Loading branch information
VahidooX authored Jun 22, 2022
1 parent 94a464f commit 41f27a5
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/asr/quantization/speech_to_text_quant_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def evaluate(asr_model, labels_map, wer):
log_probs, encoded_len, greedy_predictions = asr_model(
input_signal=test_batch[0], input_signal_length=test_batch[1]
)
hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
hypotheses += wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)[0]
for batch_ind in range(greedy_predictions.shape[0]):
seq_len = test_batch[3][batch_ind].cpu().detach().numpy()
seq_ids = test_batch[2][batch_ind].cpu().detach().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def evaluate(asr_model, asr_onnx, labels_map, wer, qat):
input_signal=processed_signal,
input_signal_length=processed_signal_length,
)
hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
hypotheses += wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)[0]
for batch_ind in range(greedy_predictions.shape[0]):
seq_len = test_batch[3][batch_ind].cpu().detach().numpy()
seq_ids = test_batch[2][batch_ind].cpu().detach().numpy()
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
else:
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)

transcribed_texts = self._wer.ctc_decoder_predictions_tensor(
transcribed_texts, _ = self._wer.decoding.ctc_decoder_predictions_tensor(
predictions=log_probs, predictions_len=encoded_len, return_hypotheses=False,
)

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def transcribe_partial_audio(
lg = logits[idx][: logits_len[idx]]
hypotheses.append(lg.cpu().numpy())
else:
current_hypotheses = asr_model._wer.ctc_decoder_predictions_tensor(
current_hypotheses, _ = asr_model._wer.decoding.ctc_decoder_predictions_tensor(
greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def autocast():
for batch_idx, probs in enumerate(all_probs):
preds = np.argmax(probs, axis=1)
preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0)
pred_text = asr_model._wer.ctc_decoder_predictions_tensor(preds_tensor)[0]
pred_text = asr_model._wer.decoding.ctc_decoder_predictions_tensor(preds_tensor)[0][0]

pred_split_w = pred_text.split()
target_split_w = target_transcripts[batch_idx].split()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#!/usr/bin/env bash
# install Boost package
sudo apt-get install build-essential libboost-all-dev cmake zlib1g-dev libbz2-dev liblzma-dev
git clone https://github.com/NVIDIA/OpenSeq2Seq -b ctc-decoders
git clone https://github.com/NVIDIA/OpenSeq2Seq
cd OpenSeq2Seq
git checkout ctc-decoders
cd ..
mv OpenSeq2Seq/decoders .
rm -rf OpenSeq2Seq
cd decoders
./setup.sh
cd ..
cd ..
2 changes: 1 addition & 1 deletion tutorials/asr/ASR_with_NeMo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@
" logits = torch.from_numpy(alogits[0])\n",
" greedy_predictions = logits.argmax(dim=-1, keepdim=False)\n",
" wer = WER(vocabulary=quartznet.decoder.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)\n",
" hypotheses = wer.ctc_decoder_predictions_tensor(greedy_predictions)\n",
" hypotheses, _ = wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)\n",
" print(hypotheses)\n",
" break\n"
],
Expand Down

0 comments on commit 41f27a5

Please sign in to comment.