Skip to content

Commit

Permalink
Temporarily use the previous RNN-T decoding algorithm as default (NVI…
Browse files Browse the repository at this point in the history
…DIA#8226)

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
  • Loading branch information
artbataev authored Jan 23, 2024
1 parent be75c7c commit 2f3bf49
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
loop_labels=self.cfg.greedy.get('loop_labels', False),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand Down
10 changes: 5 additions & 5 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,10 +566,10 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer):
- 'lin' for using the linear mapping.
- 'exp' for using exponential mapping with linear shift.
loop_labels: Switching between decoding algorithms. Both algorithms produce equivalent results.
loop_labels=True (default) algorithm is faster (especially for large batches) but can use a bit more memory
loop_labels=True algorithm is faster (especially for large batches) but can use a bit more memory
(negligible overhead compared to the amount of memory used by the encoder).
loop_labels=False is an implementation of a traditional decoding algorithm, which iterates over frames
(encoder output vectors), and in the inner loop, decodes labels for the current frame one by one,
loop_labels=False (default) is an implementation of a traditional decoding algorithm, which iterates over
frames (encoder output vectors), and in the inner loop, decodes labels for the current frame one by one,
stopping when <blank> is found.
loop_labels=True iterates over labels, on each step finding the next non-blank label
(evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls
Expand All @@ -586,7 +586,7 @@ def __init__(
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
loop_labels: bool = False,
):
super().__init__(
decoder_model=decoder_model,
Expand Down Expand Up @@ -2421,7 +2421,7 @@ class GreedyBatchedRNNTInferConfig:
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True
loop_labels: bool = False

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down

0 comments on commit 2f3bf49

Please sign in to comment.