Skip to content

Commit

Permalink
Fix device selection using CUDA_VISIBLE_DEVICES (#6530)
Browse files Browse the repository at this point in the history
This PR addresses #5818.
Instead of contiguous numbers based on the device count, this PR uses
device indices in `--include`.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 8, 2024
1 parent f74ea69 commit 5cbbff4
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from copy import deepcopy
import signal
import time
from typing import Tuple, List, Dict
from collections import defaultdict
import shlex

from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner
Expand Down Expand Up @@ -263,6 +265,31 @@ def _stable_remove_duplicates(data):
return new_list


def parse_node_config(node_config: str) -> Tuple[str, List[int]]:
SLOT_LIST_START = ':'
SLOT_SEP = ','

if SLOT_LIST_START not in node_config:
return node_config, []

hostname, slots = node_config.split(SLOT_LIST_START)
slots = [int(x) for x in slots.split(SLOT_SEP)]

return hostname, slots


def parse_node_config_list(node_config_list: List[str]) -> Dict[str, List[int]]:
NODE_SEP = '@'

node_configs = defaultdict(list)

for node_config in node_config_list.split(NODE_SEP):
hostname, slots = parse_node_config(node_config)
node_configs[hostname] += slots

return {k: sorted(list(set(v))) for k, v in node_configs.items()}


def parse_resource_filter(host_info, include_str="", exclude_str=""):
'''Parse an inclusion or exclusion string and filter a hostfile dictionary.
Expand All @@ -277,11 +304,6 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""):
slot 0 on worker-1.
'''

# Constants that define our syntax
NODE_SEP = '@'
SLOT_LIST_START = ':'
SLOT_SEP = ','

# Ensure include/exclude are mutually exclusive
if (include_str != "") and (exclude_str != ""):
raise ValueError('include_str and exclude_str are mutually exclusive.')
Expand All @@ -299,12 +321,9 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""):
parse_str = exclude_str

# foreach node in the list
for node_config in parse_str.split(NODE_SEP):
for hostname, slots in parse_node_config_list(parse_str).items():
# Node can either be alone or node:slot,slot,slot
if SLOT_LIST_START in node_config:
hostname, slots = node_config.split(SLOT_LIST_START)
slots = [int(x) for x in slots.split(SLOT_SEP)]

if len(slots) > 0:
# sanity checks
if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
Expand All @@ -322,7 +341,6 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""):

# User just specified the whole node
else:
hostname = node_config
# sanity check hostname
if hostname not in host_info:
raise ValueError(f"Hostname '{hostname}' not found in hostfile")
Expand Down Expand Up @@ -355,8 +373,10 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""):

def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
active_resources = collections.OrderedDict()
node_configs = parse_node_config_list(inclusion)

for hostname, slots in resource_pool.items():
active_resources[hostname] = list(range(slots))
active_resources[hostname] = node_configs[hostname] if hostname in node_configs else list(range(slots))

return parse_resource_filter(active_resources, include_str=inclusion, exclude_str=exclusion)

Expand Down

0 comments on commit 5cbbff4

Please sign in to comment.