Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN-T and TDT inference: use CUDA graphs by default #8972

Merged
merged 39 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7441a2f
Use Cuda graphs by default for transcription
artbataev Apr 18, 2024
4761e68
RNN-T Loop Labels + Cuda graphs user-friendly
artbataev Apr 18, 2024
2b13f31
Fix Cuda graphs mode
artbataev Apr 18, 2024
d391e98
Fuse graphs
artbataev Apr 18, 2024
0205101
Enable by default Cuda graphs for TDT
artbataev Apr 18, 2024
0cac599
Merge branch 'main' into rnnt_cuda_graphs_default
artbataev Apr 19, 2024
cb68701
Merge branch 'main' into rnnt_cuda_graphs_default
artbataev Apr 19, 2024
b38ff6f
Merge branch 'main' into rnnt_cuda_graphs_default
artbataev Apr 25, 2024
25235bb
Add test
artbataev Apr 25, 2024
f175c8a
Speedup init state
artbataev Apr 25, 2024
1672739
Add comments
artbataev Apr 25, 2024
9e18150
Speedup tests
artbataev Apr 25, 2024
94ba6af
Add comments
artbataev Apr 25, 2024
7b9d619
Fix tests for alignments
artbataev Apr 25, 2024
2d3b083
Fix test
artbataev Apr 25, 2024
7e805bc
Merge branch 'main' into rnnt_cuda_graphs_default
artbataev Apr 25, 2024
43c01ac
Test decoder in forced mode
artbataev Apr 27, 2024
030be86
Set max_symbols to 10 if None. Add comments
artbataev Apr 27, 2024
100fd9c
Fix issue with confidence + bfloat16
artbataev Apr 27, 2024
6b5d1d2
Test with confidence
artbataev Apr 27, 2024
07bd665
Add comment about setting variables in config
artbataev Apr 27, 2024
638823e
Merge branch 'main' into rnnt_cuda_graphs_default
artbataev Apr 27, 2024
464dd51
Enable CUDA graphs everywhere. Disable explicitly in training pipeline.
artbataev Apr 29, 2024
2b7cd73
Revert redundant changes
artbataev Apr 29, 2024
e730f91
Fix comment
artbataev Apr 29, 2024
cc06bf0
Fix typo
artbataev Apr 29, 2024
cf38241
Fix enabling CUDA graphs
artbataev Apr 29, 2024
19ca09d
Instantiate RNNTGreedyDecodeCudaGraph only when all conditions are met
artbataev Apr 29, 2024
d98b8fc
Fix hybrid ASR-TTS model
artbataev Apr 30, 2024
05eb103
Merge branch 'main' into rnnt_cuda_graphs_default
artbataev Apr 30, 2024
7c6f7f0
Move toggling CUDA graphs to `ASRModel`
artbataev Apr 30, 2024
cb6d500
Remove redundant import
artbataev Apr 30, 2024
35564df
Clean up
artbataev Apr 30, 2024
7192acc
Clean up
artbataev Apr 30, 2024
d4a27f6
Clean up
artbataev Apr 30, 2024
c0877f2
Extract toggling CUDA graphs logic to `WithOptionalCudaGraphs`. Fix C…
artbataev May 2, 2024
b880510
Fix unused imports
artbataev May 2, 2024
82f83dc
Merge branch 'main' into rnnt_cuda_graphs_default
artbataev May 2, 2024
4e47010
Fix hook (failing tests)
artbataev May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import List

import torch
import torch.nn as nn
Fixed Show fixed Hide fixed

from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
Expand Down Expand Up @@ -171,6 +173,51 @@ def on_after_backward(self):
logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.')
self.zero_grad()

@classmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two methods should be class methods of WithOptionalCYDAGraphs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, generally, this is not a good idea to introduce a 2-way dependency WithOptionalCudaGraphs <-> ASRModel (actually, EncDecRNNTModel, since decoding is only in this model).
I made the method more abstract to separate the logic, separating the path in the model and the lookup logic.

