Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN-T and TDT inference: use CUDA graphs by default #8972

Merged
merged 39 commits into from
May 3, 2024

Conversation

artbataev
Copy link
Collaborator

@artbataev artbataev commented Apr 18, 2024

What does this PR do ?

  • enable CUDA graphs by default for inference (transcribe_speech.py, speech_to_text_eval.py)
  • add fallback behavior with partial CUDA graphs for Label-Looping algorithm implementation (TDT, RNN-T)
    • Use Cuda graphs with while loops when all requirements are met (driver 545+, Cuda 12.3+, cuda-python package)
    • Use Cuda graphs without while loops (pure PyTorch functionality) in most other cases. The idea: use several separate Cuda graphs for parts of the decoding algorithm (<graph before outer loop> -> while loop (python) -> <graph before inner loop> -> inner while loop (python) -> <graph for inner loop code> etc.)

On my local machine, FastConformer L, LibriSpeech test-other decoding time, bfloat16, bs=16

no graphs partial graphs full graph
unsorted batch 34s 21s 19s
sorted batch 23s 15s 14s

Collection: [ASR]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

python examples/asr/speech_to_text_eval.py <with rnnt/tdt model>

Jenkins CI

To run Jenkins, a NeMo User with write access must comment jenkins on the PR.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@github-actions github-actions bot added the ASR label Apr 18, 2024
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev artbataev changed the title RNN-T and TDT decoding: use CUDA graphs by default RNN-T and TDT inference: use CUDA graphs by default Apr 18, 2024
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev
Copy link
Collaborator Author

jenkins

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev
Copy link
Collaborator Author

jenkins

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev artbataev marked this pull request as ready for review April 25, 2024 19:06
@artbataev artbataev requested review from galv and titu1994 April 25, 2024 19:06
@artbataev
Copy link
Collaborator Author

artbataev commented Apr 25, 2024

@galv, @titu1994 Please, review the PR

The only issue I see for now is that the Frame-Looping algorithm does not have fallback behavior, and setting loop_labels=false without use_cuda_graph_decoder=false (default - true) can cause failures (if not all requirements are met for Frame-Looping+CUDA graphs).

On the other hand, this case is not important. Most users will use the Label-Looping algorithm since it is default and produces the same result as the Frame-Looping algorithm.

@galv
Copy link
Collaborator

galv commented Apr 26, 2024

There is something I'm not quite understanding. Is there a reason why you did not set the default value of these values to true?

use_cuda_graph_decoder: bool = False,

and

use_cuda_graph_decoder: bool = False

and

use_cuda_graph_decoder: bool = False,

It looks like you are turning cuda graphs on by default only when someone runs transcribe_speech.py or transcribe_speech_parallel.py right now.

if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH:
self.full_graph.replay()
elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS:
self.separate_graphs.before_outer_loop.replay()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool. I would like to speak to you about a way we could possibly do this more easily using torch.cond and torch.while_loop. They're not in the newest versions of pytorch yet.

As we improve our implementations of beam search, I don't think it is realistic to keep doing special case code like this. I'm thinking we can get the right level of abstraction with torch.cond and torch.while_loop, such that people can write things more naturally.

