Skip to content

Commit

Permalink
Support UUID in CUDA_VISIBLE_DEVICES (#437)
Browse files Browse the repository at this point in the history
* Allow parsing CUDA_VISIBLE_DEVICES with UUID

Add new parse_cuda_visible_device utility function to parse UUIDs

* Add test CUDA_VISIBLE_DEVICES parsing with UUID

* Move cuda_visible_devices to utils.py

* Fix formatting

* Fix parse_cuda_visible_device doc typo

Co-authored-by: Mads R. B. Kristensen <madsbk@gmail.com>
  • Loading branch information
pentschev and madsbk authored Nov 3, 2020
1 parent b4e134d commit 3cefcc0
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 25 deletions.
2 changes: 1 addition & 1 deletion dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

from .device_host_file import DeviceHostFile
from .initialize import initialize
from .local_cuda_cluster import cuda_visible_devices
from .utils import (
CPUAffinity,
RMMSetup,
cuda_visible_devices,
get_cpu_affinity,
get_device_total_memory,
get_n_gpus,
Expand Down
32 changes: 8 additions & 24 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,15 @@
from .utils import (
CPUAffinity,
RMMSetup,
cuda_visible_devices,
get_cpu_affinity,
get_device_total_memory,
get_n_gpus,
get_ucx_config,
get_ucx_net_devices,
parse_cuda_visible_device,
)


def cuda_visible_devices(i, visible=None):
"""Cycling values for CUDA_VISIBLE_DEVICES environment variable
Examples
--------
>>> cuda_visible_devices(0, range(4))
'0,1,2,3'
>>> cuda_visible_devices(3, range(8))
'3,4,5,6,7,0,1,2'
"""
if visible is None:
try:
visible = map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(","))
except KeyError:
visible = range(get_n_gpus())
visible = list(visible)

L = visible[i:] + visible[:i]
return ",".join(map(str, L))


class LocalCUDACluster(LocalCluster):
"""A variant of LocalCluster that uses one GPU per process
Expand Down Expand Up @@ -159,7 +139,9 @@ def __init__(
CUDA_VISIBLE_DEVICES = cuda_visible_devices(0)
if isinstance(CUDA_VISIBLE_DEVICES, str):
CUDA_VISIBLE_DEVICES = CUDA_VISIBLE_DEVICES.split(",")
CUDA_VISIBLE_DEVICES = list(map(int, CUDA_VISIBLE_DEVICES))
CUDA_VISIBLE_DEVICES = list(
map(parse_cuda_visible_device, CUDA_VISIBLE_DEVICES)
)
if n_workers is None:
n_workers = len(CUDA_VISIBLE_DEVICES)
self.host_memory_limit = parse_memory_limit(
Expand Down Expand Up @@ -283,7 +265,9 @@ def new_worker_spec(self):
visible_devices = cuda_visible_devices(worker_count, self.cuda_visible_devices)
spec["options"].update(
{
"env": {"CUDA_VISIBLE_DEVICES": visible_devices,},
"env": {
"CUDA_VISIBLE_DEVICES": visible_devices,
},
"plugins": {
CPUAffinity(get_cpu_affinity(worker_count)),
RMMSetup(self.rmm_pool_size, self.rmm_managed_memory),
Expand Down
38 changes: 38 additions & 0 deletions dask_cuda/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from numba import cuda

from dask_cuda.utils import (
cuda_visible_devices,
get_cpu_affinity,
get_device_total_memory,
get_gpu_count,
get_n_gpus,
get_preload_options,
get_ucx_config,
get_ucx_net_devices,
parse_cuda_visible_device,
unpack_bitmask,
)

Expand Down Expand Up @@ -181,3 +184,38 @@ def test_get_ucx_config(enable_tcp_over_ucx, enable_infiniband, net_devices):
pass
elif net_devices == "":
assert "net-device" not in ucx_config


def test_parse_visible_devices():
pynvml = pytest.importorskip("pynvml")
pynvml.nvmlInit()
indices = []
uuids = []
for index in range(get_gpu_count()):
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
uuid = pynvml.nvmlDeviceGetUUID(handle).decode("utf-8")

assert parse_cuda_visible_device(index) == index
assert parse_cuda_visible_device(uuid) == uuid

indices.append(str(index))
uuids.append(pynvml.nvmlDeviceGetUUID(handle).decode("utf-8"))

index_devices = ",".join(indices)
os.environ["CUDA_VISIBLE_DEVICES"] = index_devices
for index in range(get_gpu_count()):
visible = cuda_visible_devices(index)
assert visible.split(",")[0] == str(index)

uuid_devices = ",".join(uuids)
os.environ["CUDA_VISIBLE_DEVICES"] = uuid_devices
for index in range(get_gpu_count()):
visible = cuda_visible_devices(index)
assert visible.split(",")[0] == str(uuids[index])

with pytest.raises(ValueError):
parse_cuda_visible_device("Foo")

with pytest.raises(TypeError):
parse_cuda_visible_device(None)
parse_cuda_visible_device([])
54 changes: 54 additions & 0 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,57 @@ def f(x):

def all_to_all(client):
return client.sync(_all_to_all, client=client, asynchronous=client.asynchronous)


def parse_cuda_visible_device(dev):
"""Parses a single CUDA device identifier
A device identifier must either be an integer, a string containing an
integer or a string containing the device's UUID, beginning with prefix
'GPU-' or 'MIG-GPU'.
>>> parse_cuda_visible_device(2)
2
>>> parse_cuda_visible_device('2')
2
>>> parse_cuda_visible_device('GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005')
'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005'
>>> parse_cuda_visible_device('Foo')
Traceback (most recent call last):
...
ValueError: Devices in CUDA_VISIBLE_DEVICES must be comma-separated integers or
strings beginning with 'GPU-' or 'MIG-GPU-' prefixes.
"""
try:
return int(dev)
except ValueError:
if any(dev.startswith(prefix) for prefix in ["GPU-", "MIG-GPU-"]):
return dev
else:
raise ValueError(
"Devices in CUDA_VISIBLE_DEVICES must be comma-separated integers "
"or strings beginning with 'GPU-' or 'MIG-GPU-' prefixes."
)


def cuda_visible_devices(i, visible=None):
"""Cycling values for CUDA_VISIBLE_DEVICES environment variable
Examples
--------
>>> cuda_visible_devices(0, range(4))
'0,1,2,3'
>>> cuda_visible_devices(3, range(8))
'3,4,5,6,7,0,1,2'
"""
if visible is None:
try:
visible = map(
parse_cuda_visible_device, os.environ["CUDA_VISIBLE_DEVICES"].split(",")
)
except KeyError:
visible = range(get_n_gpus())
visible = list(visible)

L = visible[i:] + visible[:i]
return ",".join(map(str, L))

0 comments on commit 3cefcc0

Please sign in to comment.