diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 076f65b201b3..5f926834dda2 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -20,6 +20,8 @@ from copy import deepcopy import signal import time +from typing import Tuple, List, Dict +from collections import defaultdict import shlex from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner @@ -263,6 +265,31 @@ def _stable_remove_duplicates(data): return new_list +def parse_node_config(node_config: str) -> Tuple[str, List[int]]: + SLOT_LIST_START = ':' + SLOT_SEP = ',' + + if SLOT_LIST_START not in node_config: + return node_config, [] + + hostname, slots = node_config.split(SLOT_LIST_START) + slots = [int(x) for x in slots.split(SLOT_SEP)] + + return hostname, slots + + +def parse_node_config_list(node_config_list: List[str]) -> Dict[str, List[int]]: + NODE_SEP = '@' + + node_configs = defaultdict(list) + + for node_config in node_config_list.split(NODE_SEP): + hostname, slots = parse_node_config(node_config) + node_configs[hostname] += slots + + return {k: sorted(list(set(v))) for k, v in node_configs.items()} + + def parse_resource_filter(host_info, include_str="", exclude_str=""): '''Parse an inclusion or exclusion string and filter a hostfile dictionary. @@ -277,11 +304,6 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): slot 0 on worker-1. ''' - # Constants that define our syntax - NODE_SEP = '@' - SLOT_LIST_START = ':' - SLOT_SEP = ',' - # Ensure include/exclude are mutually exclusive if (include_str != "") and (exclude_str != ""): raise ValueError('include_str and exclude_str are mutually exclusive.') @@ -299,12 +321,9 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): parse_str = exclude_str # foreach node in the list - for node_config in parse_str.split(NODE_SEP): + for hostname, slots in parse_node_config_list(parse_str).items(): # Node can either be alone or node:slot,slot,slot - if SLOT_LIST_START in node_config: - hostname, slots = node_config.split(SLOT_LIST_START) - slots = [int(x) for x in slots.split(SLOT_SEP)] - + if len(slots) > 0: # sanity checks if hostname not in host_info: raise ValueError(f"Hostname '{hostname}' not found in hostfile") @@ -322,7 +341,6 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): # User just specified the whole node else: - hostname = node_config # sanity check hostname if hostname not in host_info: raise ValueError(f"Hostname '{hostname}' not found in hostfile") @@ -355,8 +373,10 @@ def parse_resource_filter(host_info, include_str="", exclude_str=""): def parse_inclusion_exclusion(resource_pool, inclusion, exclusion): active_resources = collections.OrderedDict() + node_configs = parse_node_config_list(inclusion) + for hostname, slots in resource_pool.items(): - active_resources[hostname] = list(range(slots)) + active_resources[hostname] = node_configs[hostname] if hostname in node_configs else list(range(slots)) return parse_resource_filter(active_resources, include_str=inclusion, exclude_str=exclusion)