def disable_cuda_graphs_in_decoder(cls, module: nn.Module):
"""
Disable CUDA graphs for decoding to preserve memory and avoid problems with memory usage.
Should be used in training pipeline
"""
# in RNN-T model, `model.decoding` is not an instance of nn.Module,
# we need to check `model.decoding.decoding` explicitly
for submodule in module.modules():
if (
hasattr(submodule, "decoding")
and hasattr(submodule.decoding, "decoding")
and isinstance(submodule.decoding.decoding, WithOptionalCudaGraphs)
):
submodule.decoding.decoding.disable_cuda_graphs()

@classmethod
def enable_cuda_graphs_in_decoder(cls, module: nn.Module):
"""Enable CUDA graphs for decoding (validation/testing)"""
# in RNN-T model, model.decoding is not an instance of nn.Module,
# we need to check `model.decoding.decoding` explicitly
for submodule in module.modules():
if (
hasattr(submodule, "decoding")
and hasattr(submodule.decoding, "decoding")
and isinstance(submodule.decoding.decoding, WithOptionalCudaGraphs)
):
submodule.decoding.decoding.maybe_enable_cuda_graphs()

def on_train_epoch_start(self) -> None:
"""Decoder with CUDA graphs does not release memory, thus we disable it for training epoch"""
self.disable_cuda_graphs_in_decoder(self)

def on_train_epoch_end(self) -> None:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
self.enable_cuda_graphs_in_decoder(self)

def on_validation_epoch_start(self) -> None:
self.enable_cuda_graphs_in_decoder(self)

def on_test_epoch_start(self) -> None:
self.enable_cuda_graphs_in_decoder(self)

def on_predict_epoch_start(self) -> None:
self.enable_cuda_graphs_in_decoder(self)


class ExportableEncDecModel(Exportable):
"""
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
batch = y.size(0)
# state contains context_size - 1 elements for each utterance in batch,
# consistent with the state returned from StatelessNet.forward
state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx]
state = [
torch.full([batch, self.context_size - 1], fill_value=self.blank_idx, dtype=torch.long, device=y.device)
]
return state

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,21 @@ def __call__(
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not available with cuda graphs (but could be)")
raise NotImplementedError(
"`partial_hypotheses` support is not available "
"with Frame-Looping algorithm with Cuda graphs (not implemented yet)"
)

if self.caller.preserve_alignments:
raise NotImplementedError("`preserve_alignments` support is not available with cuda graphs (but could be)")
raise NotImplementedError(
"`preserve_alignments` support is not available"
"with Frame-Looping algorithm with Cuda graphs (not implemented yet)"
)

if self.caller.preserve_frame_confidence:
raise NotImplementedError(
"`preserve_frame_confidence` support is not available with cuda graphs (but could be)"
"`preserve_frame_confidence` support is not available"
"with Frame-Looping algorithm with Cuda graphs (not implemented yet)"
)

batch_size = x.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand All @@ -347,7 +347,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_frame_confidence=self.preserve_frame_confidence,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
)

else:
Expand Down
98 changes: 82 additions & 16 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.collections.common.parts.rnn import label_collate
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
Expand Down Expand Up @@ -508,7 +509,7 @@
return hypothesis


class GreedyBatchedRNNTInfer(_GreedyRNNTInfer):
class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs):
"""A batch level greedy transducer decoder.

Batch level greedy decoding, performed auto-regressively.
Expand Down Expand Up @@ -589,7 +590,7 @@
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand All @@ -602,13 +603,14 @@
)

self.use_cuda_graph_decoder = use_cuda_graph_decoder
self.loop_labels = loop_labels

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
self._decoding_computer = None
if self.decoder.blank_as_pad:
if loop_labels:
# default (faster) algo: loop over labels
if self.loop_labels:
# Label-Looping algorithm (default, faster)
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels
self._decoding_computer = GreedyBatchedRNNTLoopLabelsComputer(
decoder=self.decoder,
Expand All @@ -618,20 +620,74 @@
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
confidence_method_cfg=confidence_method_cfg,
allow_cuda_graphs=use_cuda_graph_decoder,
allow_cuda_graphs=self.use_cuda_graph_decoder,
)
elif use_cuda_graph_decoder:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import (
RNNTGreedyDecodeCudaGraph,
)

