Skip to content

Commit

Permalink
[CustomDevice] fix get_paddle_place
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed Jul 10, 2023
1 parent 2fc429f commit d658880
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7623,8 +7623,15 @@ def _get_paddle_place(place):
device_id = int(device_id)
return core.IPUPlace(device_id)

place_info_list = place.split(':', 1)
device_type = place_info_list[0]
if device_type in core.get_all_custom_device_type():
device_id = place_info_list[1]
device_id = int(device_id)
return core.CustomPlace(device_type, device_id)

raise ValueError(
f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace and IPUPlace, but received {place}."
f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace, IPUPlace and CustomPlace, but received {place}."
)


Expand Down

0 comments on commit d658880

Please sign in to comment.