From 2f3bf495d40f8e1d2215f51eea18a047b3877dad Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Tue, 23 Jan 2024 23:05:08 +0400 Subject: [PATCH] Temporarily use the previous RNN-T decoding algorithm as default (#8226) Signed-off-by: Vladimir Bataev --- nemo/collections/asr/parts/submodules/rnnt_decoding.py | 2 +- .../asr/parts/submodules/rnnt_greedy_decoding.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 3f4e0bc6eac05..5c474ee21f8fd 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -316,7 +316,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, - loop_labels=self.cfg.greedy.get('loop_labels', True), + loop_labels=self.cfg.greedy.get('loop_labels', False), ) else: self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer( diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 83fdad35f6de9..fafa18b631265 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -566,10 +566,10 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. loop_labels: Switching between decoding algorithms. Both algorithms produce equivalent results. - loop_labels=True (default) algorithm is faster (especially for large batches) but can use a bit more memory + loop_labels=True algorithm is faster (especially for large batches) but can use a bit more memory (negligible overhead compared to the amount of memory used by the encoder). - loop_labels=False is an implementation of a traditional decoding algorithm, which iterates over frames - (encoder output vectors), and in the inner loop, decodes labels for the current frame one by one, + loop_labels=False (default) is an implementation of a traditional decoding algorithm, which iterates over + frames (encoder output vectors), and in the inner loop, decodes labels for the current frame one by one, stopping when is found. loop_labels=True iterates over labels, on each step finding the next non-blank label (evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls @@ -586,7 +586,7 @@ def __init__( preserve_alignments: bool = False, preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, - loop_labels: bool = True, + loop_labels: bool = False, ): super().__init__( decoder_model=decoder_model, @@ -2421,7 +2421,7 @@ class GreedyBatchedRNNTInferConfig: preserve_alignments: bool = False preserve_frame_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) - loop_labels: bool = True + loop_labels: bool = False def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed