From 0bac4f0fd582beb818706e2d7418107bd6f5199c Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Fri, 26 Jul 2024 06:29:19 -0700 Subject: [PATCH] Temporarily disable cuda graph based RNN-T greedy inference for (#9904) r2.0.0rc1. For very rare input shapes, a cooperative kernel might be used by pytorch for LSTM operations. This does not work within a cuda graph conditional node until CUDA 12.6. Unfortunately CUDA 12.6 is not part of the 24.07 pytorch container release, which this release of nemo is intended for. Signed-off-by: Daniel Galvez --- nemo/collections/asr/parts/submodules/rnnt_decoding.py | 4 ++-- .../asr/parts/submodules/rnnt_greedy_decoding.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index eb4088f84cae..2416d916ac13 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, loop_labels=self.cfg.greedy.get('loop_labels', True), - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), ) else: self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer( @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), ) else: diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 70ab74e7b014..7616912fe8d5 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -591,7 +591,7 @@ def __init__( preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, loop_labels: bool = True, - use_cuda_graph_decoder: bool = True, + use_cuda_graph_decoder: bool = False, ): super().__init__( decoder_model=decoder_model, @@ -2358,7 +2358,7 @@ class GreedyBatchedRNNTInferConfig: tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) loop_labels: bool = True - use_cuda_graph_decoder: bool = True + use_cuda_graph_decoder: bool = False def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed @@ -2709,7 +2709,7 @@ def __init__( preserve_frame_confidence: bool = False, include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, - use_cuda_graph_decoder: bool = True, + use_cuda_graph_decoder: bool = False, ): super().__init__( decoder_model=decoder_model,