self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self)
else:
# previous algo: loop over frames
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
# Frame-Looping algorithm
if not self.use_cuda_graph_decoder:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
if self.preserve_alignments:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprsied that you would silently change the behavior on lines 630 to 639 rather than throw an exception in these cases, to be honest.

Meanwhile, the situation where we set the symbols_per_step to 10 if it is None seems okay because it is unlikely to change the results, since 10 is such a large number.

I'm not going to hold up merging this because of this concern, anyway, since it is a code path most people won't see.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this fallback behavior to prevent crashes when the user wants to change some parameters since use_cuda_graph_decoder is True by default now. Since it's only about speed (not quality), it is acceptable to switch silently between implementations instead of requiring the user to understand all the nuances of the available parameter combinations.
LoopLabelsComputer(s) are designed to handle all situations without explicit errors (e.g., when cuda is unavailable, etc.).

logging.warning("`preserve_alignments` is not implemented for Frame-Looping + CUDA graphs")
self.use_cuda_graph_decoder = False
if self.preserve_frame_confidence:
logging.warning(
"`preserve_frame_confidence` is not implemented for Frame-Looping + CUDA graphs"
)
self.use_cuda_graph_decoder = False
if not torch.cuda.is_available():
self.use_cuda_graph_decoder = False

if self.use_cuda_graph_decoder:
try:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import (
RNNTGreedyDecodeCudaGraph,
)

self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not certain, but it currently looks like we will throw an exception if max_symbols_per_step is None, rather than overriding it to 10 for the frame-loop decoder right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, thanks for catching this. I will address this in a follow-up PR

except (ImportError, ModuleNotFoundError, ValueError) as e:
self.use_cuda_graph_decoder = False
logging.warning(f"Cannot use decoder with CUDA graphs, reason: {e.msg}")
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
self._greedy_decode = self._greedy_decode_masked

def disable_cuda_graphs(self):
"""Disable CUDA graphs (e.g., for decoding in training)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.disable_cuda_graphs()
else:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames

def maybe_enable_cuda_graphs(self):
"""Enable CUDA graphs (if allowed)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.maybe_enable_cuda_graphs()
else:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph

self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self)

@typecheck()
def forward(
self,
Expand Down Expand Up @@ -1829,9 +1885,9 @@
# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
if self.decoder.blank_as_pad:
self._greedy_decode = self._greedy_decode_blank_as_pad

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
else:
self._greedy_decode = self._greedy_decode_masked

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
Assignment overwrites attribute _greedy_decode, which was previously defined in superclass
GreedyBatchedRNNTInfer
.
self._SOS = blank_index - len(big_blank_durations)

def _greedy_decode_blank_as_pad(
Expand Down Expand Up @@ -2302,7 +2358,7 @@
tdt_include_duration_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = True
use_cuda_graph_decoder: bool = False
use_cuda_graph_decoder: bool = True

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down Expand Up @@ -2580,7 +2636,7 @@
return hypothesis


class GreedyBatchedTDTInfer(_GreedyRNNTInfer):
class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs):
"""A batch level greedy TDT decoder.
Batch level greedy decoding, performed auto-regressively.
Args:
Expand Down Expand Up @@ -2652,7 +2708,7 @@
preserve_frame_confidence: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
use_cuda_graph_decoder: bool = False,
use_cuda_graph_decoder: bool = True,
):
super().__init__(
decoder_model=decoder_model,
Expand Down Expand Up @@ -2759,3 +2815,13 @@
for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)):
hyp.dec_state = state
return hyps

def disable_cuda_graphs(self):
"""Disable CUDA graphs (e.g., for decoding in training)"""
if self._decoding_computer is not None:
self._decoding_computer.disable_cuda_graphs()

def maybe_enable_cuda_graphs(self):
"""Enable CUDA graphs (if allowed)"""
if self._decoding_computer is not None:
self._decoding_computer.maybe_enable_cuda_graphs()
Loading
Loading