Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for warning as default stream was used in enqueueV3 #3191

Merged
merged 2 commits into from
Oct 15, 2024

Conversation

keehyuna
Copy link
Collaborator

Description

torch.cuda.current_stream()/c10::cuda::getCurrentCUDAStream() always returns default stream and it leads running enqueueV3() with default stream.
torch.cuda.set_stream/c10::cuda::setCurrentCUDAStream is required to set current stream when new stream is acquired from pool

Fixes #3190

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@keehyuna keehyuna self-assigned this Sep 27, 2024
@github-actions github-actions bot added component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 27, 2024
@narendasan
Copy link
Collaborator

@keehyuna there is some code related to cudagraphs, can you check how it handles streams and perhaps write some docs on how the runtime is suppose to do streams in all cases?

@peri044
Copy link
Collaborator

peri044 commented Oct 4, 2024

  • Test GPT2 model and llama model output numerics after applying this patch
  • Add a doc (runtime) which explains how streams work in our runtime. (2nd part of this task)

Copy link
Collaborator Author

@keehyuna keehyuna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is when problem happend. capture stream is changed to default and non default stream.
image

This is when torch.cuda.set_stream() is used. Non default stream is used for cuda graph/enqueueV3(). But stream is not restored after Forward()
image.

This is proposed fix to keep side stream to cuda graph or enqueueV3()
image

@@ -333,7 +331,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
if (need_cudagraphs_record) {
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was reverted to fix below assert() from torch. We don't share memory across captures, I think we can use internally created pool.

https://pytorch.org/docs/stable/notes/cuda.html#graph-memory-management

  File "/root/trt/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 274, in forward
    outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_ops.py", line 1113, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: it->second->use_count > 0 INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":2056, please report a bug to PyTorch

@@ -464,7 +461,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
if new_shape_key != self.shape_key:
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
self.shape_key = new_shape_key
self.cudagraph.reset() # type: ignore
if self.cudagraph:
self.cudagraph.reset()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.cudagraph can be None when torch_compile backend is used.

self.cudagraph is initialized when cudagraphs mode is enabled. But this init was called at compile()
https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py#L144-L145

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch LGTM

@narendasan
Copy link
Collaborator

@lanluo-nvidia please cherry-pick

lanluo-nvidia added a commit that referenced this pull request Oct 14, 2024
@keehyuna keehyuna merged commit 743fdbd into pytorch:main Oct 15, 2024
75 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛 [Bug] Warning as default stream was used in enqueueV3()
4 participants