Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support CUDA Graph for new executor #49708

Merged
merged 8 commits into from
Jan 17, 2023

Conversation

pangyoki
Copy link
Contributor

@pangyoki pangyoki commented Jan 10, 2023

PR types

New features

PR changes

Others

Describe

为新执行器支持CUDA Graph。
添加FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR开关控制,为1时使用新执行器执行cuda graph。默认为0不开启。

  • 新执行器只能使用fast GC,event GC中的EventQuery方法在新执行器中会直接报错。
  • coalesce_tensor op的输出var都添加到skip_gc_vars中。
  • 使用cuda graph时,不能在新执行器中使用h2d的拷贝给lr赋值。

注:
在为动态图支持CUDA Graph时(PR #42786),在cuda_graph_with_memory_pool.cc中设置了FLAGS_use_stream_safe_cuda_allocator必须为false的限制:动态图使用cuda graph时,如果使用了stream_safe_cuda_allocator,进程退出会出现core dump。
但是在新执行器中,FLAGS_use_stream_safe_cuda_allocator 必须为true,因为新执行器使用cuda graph时,只支持使用fast GC。这一点与动态图的使用冲突。

总结:

  • 动态图下,FLAGS_use_stream_safe_cuda_allocator 需要为false。
  • 静态图新执行器下,FLAGS_use_stream_safe_cuda_allocator 需要为true。

因为上面的原因,将原来的test_cuda_graph单测做了拆分:拆成了动态图下的单测和静态图下的单测两个文件。
因为一个单测文件的allocator可能只初始化一次,而只有在初始化时,才能通过FLAGS_use_stream_safe_cuda_allocator 环境变量指定是否使用stream safe allocator。
如果单测文件中同时有动态图和静态图的测试代码,FLAGS_use_stream_safe_cuda_allocator 会冲突。

@paddle-bot
Copy link

paddle-bot bot commented Jan 10, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

return (memory::allocation::AllocatorFacade::Instance()
.IsStreamSafeCUDAAllocatorUsed() &&
FLAGS_fast_eager_deletion_mode) ||
FLAGS_new_executor_use_cuda_graph;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsStreamSafeCUDAAllocatorUsed() == false且开启了FLAGS_new_executor_use_cuda_graph时,建议在这里添加报错拦截。

Copy link
Contributor Author

@pangyoki pangyoki Jan 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done。
但是如PR的description中描述的,动态图与静态图使用CUDA Graph时,FLAGS_use_stream_safe_cuda_allocator 的行为不一样。
动态图使用CUDA Graph时,FLAGS_use_stream_safe_cuda_allocator 必须为False。
静态图新执行器使用CUDA Graph时,FLAGS_use_stream_safe_cuda_allocator 必须为True。
导致单测也做了拆分。
后续动态图需要定位并修复FLAGS_use_stream_safe_cuda_allocator 必须为False的问题。

true,
platform::errors::InvalidArgument(
"CUDA Graph is only supported on NVIDIA GPU device."));
PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议在代码中添加注释说明为何不能开启FLAGS_sync_nccl_allreduce,以及为何coalesce_tensor的输出需要设置persistable。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done。这些策略都参考PE的写法,与PE行为对齐。
按FLAGS_sync_nccl_allreduce的描述,这个flag开启时,会在allreduce后面加cudaStreamSynchronize,cuda graph不支持。
coalesce_tensor的输出fused var gc时,使用cuda graph会出现精度问题。具体原因还需要再分析。

build_strategy is not None
and build_strategy.allow_cuda_graph_capture
):
build_strategy.allow_cuda_graph_capture = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处更改allow_cuda_graph_capture的值,建议在PE构造结束后将其恢复为原值。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -1746,18 +1755,7 @@ def _can_use_interpreter_core(program, place):
)
return False

# Unsupported case 4: CUDA Graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新执行器CUDA Graph功能当前只通过了CI单测,未经过模型验证,不建议直接将其默认打开。建议添加临时控制开关让CUDA Graph可以手动开启新执行器,默认为关,等经过充分的模型验证之后再开启。

Copy link
Contributor Author

@pangyoki pangyoki Jan 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,添加FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR开关控制,为1时使用新执行器执行cuda graph。默认为0。

@pangyoki pangyoki force-pushed the new_exe_support_cuda_graph branch 2 times, most recently from 3e5233f to ce1e7d4 Compare January 15, 2023 19:40
@pangyoki pangyoki force-pushed the new_exe_support_cuda_graph branch from ce1e7d4 to 595a1b0 Compare January 15, 2023 19:52
@pangyoki pangyoki force-pushed the new_exe_support_cuda_graph branch from cf08aa7 to 5a93c60 Compare January 16, 2023 05:08
@pangyoki pangyoki force-pushed the new_exe_support_cuda_graph branch from 9c6776c to 4e24edf Compare January 16, 2023 09:25
@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Jan 16, 2023
@PaddlePaddle PaddlePaddle unlocked this conversation Jan 16, 2023
Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for change parallel_UT_rule.py

Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for new Flag

@pangyoki pangyoki merged commit 8e5ed04 into PaddlePaddle:develop Jan 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants