Skip to content

Commit

Permalink
Temporarily disable cuda graph based RNN-T greedy inference for (#9904)
Browse files Browse the repository at this point in the history
r2.0.0rc1.

For very rare input shapes, a cooperative kernel might be used by
pytorch for LSTM operations. This does not work within a cuda graph
conditional node until CUDA 12.6.

Unfortunately CUDA 12.6 is not part of the 24.07 pytorch container
release, which this release of nemo is intended for.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
  • Loading branch information
galv authored and web-flow committed Jul 26, 2024
1 parent 74c2caf commit 0bac4f0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand All @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
)

else:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def __init__(
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
use_cuda_graph_decoder: bool = True,
use_cuda_graph_decoder: bool = False,
):
super().__init__(
decoder_model=decoder_model,
Expand Down Expand Up @@ -2358,7 +2358,7 @@ class GreedyBatchedRNNTInferConfig:
tdt_include_duration_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True
use_cuda_graph_decoder: bool = True
use_cuda_graph_decoder: bool = False

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down Expand Up @@ -2709,7 +2709,7 @@ def __init__(
preserve_frame_confidence: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
use_cuda_graph_decoder: bool = True,
use_cuda_graph_decoder: bool = False,
):
super().__init__(
decoder_model=decoder_model,
Expand Down

0 comments on commit 0bac4f0

Please sign in to comment.