diff --git a/README.rst b/README.rst index 43bb89139045..bf089101c198 100644 --- a/README.rst +++ b/README.rst @@ -61,6 +61,7 @@ Key Features ------------ * Speech processing + * `HuggingFace Space for Audio Transcription (File, Micriphone and YouTube) `_ * `Automatic Speech Recognition (ASR) `_ * Supported models: Jasper, QuartzNet, CitriNet, Conformer-CTC, Conformer-Transducer, Squeezeformer-CTC, Squeezeformer-Transducer, ContextNet, LSTM-Transducer (RNNT), LSTM-CTC, ... * Supports CTC and Transducer/RNNT losses/decoders diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 27ec9e43b897..476b5a43c663 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -445,7 +445,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp # keep the original predictions, wrap with the number of repetitions per token and alignments # this is done so that `rnnt_decoder_predictions_tensor()` can process this hypothesis # in order to compute exact time stamps. - alignments = hypotheses_list[ind].alignments + alignments = copy.deepcopy(hypotheses_list[ind].alignments) token_repetitions = [1] * len(alignments) # preserve number of repetitions per token hypothesis = (prediction, alignments, token_repetitions) else: diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 784cf24207a0..2852b75fd9f4 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -362,7 +362,7 @@ def __init__(self, asr_model, chunk_size, buffer_size): ''' self.NORM_CONSTANT = 1e-5 - if asr_model.cfg.preprocessor.log: + if hasattr(asr_model.preprocessor, 'log') and asr_model.preprocessor.log: self.ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal else: self.ZERO_LEVEL_SPEC_DB_VAL = 0.0 @@ -576,7 +576,7 @@ def __init__(self, asr_model, frame_len=1.6, batch_size=4, total_buffer=4.0): frame_overlap: duration of overlaps before and after current frame, seconds offset: number of symbols to drop for smooth streaming ''' - if asr_model.cfg.preprocessor.log: + if hasattr(asr_model.preprocessor, 'log') and asr_model.preprocessor.log: self.ZERO_LEVEL_SPEC_DB_VAL = -16.635 # Log-Melspectrogram value for zero signal else: self.ZERO_LEVEL_SPEC_DB_VAL = 0.0 @@ -963,6 +963,7 @@ def __init__( self.all_alignments = [[] for _ in range(self.batch_size)] self.all_preds = [[] for _ in range(self.batch_size)] + self.all_timestamps = [[] for _ in range(self.batch_size)] self.previous_hypotheses = None self.batch_index_map = { idx: idx for idx in range(self.batch_size) @@ -990,6 +991,7 @@ def reset(self): self.all_alignments = [[] for _ in range(self.batch_size)] self.all_preds = [[] for _ in range(self.batch_size)] + self.all_timestamps = [[] for _ in range(self.batch_size)] self.previous_hypotheses = None self.batch_index_map = {idx: idx for idx in range(self.batch_size)} @@ -1110,6 +1112,14 @@ def _get_batch_preds(self): if not has_signal_ended: self.all_preds[global_index_key].append(pred.cpu().numpy()) + timestamps = [hyp.timestep for hyp in best_hyp] + for idx, timestep in enumerate(timestamps): + global_index_key = new_batch_keys[idx] # get index of this sample in the global batch + + has_signal_ended = self.frame_bufferer.signal_end[global_index_key] + if not has_signal_ended: + self.all_timestamps[global_index_key].append(timestep) + if self.stateful_decoding: # State resetting is being done on sub-batch only, global index information is not being updated reset_states = self.asr_model.decoder.initialize_state(encoded)