-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Custom Device]add run_check support for custom device #56318
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,7 +81,7 @@ def _is_xpu_available(): | |
return False | ||
|
||
|
||
def _run_dygraph_single(use_cuda, use_xpu): | ||
def _run_dygraph_single(use_cuda, use_xpu, use_custom, custom_device_name): | ||
""" | ||
Testing the simple network in dygraph mode using one CPU/GPU/XPU. | ||
|
||
|
@@ -94,6 +94,8 @@ def _run_dygraph_single(use_cuda, use_xpu): | |
paddle.set_device('gpu') | ||
elif use_xpu: | ||
paddle.set_device('xpu') | ||
elif use_custom: | ||
paddle.set_device(custom_device_name) | ||
else: | ||
paddle.set_device('cpu') | ||
weight_attr = paddle.ParamAttr( | ||
|
@@ -116,7 +118,7 @@ def _run_dygraph_single(use_cuda, use_xpu): | |
opt.step() | ||
|
||
|
||
def _run_static_single(use_cuda, use_xpu): | ||
def _run_static_single(use_cuda, use_xpu, use_custom, custom_device_name): | ||
""" | ||
Testing the simple network with executor running directly, using one CPU/GPU/XPU. | ||
|
||
|
@@ -139,6 +141,8 @@ def _run_static_single(use_cuda, use_xpu): | |
place = paddle.CUDAPlace(0) | ||
elif use_xpu: | ||
place = paddle.XPUPlace(0) | ||
elif use_custom: | ||
place = paddle.CustomPlace(custom_device_name, 0) | ||
else: | ||
place = paddle.CPUPlace() | ||
|
||
|
@@ -229,29 +233,53 @@ def run_check(): | |
|
||
use_cuda = False | ||
use_xpu = False | ||
use_custom = False | ||
custom_device_name = None | ||
|
||
if paddle.is_compiled_with_cuda(): | ||
use_cuda = _is_cuda_available() | ||
elif paddle.is_compiled_with_xpu(): | ||
use_xpu = _is_xpu_available() | ||
elif len(paddle.framework.core.get_all_custom_device_type()) > 0: | ||
use_custom = True | ||
if len(paddle.framework.core.get_all_custom_device_type()) > 1: | ||
logging.warning( | ||
"More than one kind of custom devices detected, but run check would only be executed on {}.".format( | ||
paddle.framework.core.get_all_custom_device_type()[0] | ||
) | ||
) | ||
|
||
if use_cuda: | ||
device_str = "GPU" | ||
device_list = paddle.static.cuda_places() | ||
elif use_xpu: | ||
device_str = "XPU" | ||
device_list = paddle.static.xpu_places() | ||
elif use_custom: | ||
device_str = paddle.framework.core.get_all_custom_device_type()[0] | ||
custom_device_name = device_str | ||
device_list = list( | ||
range( | ||
paddle.framework.core.get_custom_device_count( | ||
custom_device_name | ||
) | ||
) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里默认只跑 device[0],判断一下如果有多个device注册,这里加点warning message提示下只对你device[0]进行检测 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
else: | ||
device_str = "CPU" | ||
device_list = paddle.static.cpu_places(device_count=1) | ||
device_count = len(device_list) | ||
|
||
_run_static_single(use_cuda, use_xpu) | ||
_run_dygraph_single(use_cuda, use_xpu) | ||
_run_static_single(use_cuda, use_xpu, use_custom, custom_device_name) | ||
_run_dygraph_single(use_cuda, use_xpu, use_custom, custom_device_name) | ||
print(f"PaddlePaddle works well on 1 {device_str}.") | ||
|
||
try: | ||
if len(device_list) > 1: | ||
if use_custom: | ||
import os | ||
|
||
os.environ['PADDLE_DISTRI_BACKEND'] = "xccl" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 分布式通过读取环境变量 |
||
_run_parallel(device_list) | ||
print( | ||
"PaddlePaddle works well on {} {}s.".format( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,这里修改为支持所有通过custom device注册的硬件类型,不要用过字符串判断。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done