diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index e14424cec5c1..0539f961a1ca 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from abc import ABC, abstractmethod -from typing import List +from abc import ABC +from typing import List, Optional import torch +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.exportable import Exportable @@ -171,6 +172,52 @@ def on_after_backward(self): logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.') self.zero_grad() + def on_train_epoch_start(self) -> None: + """ + Decoder with CUDA graphs does not release memory, thus we disable it for training epoch. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs + """ + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_train_epoch_end(self) -> None: + """ + After training, we can enable the decoder with CUDA graphs. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_validation_epoch_start(self) -> None: + """ + For validation, we enable CUDA graphs to speedup validation. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs. + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_validation_epoch_end(self) -> Optional[dict[str, dict[str, torch.Tensor]]]: + """ + After validation, we disable CUDA graphs, since `validation` can be called in training loop, and + training will continue after validation + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs. + """ + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + return super().on_validation_epoch_end() + + def on_test_epoch_start(self) -> None: + """ + For testing, we enable CUDA graphs to speedup validation. + We do not need to disable CUDA graphs after testing, since `test` cannot be called in training loop. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs. + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + + def on_predict_epoch_start(self) -> None: + """ + For predicting, we enable CUDA graphs to speedup validation. + We do not need to disable CUDA graphs after predicting, since `predict` cannot be called in training loop. + EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs + """ + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding") + class ExportableEncDecModel(Exportable): """ diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 055066c00660..2355cfb7005b 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -312,7 +312,9 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: batch = y.size(0) # state contains context_size - 1 elements for each utterance in batch, # consistent with the state returned from StatelessNet.forward - state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx] + state = [ + torch.full([batch, self.context_size - 1], fill_value=self.blank_idx, dtype=torch.long, device=y.device) + ] return state def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): diff --git a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py index 388737443fd4..93cef4d4138e 100644 --- a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py @@ -292,14 +292,21 @@ def __call__( partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, ): if partial_hypotheses is not None: - raise NotImplementedError("`partial_hypotheses` support is not available with cuda graphs (but could be)") + raise NotImplementedError( + "`partial_hypotheses` support is not available " + "with Frame-Looping algorithm with Cuda graphs (not implemented yet)" + ) if self.caller.preserve_alignments: - raise NotImplementedError("`preserve_alignments` support is not available with cuda graphs (but could be)") + raise NotImplementedError( + "`preserve_alignments` support is not available" + "with Frame-Looping algorithm with Cuda graphs (not implemented yet)" + ) if self.caller.preserve_frame_confidence: raise NotImplementedError( - "`preserve_frame_confidence` support is not available with cuda graphs (but could be)" + "`preserve_frame_confidence` support is not available" + "with Frame-Looping algorithm with Cuda graphs (not implemented yet)" ) batch_size = x.shape[0] diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 71079f4b6382..5fa225864f8c 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, loop_labels=self.cfg.greedy.get('loop_labels', True), - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), ) else: self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer( @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_frame_confidence=self.preserve_frame_confidence, include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, - use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), + use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), ) else: diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index e5de99cf0776..b2fa9b85b5fd 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -38,6 +38,7 @@ from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType @@ -508,7 +509,7 @@ def _greedy_decode( return hypothesis -class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): +class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): """A batch level greedy transducer decoder. Batch level greedy decoding, performed auto-regressively. @@ -589,7 +590,7 @@ def __init__( preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, loop_labels: bool = True, - use_cuda_graph_decoder: bool = False, + use_cuda_graph_decoder: bool = True, ): super().__init__( decoder_model=decoder_model, @@ -602,13 +603,14 @@ def __init__( ) self.use_cuda_graph_decoder = use_cuda_graph_decoder + self.loop_labels = loop_labels # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique self._decoding_computer = None if self.decoder.blank_as_pad: - if loop_labels: - # default (faster) algo: loop over labels + if self.loop_labels: + # Label-Looping algorithm (default, faster) self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels self._decoding_computer = GreedyBatchedRNNTLoopLabelsComputer( decoder=self.decoder, @@ -618,20 +620,74 @@ def __init__( preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, confidence_method_cfg=confidence_method_cfg, - allow_cuda_graphs=use_cuda_graph_decoder, + allow_cuda_graphs=self.use_cuda_graph_decoder, ) - elif use_cuda_graph_decoder: - from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import ( - RNNTGreedyDecodeCudaGraph, - ) - - self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self) else: - # previous algo: loop over frames - self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + # Frame-Looping algorithm + if not self.use_cuda_graph_decoder: + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + else: + if self.preserve_alignments: + logging.warning("`preserve_alignments` is not implemented for Frame-Looping + CUDA graphs") + self.use_cuda_graph_decoder = False + if self.preserve_frame_confidence: + logging.warning( + "`preserve_frame_confidence` is not implemented for Frame-Looping + CUDA graphs" + ) + self.use_cuda_graph_decoder = False + if not torch.cuda.is_available(): + self.use_cuda_graph_decoder = False + + if self.use_cuda_graph_decoder: + try: + from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import ( + RNNTGreedyDecodeCudaGraph, + ) + + self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self) + except (ImportError, ModuleNotFoundError, ValueError) as e: + self.use_cuda_graph_decoder = False + logging.warning(f"Cannot use decoder with CUDA graphs, reason: {e.msg}") + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + else: + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames else: self._greedy_decode = self._greedy_decode_masked + def disable_cuda_graphs(self): + """Disable CUDA graphs (e.g., for decoding in training)""" + if not self.use_cuda_graph_decoder: + # CUDA graphs not allowed, nothing to do + return + + if not self.decoder.blank_as_pad: + # blank as pad uses decoding without CUDA graphs + return + + if self.loop_labels: + # Label-Looping implementation + self._decoding_computer.disable_cuda_graphs() + else: + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs (if allowed)""" + if not self.use_cuda_graph_decoder: + # CUDA graphs not allowed, nothing to do + return + + if not self.decoder.blank_as_pad: + # blank as pad uses decoding without CUDA graphs + return + + if self.loop_labels: + # Label-Looping implementation + self._decoding_computer.maybe_enable_cuda_graphs() + else: + from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph + + self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self) + @typecheck() def forward( self, @@ -2302,7 +2358,7 @@ class GreedyBatchedRNNTInferConfig: tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) loop_labels: bool = True - use_cuda_graph_decoder: bool = False + use_cuda_graph_decoder: bool = True def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed @@ -2580,7 +2636,7 @@ def _greedy_decode( return hypothesis -class GreedyBatchedTDTInfer(_GreedyRNNTInfer): +class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): """A batch level greedy TDT decoder. Batch level greedy decoding, performed auto-regressively. Args: @@ -2652,7 +2708,7 @@ def __init__( preserve_frame_confidence: bool = False, include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, - use_cuda_graph_decoder: bool = False, + use_cuda_graph_decoder: bool = True, ): super().__init__( decoder_model=decoder_model, @@ -2759,3 +2815,13 @@ def _greedy_decode_blank_as_pad_loop_labels( for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)): hyp.dec_state = state return hyps + + def disable_cuda_graphs(self): + """Disable CUDA graphs (e.g., for decoding in training)""" + if self._decoding_computer is not None: + self._decoding_computer.disable_cuda_graphs() + + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs (if allowed)""" + if self._decoding_computer is not None: + self._decoding_computer.maybe_enable_cuda_graphs() diff --git a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py index 92cb8a36aeb5..b920dba09cfd 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple +from dataclasses import dataclass, field +from typing import Any, Optional, Tuple, Union import numpy as np import torch @@ -21,6 +22,7 @@ from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.utils.cuda_python_utils import ( check_cuda_python_cuda_graphs_conditional_nodes_supported, cu_call, @@ -28,6 +30,7 @@ with_conditional_node, ) from nemo.utils import logging +from nemo.utils.enum import PrettyStrEnum try: from cuda import cudart @@ -161,7 +164,17 @@ def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: ) -class GreedyBatchedRNNTLoopLabelsComputer(ConfidenceMethodMixin): +@dataclass +class SeparateGraphsLoopLabels: + """Class to store Cuda graphs for decoding when separate graphs are used""" + + before_outer_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + inner_loop_code: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + after_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + + +class GreedyBatchedRNNTLoopLabelsComputer(WithOptionalCudaGraphs, ConfidenceMethodMixin): """ Label Looping algorithm implementation: optimized batched greedy decoding. Callable. Iterates over labels, on each step finding the next non-blank label @@ -174,6 +187,16 @@ class GreedyBatchedRNNTLoopLabelsComputer(ConfidenceMethodMixin): INITIAL_MAX_TIME = 375 # initial max time, used to init state for Cuda graphs CUDA_PROGRAM_NAME = b"while_loop_labels_conditional_rnnt.cu" + class CudaGraphsMode(PrettyStrEnum): + FULL_GRAPH = "full_graph" # Cuda graphs with conditional nodes, fastest implementation + NO_WHILE_LOOPS = "no_while_loops" # Decoding with PyTorch while loops + partial Cuda graphs + NO_GRAPHS = "no_graphs" # decoding without graphs, stateful implementation, only for testing purposes + + separate_graphs: Optional[SeparateGraphsLoopLabels] + full_graph: Optional[torch.cuda.CUDAGraph] + cuda_graphs_mode: Optional[CudaGraphsMode] + state: Optional[LoopLabelsState] + def __init__( self, decoder, @@ -203,24 +226,66 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence + self.allow_cuda_graphs = allow_cuda_graphs self._SOS = self._blank_index self._init_confidence_method(confidence_method_cfg=confidence_method_cfg) assert self._SOS == self._blank_index # "blank as pad" algorithm only - self.use_cuda_graphs = allow_cuda_graphs + self.state = None + self.full_graph = None + self.separate_graphs = None - if self.use_cuda_graphs and self.max_symbols is None: - logging.warning("Max symbols is None, which is not allowed with Cuda graphs.") - self.use_cuda_graphs = False + self.cuda_graphs_mode = None + self.maybe_enable_cuda_graphs() - if self.use_cuda_graphs: + def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): + """ + Method to set graphs mode. Use only for testing purposes. + For debugging the algorithm use "no_graphs" mode, since it is impossible to debug CUDA graphs directly. + """ + self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None + self.state = None + + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs if conditions met""" + if self.cuda_graphs_mode is not None: + # CUDA graphs are already enabled + return + + if not self.allow_cuda_graphs: + self.cuda_graphs_mode = None + else: + # cuda graphs are allowed + # check basic requirements for cuda graphs + if self.max_symbols is None: + logging.warning("Max symbols per step is None, which is not allowed with Cuda graphs. Setting to `10`") + self.max_symbols = 10 + # basic requirements met, need to check while loops try: check_cuda_python_cuda_graphs_conditional_nodes_supported() - except ImportError as e: - logging.warning(f"No conditional node support. Cuda graphs will be disabled,\n{e.msg}") - self.use_cuda_graphs = False - - self.state: Optional[LoopLabelsState] = None + self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH + except (ImportError, ModuleNotFoundError) as e: + logging.warning( + "No conditional node support for Cuda.\n" + "Cuda graphs with while loops are disabled, decoding speed will be slower\n" + f"Reason: {e.msg}" + ) + self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS + self.reset_cuda_graphs_state() + + def disable_cuda_graphs(self): + """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" + if self.cuda_graphs_mode is None: + # nothing to disable + return + self.cuda_graphs_mode = None + self.reset_cuda_graphs_state() + + def reset_cuda_graphs_state(self): + """Reset state to release memory (for CUDA graphs implementations)""" + self.state = None + self.full_graph = None + self.separate_graphs = None def loop_labels_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, @@ -237,6 +302,7 @@ def loop_labels_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) + float_dtype = encoder_output_projected.dtype # init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state # init empty batched hypotheses @@ -244,7 +310,7 @@ def loop_labels_torch( batch_size=batch_size, init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time, device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=float_dtype, ) # sample state, will be replaced further when the decoding for hypothesis is done last_decoder_state = self.decoder.initialize_state(encoder_output_projected) @@ -256,7 +322,7 @@ def loop_labels_torch( logits_dim=self.joint.num_classes_with_blank, init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=float_dtype, store_alignments=self.preserve_alignments, store_frame_confidence=self.preserve_frame_confidence, ) @@ -312,7 +378,7 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -350,7 +416,7 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -413,6 +479,8 @@ def loop_labels_cuda_graphs( encoder_output: output from the encoder encoder_output_length: lengths of the utterances in `encoder_output` """ + assert self.cuda_graphs_mode is not None + # do not recalculate joint projection, project only once encoder_output = self.joint.project_encoder(encoder_output) current_batch_size = encoder_output.shape[0] @@ -430,16 +498,27 @@ def loop_labels_cuda_graphs( self.state.encoder_output_length[: encoder_output_length.shape[0]].copy_(encoder_output_length) # set length to zero for elements outside the current batch self.state.encoder_output_length[current_batch_size:].fill_(0) - self.graph.replay() - - # example manual loop (can be used instead of graph.replay()) - # self._before_outer_loop() - # while self.state.active_mask_any.item(): - # self._before_inner_loop_get_decoder_output() - # self._before_inner_loop_get_joint_output() - # while self.state.advance_mask_any.item(): - # self._inner_loop_code() - # self._after_inner_loop() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self.full_graph.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self.separate_graphs.before_outer_loop.replay() + while self.state.active_mask_any.item(): + self.separate_graphs.before_inner_loop.replay() + while self.state.advance_mask_any.item(): + self.separate_graphs.inner_loop_code.replay() + self.separate_graphs.after_inner_loop.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # this mode is only for testing purposes + # manual loop instead of using graphs + self._before_outer_loop() + while self.state.active_mask_any.item(): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + while self.state.advance_mask_any.item(): + self._inner_loop_code() + self._after_inner_loop() + else: + raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") return ( self.state.batched_hyps, @@ -509,12 +588,49 @@ def _graph_reinitialize( ) # to avoid recalculation of joint projection, store decoder output in state self.state.decoder_output = self.joint.project_prednet(decoder_output) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self._full_graph_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self._partial_graphs_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # no graphs needed + pass + else: + raise NotImplementedError + + def _partial_graphs_compile(self): + """Compile decoding by parts""" + # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. + stream_for_graph = torch.cuda.Stream(self.state.device) + self.separate_graphs = SeparateGraphsLoopLabels() + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_outer_loop, stream=stream_for_graph + ): + self._before_outer_loop() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_inner_loop, stream=stream_for_graph + ): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.inner_loop_code, stream=stream_for_graph + ): + self._inner_loop_code() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.after_inner_loop, stream=stream_for_graph + ): + self._after_inner_loop() + def _full_graph_compile(self): + """Compile full graph for decoding""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) - self.graph = torch.cuda.CUDAGraph() + self.full_graph = torch.cuda.CUDAGraph() with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.graph, stream=stream_for_graph + self.full_graph, stream=stream_for_graph ): self._before_outer_loop() @@ -612,12 +728,13 @@ def _before_inner_loop_get_joint_output(self): # blank_mask = self.labels == self._blank_index self.state.time_indices_current_labels.copy_(self.state.time_indices, non_blocking=True) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.active_mask, time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=self.state.labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -662,12 +779,13 @@ def _inner_loop_code(self): torch.where(self.state.advance_mask, more_scores, self.state.scores, out=self.state.scores) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.advance_mask, time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)) + confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -721,7 +839,7 @@ def _after_inner_loop(self): def __call__( self, x: torch.Tensor, out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: - if self.use_cuda_graphs and x.device.type == "cuda": + if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) return self.loop_labels_torch(encoder_output=x, encoder_output_length=out_len) diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index b136446d97fb..4e514966db2b 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -13,6 +13,7 @@ # limitations under the License. +from dataclasses import dataclass, field from typing import Any, Optional, Tuple, Union import numpy as np @@ -22,6 +23,7 @@ from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.utils.cuda_python_utils import ( check_cuda_python_cuda_graphs_conditional_nodes_supported, cu_call, @@ -29,6 +31,7 @@ with_conditional_node, ) from nemo.utils import logging +from nemo.utils.enum import PrettyStrEnum try: from cuda import cudart @@ -167,7 +170,17 @@ def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: ) -class GreedyBatchedTDTLoopLabelsComputer(ConfidenceMethodMixin): +@dataclass +class SeparateGraphsLoopLabels: + """Class to store Cuda graphs for decoding when separate graphs are used""" + + before_outer_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + inner_loop_code: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + after_inner_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + + +class GreedyBatchedTDTLoopLabelsComputer(WithOptionalCudaGraphs, ConfidenceMethodMixin): """ Label Looping algorithm implementation: optimized batched greedy decoding. Callable. Iterates over labels, on each step finding the next non-blank label @@ -180,6 +193,16 @@ class GreedyBatchedTDTLoopLabelsComputer(ConfidenceMethodMixin): INITIAL_MAX_TIME = 375 # initial max time, used to init state for Cuda graphs CUDA_PROGRAM_NAME = b"while_loop_labels_conditional_tdt.cu" + class CudaGraphsMode(PrettyStrEnum): + FULL_GRAPH = "full_graph" # Cuda graphs with conditional nodes, fastest implementation + NO_WHILE_LOOPS = "no_while_loops" # Decoding with PyTorch while loops + partial Cuda graphs + NO_GRAPHS = "no_graphs" # decoding without graphs, stateful implementation, only for testing purposes + + separate_graphs: Optional[SeparateGraphsLoopLabels] + full_graph: Optional[torch.cuda.CUDAGraph] + cuda_graphs_mode: Optional[CudaGraphsMode] + state: Optional[LoopLabelsState] + def __init__( self, decoder, @@ -215,25 +238,67 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence + self.allow_cuda_graphs = allow_cuda_graphs self.include_duration_confidence = include_duration_confidence self._SOS = self._blank_index self._init_confidence_method(confidence_method_cfg=confidence_method_cfg) assert self._SOS == self._blank_index # "blank as pad" algorithm only - self.use_cuda_graphs = allow_cuda_graphs + self.state = None + self.full_graph = None + self.separate_graphs = None - if self.use_cuda_graphs and self.max_symbols is None: - logging.warning("Max symbols is None, which is not allowed with Cuda graphs.") - self.use_cuda_graphs = False + self.cuda_graphs_mode = None + self.maybe_enable_cuda_graphs() - if self.use_cuda_graphs: + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs if conditions met""" + if self.cuda_graphs_mode is not None: + # CUDA graphs are enabled + return + + if not self.allow_cuda_graphs: + self.cuda_graphs_mode = None + else: + # cuda graphs are allowed + # check basic requirements for cuda graphs + if self.max_symbols is None: + logging.warning("Max symbols per step is None, which is not allowed with Cuda graphs. Setting to `10`") + self.max_symbols = 10 + # basic requirements met, need to check while loops try: check_cuda_python_cuda_graphs_conditional_nodes_supported() - except ImportError as e: - logging.warning(f"No conditional node support. Cuda graphs will be disabled,\n{e.msg}") - self.use_cuda_graphs = False - - self.state: Optional[LoopLabelsState] = None + self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH + except (ImportError, ModuleNotFoundError) as e: + logging.warning( + "No conditional node support for Cuda.\n" + "Cuda graphs with while loops are disabled, decoding speed will be slower\n" + f"Reason: {e.msg}" + ) + self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS + self.reset_cuda_graphs_state() + + def disable_cuda_graphs(self): + """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" + if self.cuda_graphs_mode is None: + # nothing to disable + return + self.cuda_graphs_mode = None + self.reset_cuda_graphs_state() + + def reset_cuda_graphs_state(self): + """Reset state to release memory (for CUDA graphs implementations)""" + self.state = None + self.full_graph = None + self.separate_graphs = None + + def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): + """ + Method to set graphs mode. Use only for testing purposes. + For debugging the algorithm use "no_graphs" mode, since it is impossible to debug CUDA graphs directly. + """ + self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None + self.state = None def loop_labels_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, @@ -250,7 +315,7 @@ def loop_labels_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) - dtype = encoder_output_projected.dtype + float_dtype = encoder_output_projected.dtype # init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state # init empty batched hypotheses @@ -258,7 +323,7 @@ def loop_labels_torch( batch_size=batch_size, init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time, device=device, - float_dtype=dtype, + float_dtype=float_dtype, ) # sample state, will be replaced further when the decoding for hypothesis is done last_decoder_state = self.decoder.initialize_state(encoder_output_projected) @@ -270,7 +335,7 @@ def loop_labels_torch( logits_dim=self.joint.num_classes_with_blank, init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens device=device, - float_dtype=dtype, + float_dtype=float_dtype, store_alignments=self.preserve_alignments, store_frame_confidence=self.preserve_frame_confidence, with_duration_confidence=self.include_duration_confidence, @@ -338,16 +403,18 @@ def loop_labels_torch( confidence=torch.stack( ( self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), ), dim=-1, ) if self.include_duration_confidence - else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(dtype=dtype) + else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=float_dtype + ) if self.preserve_frame_confidence else None, ) @@ -390,17 +457,17 @@ def loop_labels_torch( confidence=torch.stack( ( self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( - dtype=dtype + dtype=float_dtype ), ), dim=-1, ) if self.include_duration_confidence else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( - dtype=dtype + dtype=float_dtype ) if self.preserve_frame_confidence else None, @@ -467,6 +534,8 @@ def loop_labels_cuda_graphs( encoder_output: output from the encoder encoder_output_length: lengths of the utterances in `encoder_output` """ + assert self.cuda_graphs_mode is not None + # do not recalculate joint projection, project only once encoder_output = self.joint.project_encoder(encoder_output) current_batch_size = encoder_output.shape[0] @@ -484,16 +553,27 @@ def loop_labels_cuda_graphs( self.state.encoder_output_length[: encoder_output_length.shape[0]].copy_(encoder_output_length) # set length to zero for elements outside the current batch self.state.encoder_output_length[current_batch_size:].fill_(0) - self.graph.replay() - - # example manual loop (can be used instead of graph.replay()) - # self._before_outer_loop() - # while self.state.active_mask_any.item(): - # self._before_inner_loop_get_decoder_output() - # self._before_inner_loop_get_joint_output() - # while self.state.advance_mask_any.item(): - # self._inner_loop_code() - # self._after_inner_loop() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self.full_graph.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self.separate_graphs.before_outer_loop.replay() + while self.state.active_mask_any.item(): + self.separate_graphs.before_inner_loop.replay() + while self.state.advance_mask_any.item(): + self.separate_graphs.inner_loop_code.replay() + self.separate_graphs.after_inner_loop.replay() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # this mode is only for testing purposes + # manual loop instead of using graphs + self._before_outer_loop() + while self.state.active_mask_any.item(): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + while self.state.advance_mask_any.item(): + self._inner_loop_code() + self._after_inner_loop() + else: + raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") return ( self.state.batched_hyps, @@ -565,12 +645,49 @@ def _graph_reinitialize( ) # to avoid recalculation of joint projection, store decoder output in state self.state.decoder_output = self.joint.project_prednet(decoder_output) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + self._full_graph_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: + self._partial_graphs_compile() + elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: + # no graphs needed + pass + else: + raise NotImplementedError + + def _partial_graphs_compile(self): + """Compile decoding by parts""" + # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. + stream_for_graph = torch.cuda.Stream(self.state.device) + self.separate_graphs = SeparateGraphsLoopLabels() + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_outer_loop, stream=stream_for_graph + ): + self._before_outer_loop() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.before_inner_loop, stream=stream_for_graph + ): + self._before_inner_loop_get_decoder_output() + self._before_inner_loop_get_joint_output() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.inner_loop_code, stream=stream_for_graph + ): + self._inner_loop_code() + + with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( + self.separate_graphs.after_inner_loop, stream=stream_for_graph + ): + self._after_inner_loop() + def _full_graph_compile(self): + """Compile full graph for decoding""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) - self.graph = torch.cuda.CUDAGraph() + self.full_graph = torch.cuda.CUDAGraph() with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.graph, stream=stream_for_graph + self.full_graph, stream=stream_for_graph ): self._before_outer_loop() @@ -651,7 +768,6 @@ def _before_inner_loop_get_joint_output(self): # stage 2: get joint output, iteratively seeking for non-blank labels # blank label in `labels` tensor means "end of hypothesis" (for this index) self.state.active_mask_prev.copy_(self.state.active_mask, non_blocking=True) - dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -675,6 +791,7 @@ def _before_inner_loop_get_joint_output(self): # for blank labels force duration >= 1 durations.masked_fill_(torch.logical_and(durations == 0, self.state.blank_mask), 1) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.active_mask, time_indices=self.state.time_indices_current_labels, @@ -684,17 +801,17 @@ def _before_inner_loop_get_joint_output(self): ( self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), self._get_confidence_tensor( F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), ), dim=-1, ) if self.include_duration_confidence else self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype) + ).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -720,7 +837,6 @@ def _inner_loop_code(self): self.state.time_indices_current_labels, out=self.state.time_indices_current_labels, ) - dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -742,6 +858,7 @@ def _inner_loop_code(self): torch.where(self.state.advance_mask, more_scores, self.state.scores, out=self.state.scores) if self.state.alignments is not None: + float_dtype = self.state.float_dtype self.state.alignments.add_results_masked_no_checks_( active_mask=self.state.advance_mask, time_indices=self.state.time_indices_current_labels, @@ -751,17 +868,17 @@ def _inner_loop_code(self): ( self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), self._get_confidence_tensor( F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) - ).to(dtype=dtype), + ).to(dtype=float_dtype), ), dim=-1, ) if self.include_duration_confidence else self._get_confidence_tensor( F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) - ).to(dtype=dtype) + ).to(dtype=float_dtype) if self.preserve_frame_confidence else None, ) @@ -822,7 +939,7 @@ def _after_inner_loop(self): def __call__( self, x: torch.Tensor, out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: - if self.use_cuda_graphs and x.device.type == "cuda": + if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) return self.loop_labels_torch(encoder_output=x, encoder_output_length=out_len) diff --git a/nemo/collections/common/parts/optional_cuda_graphs.py b/nemo/collections/common/parts/optional_cuda_graphs.py new file mode 100644 index 000000000000..2417d9e00370 --- /dev/null +++ b/nemo/collections/common/parts/optional_cuda_graphs.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Optional + +import torch.nn as nn + +from nemo.utils import logging + + +class WithOptionalCudaGraphs(abc.ABC): + """ + Abstract interface for modules with CUDA graphs. + Allows to enable/disable CUDA graphs on the fly. + """ + + @classmethod + def disable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optional[str] = None): + """ + Disable CUDA graphs Enable CUDA graphs, finding submodule recursively. + + Args: + module: instance of nn.Module + attribute_path: field containing instance of WithOptionalCudaGraphs + E.g., "decoding.decoding" means that ".decoding.decoding" are checked. + If None, "" is checked. + """ + attributes = attribute_path.split(".") if attribute_path else [] + + for name, submodule in module.named_modules(): + object_to_check = submodule + try: + # recursively get attribute by iterating attribute_path + for attribute in attributes: + object_to_check = getattr(object_to_check, attribute) + except AttributeError: + continue # loop over modules, no attribute + + if isinstance(object_to_check, cls): + object_to_check.disable_cuda_graphs() + logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) + + @classmethod + def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optional[str] = None): + """ + Enable CUDA graphs, finding submodule recursively + + Args: + module: instance of nn.Module + attribute_path: field containing instance of WithOptionalCudaGraphs + E.g., "decoding.decoding" means that ".decoding.decoding" are checked. + If None, "" is checked. + """ + attributes = attribute_path.split(".") if attribute_path else [] + + for name, submodule in module.named_modules(): + object_to_check = submodule + try: + # recursively get attribute by iterating attribute_path + for attribute in attributes: + object_to_check = getattr(object_to_check, attribute) + except AttributeError: + continue # loop over modules, no attribute + + if isinstance(object_to_check, cls): + object_to_check.maybe_enable_cuda_graphs() + logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) + + @abc.abstractmethod + def disable_cuda_graphs(self): + """Disable (maybe temporary) CUDA graphs""" + raise NotImplementedError + + @abc.abstractmethod + def maybe_enable_cuda_graphs(self): + """Enable CUDA graphs if all conditions met""" + raise NotImplementedError diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index fb47c22ceee0..eb8897df0797 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -25,7 +25,7 @@ def check_cuda_python_cuda_graphs_conditional_nodes_supported(): try: from cuda import cuda except ImportError: - raise ModuleNotFoundError("Please do `pip install cuda-python>=12.3`") + raise ModuleNotFoundError("No `cuda-python` module. Please do `pip install cuda-python>=12.3`") from cuda import __version__ as cuda_python_version diff --git a/tests/collections/asr/decoding/rnnt_alignments_check.py b/tests/collections/asr/decoding/rnnt_alignments_check.py index aa4d5f044de1..d44f7f8fd985 100644 --- a/tests/collections/asr/decoding/rnnt_alignments_check.py +++ b/tests/collections/asr/decoding/rnnt_alignments_check.py @@ -28,13 +28,14 @@ PRETRAINED_MODEL_NAME = "stt_en_conformer_transducer_small" -def get_rnnt_alignments(strategy: str, loop_labels: bool = True, location="cuda"): +def get_rnnt_alignments(strategy: str, loop_labels: bool = True, use_cuda_graph_decoder=False, location="cuda"): cfg = OmegaConf.structured(TranscriptionConfig(pretrained_name=PRETRAINED_MODEL_NAME)) cfg.rnnt_decoding.confidence_cfg.preserve_frame_confidence = True cfg.rnnt_decoding.preserve_alignments = True cfg.rnnt_decoding.strategy = strategy if cfg.rnnt_decoding.strategy == "greedy_batch": cfg.rnnt_decoding.greedy.loop_labels = loop_labels + cfg.rnnt_decoding.greedy.use_cuda_graph_decoder = use_cuda_graph_decoder cfg.dataset_manifest = TEST_DATA_PATH filepaths = prepare_audio_data(cfg)[0][:10] # selecting 10 files only @@ -73,10 +74,15 @@ def cleanup_local_folder(): # TODO: add the same tests for multi-blank RNNT decoding @pytest.mark.skipif(not os.path.exists('/home/TestData'), reason='Not a Jenkins machine') @pytest.mark.parametrize("loop_labels", [True, False]) -def test_rnnt_alignments(loop_labels: bool): +@pytest.mark.parametrize("use_cuda_graph_decoder", [True, False]) +def test_rnnt_alignments(loop_labels: bool, use_cuda_graph_decoder: bool): + if not loop_labels and use_cuda_graph_decoder: + pytest.skip("Frame-Looping algorithm with CUDA graphs does not yet support alignments") # using greedy as baseline and comparing all other configurations to it ref_transcriptions = get_rnnt_alignments("greedy") - transcriptions = get_rnnt_alignments("greedy_batch", loop_labels=loop_labels) + transcriptions = get_rnnt_alignments( + "greedy_batch", loop_labels=loop_labels, use_cuda_graph_decoder=use_cuda_graph_decoder + ) # comparing that label sequence in alignments is exactly the same # we can't compare logits as well, because they are expected to be # slightly different in batched and single-sample mode diff --git a/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py index 538ff9d71cf1..31fe822573ce 100644 --- a/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py +++ b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py @@ -11,19 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import glob -import tempfile import jiwer import pytest import torch -from omegaconf import OmegaConf, open_dict +from omegaconf import open_dict from nemo.collections.asr.models import ASRModel from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported +@pytest.fixture(scope="module") +def stt_en_fastconformer_transducer_xlarge(): + model_name = "stt_en_fastconformer_transducer_xlarge" + return ASRModel.from_pretrained(model_name, map_location="cpu") + + +@pytest.fixture(scope="module") +def stt_en_fastconformer_transducer_xxlarge(): + model_name = "stt_en_fastconformer_transducer_xxlarge" + return ASRModel.from_pretrained(model_name, map_location="cpu") + + +@pytest.fixture(scope="module") +def stt_en_fastconformer_transducer_large(): + model_name = "stt_en_fastconformer_transducer_large" + return ASRModel.from_pretrained(model_name, map_location="cpu") + + +@pytest.mark.with_downloads +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA decoder can run only on CUDA") @pytest.mark.parametrize( ("model_name", "batch_size", "enable_bfloat16"), [ @@ -42,28 +61,87 @@ ], ) @pytest.mark.parametrize("loop_labels", [False, True]) -def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16, loop_labels: bool): - skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() +def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16, loop_labels: bool, request): + if not loop_labels: + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() + if enable_bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("bfloat16 is not supported") + + device = torch.device("cuda") + nemo_model = request.getfixturevalue(model_name).to(device) + decoding_config = copy.deepcopy(nemo_model.cfg.decoding) + + with open_dict(decoding_config): + decoding_config["greedy"]["max_symbols"] = 5 + decoding_config["greedy"]["loop_labels"] = loop_labels + decoding_config["greedy"]["use_cuda_graph_decoder"] = False + + nemo_model.change_decoding_strategy(decoding_config) + audio_filepaths = glob.glob("tests/.data/asr/test/an4/wav/*.wav") + + with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): + actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) + + decoding_config["greedy"]["use_cuda_graph_decoder"] = True + + nemo_model.change_decoding_strategy(decoding_config) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): + fast_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) - conf = ASRModel.from_pretrained(model_name, return_config=True) - with open_dict(conf): - conf["decoding"]["greedy"]["max_symbols"] = 5 - conf["decoding"]["greedy"]["loop_labels"] = loop_labels - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = False + wer = jiwer.wer(actual_transcripts, fast_transcripts) - with tempfile.NamedTemporaryFile() as fp: - OmegaConf.save(config=conf, f=fp.name) - nemo_model = ASRModel.from_pretrained(model_name, override_config_path=fp.name, map_location="cuda") + assert wer <= 1e-3, "Cuda graph greedy decoder should match original decoder implementation." + for actual, fast in zip(actual_transcripts, fast_transcripts): + if actual != fast: + print("erroneous samples:") + print("Original transcript:", actual) + print("New transcript:", fast) + + +@pytest.mark.with_downloads +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA decoder can run only on CUDA") +@pytest.mark.parametrize("force_mode", ["no_graphs", "no_while_loops", "full_graph"]) +@pytest.mark.parametrize("enable_bfloat16", [False, True]) +def test_loop_labels_cuda_graph_rnnt_greedy_decoder_forced_mode( + stt_en_fastconformer_transducer_large, force_mode: str, enable_bfloat16: bool +): + """ + Testing Label-Looping algorithm with CUDA graphs in forced mode. + This test guarantees that we check that the fallback behavior is working. + NB: Since it is impossible to directly debug CUDA graphs, when making changes, + start testing and debugging the code with forced "no_graphs" mode. + """ + if enable_bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("bfloat16 is not supported") + + if force_mode == "full_graph": + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() + + batch_size = 16 + device = torch.device("cuda") + nemo_model = stt_en_fastconformer_transducer_large.to(device) + decoding_config = copy.deepcopy(nemo_model.cfg.decoding) + + with open_dict(decoding_config): + decoding_config["greedy"]["max_symbols"] = 5 + decoding_config["greedy"]["loop_labels"] = True + decoding_config["greedy"]["use_cuda_graph_decoder"] = False + # test that alignments and confidence do not introduce failures + decoding_config["greedy"]["preserve_alignments"] = True + decoding_config["greedy"]["preserve_frame_confidence"] = True + + nemo_model.change_decoding_strategy(decoding_config) audio_filepaths = glob.glob("tests/.data/asr/test/an4/wav/*.wav") with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): actual_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) - with open_dict(conf): - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True - - nemo_model.change_decoding_strategy(conf["decoding"]) + # transcribe with use implementation with cuda graphs + decoding_config["greedy"]["use_cuda_graph_decoder"] = True + nemo_model.change_decoding_strategy(decoding_config) + nemo_model.decoding.decoding._decoding_computer.force_cuda_graphs_mode(mode=force_mode) with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enable_bfloat16): fast_transcripts, _ = nemo_model.transcribe(audio_filepaths, batch_size=batch_size, num_workers=None) @@ -79,27 +157,27 @@ def test_cuda_graph_rnnt_greedy_decoder(model_name, batch_size, enable_bfloat16, print("New transcript:", fast) +@pytest.mark.with_downloads +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="Test requires 2 GPUs") @pytest.mark.parametrize("loop_labels", [False, True]) -def test_change_devices(loop_labels: bool): - skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() - - if torch.cuda.device_count() < 2: - pytest.skip("Test requires more than 2 GPUs") +def test_change_devices(loop_labels: bool, stt_en_fastconformer_transducer_xlarge): + if not loop_labels: + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported() first_device = torch.device("cuda:0") second_device = torch.device("cuda:1") - model_name = "stt_en_fastconformer_transducer_xlarge" batch_size = 8 - conf = ASRModel.from_pretrained(model_name, return_config=True) - with open_dict(conf): - conf["decoding"]["greedy"]["max_symbols"] = 5 - conf["decoding"]["greedy"]["loop_labels"] = loop_labels - conf["decoding"]["greedy"]["use_cuda_graph_decoder"] = True + nemo_model = stt_en_fastconformer_transducer_xlarge.to(second_device) + decoding_config = copy.deepcopy(nemo_model.cfg.decoding) + + with open_dict(decoding_config): + decoding_config["greedy"]["max_symbols"] = 5 + decoding_config["greedy"]["loop_labels"] = loop_labels + decoding_config["greedy"]["use_cuda_graph_decoder"] = True - nemo_model = ASRModel.from_pretrained(model_name, map_location=second_device) - nemo_model.change_decoding_strategy(conf["decoding"]) + nemo_model.change_decoding_strategy(decoding_config) # Test that the model can run successfully when it is first # initialized on second_device and then transferred to diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index a6e3714f20f5..c3b214751d04 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -432,9 +432,14 @@ def test_BeamRNNTInferConfig(self): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding(self, greedy_class): + def test_greedy_decoding(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -454,7 +459,14 @@ def test_greedy_decoding(self, greedy_class): for joint_type in [RNNTJoint, HATJoint]: joint_net = joint_type(jointnet_cfg, vocab_size, vocabulary=token_list) - greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} + greedy = greedy_class( + decoder, + joint_net, + blank_index=len(token_list) - 1, + max_symbols_per_step=5, + **additional_decoding_kwargs, + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) diff --git a/tests/collections/common/test_optional_cuda_graphs.py b/tests/collections/common/test_optional_cuda_graphs.py new file mode 100644 index 000000000000..7b1dda775863 --- /dev/null +++ b/tests/collections/common/test_optional_cuda_graphs.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import SimpleNamespace + +import torch.nn as nn + +from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs + + +class MockClassWithCudaGraphs(WithOptionalCudaGraphs): + def __init__(self): + super().__init__() + self.cuda_graphs_used = True + + def disable_cuda_graphs(self): + self.cuda_graphs_used = False + + def maybe_enable_cuda_graphs(self): + self.cuda_graphs_used = True + + +class MockModuleWithCudaGraphs(MockClassWithCudaGraphs, nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 20) + + def forward(self, x): + return self.linear(x) + + +class MockModuleWithCudaGraphsByPath(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 20) + self.decoding = SimpleNamespace(decoding=MockClassWithCudaGraphs()) + + def forward(self, x): + return self.linear(x) + + +class TestWithOptionalCudaGraphs: + def test_module_toggle_cuda_graphs(self): + module_with_graphs = MockModuleWithCudaGraphs() + assert module_with_graphs.cuda_graphs_used + WithOptionalCudaGraphs.disable_cuda_graphs_recursive(module_with_graphs) + assert not module_with_graphs.cuda_graphs_used + WithOptionalCudaGraphs.enable_cuda_graphs_recursive(module_with_graphs) + assert module_with_graphs.cuda_graphs_used + + def test_module_toggle_cuda_graphs_by_path(self): + module_with_graphs_by_path = MockModuleWithCudaGraphsByPath() + assert module_with_graphs_by_path.decoding.decoding.cuda_graphs_used + WithOptionalCudaGraphs.disable_cuda_graphs_recursive( + module_with_graphs_by_path, attribute_path="decoding.decoding" + ) + assert not module_with_graphs_by_path.decoding.decoding.cuda_graphs_used + WithOptionalCudaGraphs.enable_cuda_graphs_recursive( + module_with_graphs_by_path, attribute_path="decoding.decoding" + ) + assert module_with_graphs_by_path.decoding.decoding.cuda_graphs_used