# cuda graphs are allowed
# check basic requirements for cuda graphs
if self.max_symbols is None:
logging.warning("Max symbols is None, which is not allowed with Cuda graphs.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

People don't check warnings often enough. I recommend you throw an exception here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Max symbols is none on some older models if i recall, that means out of the box they will crash to infer due to error for unrelated reason of cuda graphs not supporting None. It should remain a warning, but i agree that instead of crashing, set it to large default value for cuda graphs..

Ie warn that its not supported so instead a default of 10 timesteps or higher is being used for cuda graphs optimization

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Som's suggestion, actually. Using a large value like 10 when someone passes in None seems like the right move.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully agree, thanks! I fixed the behavior.


self.state: Optional[LoopLabelsState] = None
def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? I don't see any usages of it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I forgot to add an explicit test. Fixed.
This is useful for debugging to set no_graphs mode (since it's impossible to debug CUDA graphs directly).

@galv
Copy link
Collaborator

galv commented Apr 26, 2024

On the other hand, this case is not important. Most users will use the Label-Looping algorithm since it is default and produces the same result as the Frame-Looping algorithm.

I agree. That is was my first attempt. It was very educational, but I don't believe that we need to do the work to add a feature to it that won't be used very much.

@galv
Copy link
Collaborator

galv commented Apr 26, 2024

I didn't hit approve yet. The changes seem good to me. But I want to give some time to defer to @titu1994

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some important comments for logging, otherwise looks good

@@ -161,7 +162,9 @@ class TranscriptionConfig:
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig()

# Decoding strategy for RNNT models
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(
fused_batch_size=-1, greedy=GreedyBatchedRNNTInferConfig(use_cuda_graph_decoder=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can set the default config inside GreedyBatchedRNNTInferConfig to have this set to True by default rather than do it with this explicit override

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as with setting use_cuda_graph_decoder=True as default in the classes.
I added comments above this line.
RNNTDecodingConfig is used, e.g., in change_vocabulary. After this operation (useful for finetuning), the model will have the field in the config use_cuda_graph_decoder=True, and further training will use the CUDA decoder. Since it is not compatible with bucketing (pre-allocated memory for maximum batch_size*sequence_length can be too large in this case), I prefer to conservatively enable it only for transcription.

The alternative is to enable it everywhere by default, but in the training loop, explicitly use the decoder without CUDA graphs. However, this can make the code too complicated.

If you see that there is a more straightforward solution for this issue, please let's discuss it!

Copy link
Collaborator Author

@artbataev artbataev Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative is to enable it everywhere by default, but in the training loop, explicitly use the decoder without CUDA graphs. However, this can make the code too complicated.

Moved to this solution. I think it is much cleaner. So, set everywhere use_cuda_graph_decoder=True by default

# cuda graphs are allowed
# check basic requirements for cuda graphs
if self.max_symbols is None:
logging.warning("Max symbols is None, which is not allowed with Cuda graphs.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Max symbols is none on some older models if i recall, that means out of the box they will crash to infer due to error for unrelated reason of cuda graphs not supporting None. It should remain a warning, but i agree that instead of crashing, set it to large default value for cuda graphs..

Ie warn that its not supported so instead a default of 10 timesteps or higher is being used for cuda graphs optimization

check_cuda_python_cuda_graphs_conditional_nodes_supported()
self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH
except (ImportError, ModuleNotFoundError) as e:
logging.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im wondering if this should be visible to users. The problem is the vast majority of users will NOT be on latest driver and cudapython install - ie the vast majority wont be using Cuda graphs and will log this warning - repeatedly polluting the inference in a loop call to transcribe().

We need to make this log just once - see loggermode and how to pass it inside of logging messages.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I see it, the transcribe does not change the decoding strategy. Since this is logged only when instantiating the class (when instantiating the model or changing the decoding strategy), this should be fine (no repeated logs).


from nemo.collections.asr.models import ASRModel
from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported


@pytest.fixture(scope="module")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldnt this be with a marker for with_download

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytest marks (with_downloads) are useful only for tests, not fixtures. All the tests that use these fixtures are marked with with_downloads tag.

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev artbataev requested a review from titu1994 April 30, 2024 18:15
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@@ -171,6 +173,51 @@ def on_after_backward(self):
logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.')
self.zero_grad()

@classmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two methods should be class methods of WithOptionalCYDAGraphs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, generally, this is not a good idea to introduce a 2-way dependency WithOptionalCudaGraphs <-> ASRModel (actually, EncDecRNNTModel, since decoding is only in this model).
I made the method more abstract to separate the logic, separating the path in the model and the lookup logic.

nemo/collections/asr/models/asr_model.py Show resolved Hide resolved
…UDA graphs in `ASRModel`

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev artbataev requested a review from titu1994 May 2, 2024 14:40
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, @galv for final approval

Copy link
Collaborator

@galv galv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue you raised about memory usage when doing inference during training is a very good one. I believe we can figure out a solution that is less instrusive in the future if we point it out to the right people (follow up with me!).

EncDecRNNTModel.decoding.decoding is the inference class with CUDA graphs.
"""
WithOptionalCudaGraphs.disable_cuda_graphs_recursive(self, attribute_path="decoding.decoding")
return super().on_validation_epoch_end()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you call the superclass's implementation only for this method, but not for the others?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hooks return None in the PyTorch-Lightning interface, and basically, there is no code in such hooks. But ModelPT defines the on_validation_epoch_end hook for all models with the customized return type, so I need to call it.

if not self.use_cuda_graph_decoder:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
else:
if self.preserve_alignments:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprsied that you would silently change the behavior on lines 630 to 639 rather than throw an exception in these cases, to be honest.

Meanwhile, the situation where we set the symbols_per_step to 10 if it is None seems okay because it is unlikely to change the results, since 10 is such a large number.

I'm not going to hold up merging this because of this concern, anyway, since it is a code path most people won't see.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this fallback behavior to prevent crashes when the user wants to change some parameters since use_cuda_graph_decoder is True by default now. Since it's only about speed (not quality), it is acceptable to switch silently between implementations instead of requiring the user to understand all the nuances of the available parameter combinations.
LoopLabelsComputer(s) are designed to handle all situations without explicit errors (e.g., when cuda is unavailable, etc.).

RNNTGreedyDecodeCudaGraph,
)

self._greedy_decode = RNNTGreedyDecodeCudaGraph(max_symbols_per_step, self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not certain, but it currently looks like we will throw an exception if max_symbols_per_step is None, rather than overriding it to 10 for the frame-loop decoder right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, thanks for catching this. I will address this in a follow-up PR

self.state.alignments.add_results_masked_no_checks_(
active_mask=self.state.active_mask,
time_indices=self.state.time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=self.state.labels if self.preserve_alignments else None,
confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1))
confidence=self._get_confidence_tensor(F.log_softmax(logits, dim=-1)).to(dtype=float_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a point of clarification, why did you need to add this to() call? It doesn't seem to be related to the rest of the changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be happy to avoid this, but without casting, the code will fail with mixed bf16 precision:

  • log_softmax returns the value of float32 type in bfloat16 mixed precision (amp)
  • alignments storage is initialized with bf16 type, and adding confidence values inside add_results_masked_no_checks_ will fail

This is the same issue I observed when reviewing Sasha's PR related to TDT confidence #8982
Since I enabled computing confidence in tests for RNN-T, I caught this bug and fixed it here https://github.com/NVIDIA/NeMo/pull/8972/files#diff-d8ba9ce8e77769e06174cf0d16842d130debb4a289e92fea5296b081f5a4deabR133

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I understand! I am planning to redo #9086 so that we will do inference in pure bfloat16 or float16, rather than using AMP. Basically, running in AMP can actually slow you down compared to running in float32 in inference mode, because it caches the down-casted versions of parameters only "requires_grad=False" for those parameters.

It should be safe to do softmax with fp16 inputs and outputs at inference time. The accumulations are done in fp32, which is the important part.

After we move away from AMP for inference (which might take a while since NeMo was written with that assumption for a long time), we can get rid of the need for the cast.

@artbataev artbataev merged commit 894e502 into main May 3, 2024
148 of 254 checks passed
@artbataev artbataev deleted the rnnt_cuda_graphs_default branch May 3, 2024 11:10
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* Use Cuda graphs by default for RNN-T and TDT

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

---------

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants