Skip to content

Commit

Permalink
Fix ASR Buffered inference scripts (#5552)
Browse files Browse the repository at this point in the history
* Fix log calculation

Signed-off-by: smajumdar <titu1994@gmail.com>

* Fix log calculation

Signed-off-by: smajumdar <titu1994@gmail.com>

* Fix log check

Signed-off-by: smajumdar <titu1994@gmail.com>

* Deepcopy the hypothesis to prevent inplace corrections

Signed-off-by: smajumdar <titu1994@gmail.com>

* Deepcopy the hypothesis to prevent inplace corrections

Signed-off-by: smajumdar <titu1994@gmail.com>

* Revert changes

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add link to HF space to readme

Signed-off-by: smajumdar <titu1994@gmail.com>

Signed-off-by: smajumdar <titu1994@gmail.com>
Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com>
  • Loading branch information
titu1994 and fayejf authored Dec 6, 2022
1 parent 786a850 commit 2a61014
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Key Features
------------

* Speech processing
* `HuggingFace Space for Audio Transcription (File, Micriphone and YouTube) <https://huggingface.co/spaces/smajumdar/nemo_multilingual_language_id>`_
* `Automatic Speech Recognition (ASR) <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/intro.html>`_
* Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, Squeezeformer-CTC, Squeezeformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, ...
* Supports CTC and Transducer/RNNT losses/decoders
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp
# keep the original predictions, wrap with the number of repetitions per token and alignments
# this is done so that `rnnt_decoder_predictions_tensor()` can process this hypothesis
# in order to compute exact time stamps.
alignments = hypotheses_list[ind].alignments
alignments = copy.deepcopy(hypotheses_list[ind].alignments)
token_repetitions = [1] * len(alignments) # preserve number of repetitions per token
hypothesis = (prediction, alignments, token_repetitions)
else:
Expand Down
14 changes: 12 additions & 2 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def __init__(self, asr_model, chunk_size, buffer_size):
'''

self.NORM_CONSTANT = 1e-5
if asr_model.cfg.preprocessor.log:
if hasattr(asr_model.preprocessor, 'log') and asr_model.preprocessor.log:
self.ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
else:
self.ZERO_LEVEL_SPEC_DB_VAL = 0.0
Expand Down Expand Up @@ -576,7 +576,7 @@ def __init__(self, asr_model, frame_len=1.6, batch_size=4, total_buffer=4.0):
frame_overlap: duration of overlaps before and after current frame, seconds
offset: number of symbols to drop for smooth streaming
'''
if asr_model.cfg.preprocessor.log:
if hasattr(asr_model.preprocessor, 'log') and asr_model.preprocessor.log:
self.ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal
else:
self.ZERO_LEVEL_SPEC_DB_VAL = 0.0
Expand Down Expand Up @@ -963,6 +963,7 @@ def __init__(

self.all_alignments = [[] for _ in range(self.batch_size)]
self.all_preds = [[] for _ in range(self.batch_size)]
self.all_timestamps = [[] for _ in range(self.batch_size)]
self.previous_hypotheses = None
self.batch_index_map = {
idx: idx for idx in range(self.batch_size)
Expand Down Expand Up @@ -990,6 +991,7 @@ def reset(self):

self.all_alignments = [[] for _ in range(self.batch_size)]
self.all_preds = [[] for _ in range(self.batch_size)]
self.all_timestamps = [[] for _ in range(self.batch_size)]
self.previous_hypotheses = None
self.batch_index_map = {idx: idx for idx in range(self.batch_size)}

Expand Down Expand Up @@ -1110,6 +1112,14 @@ def _get_batch_preds(self):
if not has_signal_ended:
self.all_preds[global_index_key].append(pred.cpu().numpy())

timestamps = [hyp.timestep for hyp in best_hyp]
for idx, timestep in enumerate(timestamps):
global_index_key = new_batch_keys[idx] # get index of this sample in the global batch

has_signal_ended = self.frame_bufferer.signal_end[global_index_key]
if not has_signal_ended:
self.all_timestamps[global_index_key].append(timestep)

if self.stateful_decoding:
# State resetting is being done on sub-batch only, global index information is not being updated
reset_states = self.asr_model.decoder.initialize_state(encoded)
Expand Down

0 comments on commit 2a61014

Please sign in to comment.