Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda graph support multi-stream for new executor #51389

Conversation

pangyoki
Copy link
Contributor

@pangyoki pangyoki commented Mar 8, 2023

PR types

New features

PR changes

Others

Describe

新执行器cuda graph支持多流。

基本思想

capture cuda graph前使用create_cuda_graph_stream判断是否使用新执行器、是否使用多流(或者新执行器没有使用默认流),在多流场景下会新建一个stream专门用于capture cuda graph。
使用eventWait让这个新建的cuda graph stream与其他所有stream建立依赖关系,可让所有stream都处于被capture的状态。最终capture得到一个cuda graph。

基本实现

  • 新执行器不支持第一轮就使用capture cuda graph、不支持更新feed fetch数据。第一轮Convert阶段分析多流时,会将多流信息记录到CUDAGraphContextManager的capturing_ctxs_中。不支持memcpy d2h与h2d。
  • begin_capture前,先新建一个stream,专门用于capture cuda graph。为所有stream做cudnn的一些初始化操作。
  • begin capture后,让记录下来的所有stream全部eventWait这个新建的stream,让所有stream都处于被capture的状态。且为所有stream设置cuda graph allocator。
  • end capture时,让新建的cuda graph stream eventWait所有其他stream。清理steam。

Pcard-66979

@paddle-bot
Copy link

paddle-bot bot commented Mar 8, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@pangyoki pangyoki force-pushed the cuda_graph_support_multi_stream_for_new_exe branch from 7db9f61 to 5925e4b Compare March 8, 2023 16:07
Can only be used for new executor in static mode, that is,
FLAGS_new_executor_use_cuda_graph needs to be set to True.
The default value of create_cuda_graph_stream is False.
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加一个判断,只有在静态图模式下、新执行器开关打开时,才能将create_cuda_graph_stream设置为True。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_cuda_graph_stream属性已删除,在cuda_graph_with_memory_pool中自动判断是否需要生成一个新的stream来capture cuda_graph,不需要由用户指定。

From00
From00 previously approved these changes Mar 13, 2023
Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ops[op_index].dist_attr.stream_priority = -1

def run_program(self, use_cuda_graph=False, apply_custom_stream=False):
# paddle.seed(2022)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be deleted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -55,8 +55,19 @@ def __init__(self, place=None, mode="thread_local"):
assert mode in ALL_MODES
self._mode = ALL_MODES.index(mode)

def capture_begin(self):
CoreCUDAGraph.begin_capture(self._place, self._mode)
def capture_begin(self, create_cuda_graph_stream=False):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个属性没有必要加,可以在cuda_graph_with_memory_pool里自动判断是否使用多流方案。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pangyoki pangyoki force-pushed the cuda_graph_support_multi_stream_for_new_exe branch from 0a34652 to 6a4ab46 Compare March 13, 2023 07:29
@pangyoki pangyoki force-pushed the cuda_graph_support_multi_stream_for_new_exe branch from 6a4ab46 to 137dbd9 Compare March 13, 2023 08:55
platform::IsCUDAGraphCapturing(),
false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

=> CUDA Graph is not allowed to capture before prepare.
first batch is not clear and confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in PR #51648

if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"op_type can't be memcpy d2h or h2d while using cuda graph."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the dev_ctx_ is CUDAContext, and change the error msg to "Cuda Memory copy d2h/h2d is not allowed while using cuda graph".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in PR #51648

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, to be refined in the next pr.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants