Skip to content

Commit

Permalink
RNN-T and TDT inference: use CUDA graphs by default (#8972)
Browse files Browse the repository at this point in the history
* Use Cuda graphs by default for RNN-T and TDT

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

---------

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
  • Loading branch information
artbataev authored May 3, 2024
1 parent e16d069 commit 894e502
Show file tree
Hide file tree
Showing 13 changed files with 746 additions and 133 deletions.
51 changes: 49 additions & 2 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from typing import List
from abc import ABC
from typing import List, Optional

import torch

from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
Expand Down Expand Up @@ -171,6 +172,52 @@ def on_after_backward(self):
logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.')
self.zero_grad()

def on_train_epoch_start(self) -> None:
"""
Decoder with CUDA graphs does not release memory, thus we disable it for training epoch.
EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs
"""
WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self, attribute_path="decoding.decoding")

def on_train_epoch_end(self) -> None:
"""
After training, we can enable the decoder with CUDA graphs.
EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs
"""
WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding")

def on_validation_epoch_start(self) -> None:
"""
For validation, we enable CUDA graphs to speedup validation.
EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs.
"""
WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding")

def on_validation_epoch_end(self) -> Optional[dict[str, dict[str, torch.Tensor]]]:
"""
After validation, we disable CUDA graphs, since `validation` can be called in training loop, and
training will continue after validation
EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs.
"""
WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self, attribute_path="decoding.decoding")
return super().on_validation_epoch_end()

def on_test_epoch_start(self) -> None:
"""
For testing, we enable CUDA graphs to speedup validation.
We do not need to disable CUDA graphs after testing, since `test` cannot be called in training loop.
EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs.
"""
WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding")

def on_predict_epoch_start(self) -> None:
"""
For predicting, we enable CUDA graphs to speedup validation.
We do not need to disable CUDA graphs after predicting, since `predict` cannot be called in training loop.
EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs
"""
WithOptionalCudaGraphs.enable_cuda_graphs_recursive(self, attribute_path="decoding.decoding")


class ExportableEncDecModel(Exportable):
"""
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
batch = y.size(0)
# state contains context_size - 1 elements for each utterance in batch,
# consistent with the state returned from StatelessNet.forward
state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx]
state = [
torch.full([batch, self.context_size - 1], fill_value=self.blank_idx, dtype=torch.long, device=y.device)
]
return state

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,21 @@ def __call__(
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not available with cuda graphs (but could be)")
raise NotImplementedError(
"`partial_hypotheses` support is not available "
"with Frame-Looping algorithm with Cuda graphs (not implemented yet)"
)

if self.caller.preserve_alignments:
raise NotImplementedError("`preserve_alignments` support is not available with cuda graphs (but could be)")
raise NotImplementedError(
"`preserve_alignments` support is not available"
"with Frame-Looping algorithm with Cuda graphs (not implemented yet)"
)

if self.caller.preserve_frame_confidence:
raise NotImplementedError(
"`preserve_frame_confidence` support is not available with cuda graphs (but could be)"
"`preserve_frame_confidence` support is not available"
"with Frame-Looping algorithm with Cuda graphs (not implemented yet)"
)

batch_size = x.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand All @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)

else:
Expand Down
98 changes: 82 additions & 16 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.collections.common.parts.rnn import label_collate
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
Expand Down Expand Up @@ -508,7 +509,7 @@ def _greedy_decode(
return hypothesis


class GreedyBatchedRNNTInfer(_GreedyRNNTInfer):
class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs):
"""A batch level greedy transducer decoder.
Batch level greedy decoding, performed auto-regressively.
Expand Down Expand Up @@ -589,7 +590,7 @@ def __init__(
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand All @@ -602,13 +603,14 @@ def __init__(
)

self.use_cuda_graph_decoder = use_cuda_graph_decoder
self.loop_labels = loop_labels

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
self._decoding_computer = None
if self.decoder.blank_as_pad:
if loop_labels:
# default (faster) algo: loop over labels
if self.loop_labels:
# Label-Looping algorithm (default, faster)
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels
self._decoding_computer = GreedyBatchedRNNTLoopLabelsComputer(
decoder=self.decoder,
Expand All @@ -618,20 +620,74 @@ def __init__(
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
confidence_method_cfg=confidence_method_cfg,
allow_cuda_graphs=use_cuda_graph_decoder,
allow_cuda_graphs=self.use_cuda_graph_decoder,
)
elif use_cuda_graph_decoder:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import (
RNNTGreedyDecodeCudaGraph,
)

self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self)
else:
# previous algo: loop over frames
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
# Frame-Looping algorithm
if not self.use_cuda_graph_decoder:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
if self.preserve_alignments:
logging.warning("`preserve_alignments` is not implemented for Frame-Looping + CUDA graphs")
self.use_cuda_graph_decoder = False
if self.preserve_frame_confidence:
logging.warning(
"`preserve_frame_confidence` is not implemented for Frame-Looping + CUDA graphs"
)
self.use_cuda_graph_decoder = False
if not torch.cuda.is_available():
self.use_cuda_graph_decoder = False

if self.use_cuda_graph_decoder:
try:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import (
RNNTGreedyDecodeCudaGraph,
)

self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self)
except (ImportError, ModuleNotFoundError, ValueError) as e:
self.use_cuda_graph_decoder = False
logging.warning(f"Cannot use decoder with CUDA graphs, reason: {e.msg}")
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
self._greedy_decode = self._greedy_decode_masked

def disable_cuda_graphs(self):
"""Disable CUDA graphs (e.g., for decoding in training)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.disable_cuda_graphs()
else:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames

def maybe_enable_cuda_graphs(self):
"""Enable CUDA graphs (if allowed)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.maybe_enable_cuda_graphs()
else:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph

self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self)

@typecheck()
def forward(
self,
Expand Down Expand Up @@ -2302,7 +2358,7 @@ class GreedyBatchedRNNTInferConfig:
tdt_include_duration_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True
use_cuda_graph_decoder: bool = False
use_cuda_graph_decoder: bool = True

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down Expand Up @@ -2580,7 +2636,7 @@ def _greedy_decode(
return hypothesis


class GreedyBatchedTDTInfer(_GreedyRNNTInfer):
class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs):
"""A batch level greedy TDT decoder.
Batch level greedy decoding, performed auto-regressively.
Args:
Expand Down Expand Up @@ -2652,7 +2708,7 @@ def __init__(
preserve_frame_confidence: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand Down Expand Up @@ -2759,3 +2815,13 @@ def _greedy_decode_blank_as_pad_loop_labels(
for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)):
hyp.dec_state = state
return hyps

def disable_cuda_graphs(self):
"""Disable CUDA graphs (e.g., for decoding in training)"""
if self._decoding_computer is not None:
self._decoding_computer.disable_cuda_graphs()

def maybe_enable_cuda_graphs(self):
"""Enable CUDA graphs (if allowed)"""
if self._decoding_computer is not None:
self._decoding_computer.maybe_enable_cuda_graphs()
Loading

0 comments on commit 894e502

Please sign in to comment.