From 480d44162e90825a2998c0268bab6d30373f4790 Mon Sep 17 00:00:00 2001 From: Chen Date: Sat, 31 Aug 2024 13:21:54 -0400 Subject: [PATCH 01/13] feat: feat: Add warning for unassigned main devices --- src/accelerate/utils/modeling.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index bd562236162..c094e728e7f 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1247,6 +1247,7 @@ def infer_auto_device_map( current_memory_used = 0 device_memory_used = {} device_buffer_sizes = {} + device_minimum_assignment_memory = {} # Direct submodules and parameters modules_to_treat = ( @@ -1318,7 +1319,8 @@ def infer_auto_device_map( # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") - + if current_memory_used == 0: + device_minimum_assignment_memory[device] = module_size + current_memory_reserved device_memory_used[device] = current_memory_used + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat @@ -1417,6 +1419,8 @@ def infer_auto_device_map( # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") + if current_memory_used == 0: + device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved device_memory_used[device] = current_memory_used + current_memory_reserved current_device += 1 @@ -1465,6 +1469,13 @@ def infer_auto_device_map( f"offload_buffers=True." ) + for device, mem in device_minimum_assignment_memory.items(): + warnings.warn( + f"No modules could be assigned to {device} as the minimum memory required is {mem} " + f"for the current calculation, which is higher than the available memory {max_memory[device]}." + f"Consider increasing the memory available." + ) + return device_map From 4f89c26373335a529ae6690b3d1f26ff3d75dd87 Mon Sep 17 00:00:00 2001 From: Chen Date: Mon, 2 Sep 2024 10:02:00 -0400 Subject: [PATCH 02/13] refactor: Improve warning for unassigned main devices --- src/accelerate/utils/modeling.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index c094e728e7f..b4ccbe425f1 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1469,13 +1469,17 @@ def infer_auto_device_map( f"offload_buffers=True." ) - for device, mem in device_minimum_assignment_memory.items(): + if device_minimum_assignment_memory: + devices_info = "\n".join( + f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items() + ) warnings.warn( - f"No modules could be assigned to {device} as the minimum memory required is {mem} " - f"for the current calculation, which is higher than the available memory {max_memory[device]}." - f"Consider increasing the memory available." + f"Based on the current allocation process, no modules could be assigned to the following devices due to" + f"insufficient memory:\n" + f"{devices_info}\n" + f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing" + f"the available memory for these devices to at least the specified minimum, or adjusting the model config." ) - return device_map From aefd770fa1e270caecea0dae908016e4fee882e6 Mon Sep 17 00:00:00 2001 From: Chen Date: Fri, 6 Sep 2024 11:14:49 -0400 Subject: [PATCH 03/13] feat: impl fallback_allocate; fix output format --- src/accelerate/utils/modeling.py | 107 ++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index b4ccbe425f1..1b6dcd43bb0 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1165,6 +1165,109 @@ def calculate_maximum_sizes(model: torch.nn.Module): return total_size, largest_layer +def fallback_allocate( + modules: List[Tuple[str, nn.Module]], + module_sizes: Dict[str, int], + size_limit: Union[int, str], + no_split_module_classes: Optional[List[str]] = None, + tied_parameters: Optional[List[List[str]]] = None, +) -> Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]: + """ + Find a module that fits in the size limit using BFS and return it with its name and the remaining modules. + + Args: + modules (`List[Tuple[str, nn.Module]]`): + The list of named modules to search in. + module_sizes (`Dict[str, int]`): + A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`). + size_limit (`Union[int, str]`): + The maximum size a module can have. + no_split_module_classes (`Optional[List[str]]`, *optional*): + A list of class names for layers we don't want to be split. + tied_parameters (`Optional[List[List[str]]`, *optional*): + A list of lists of parameter names being all tied together. + + Returns: + `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: + The name of the module that fits in the + """ + size_limit = convert_file_size_to_int(size_limit) + + if no_split_module_classes is None: + no_split_module_classes = [] + + if tied_parameters is None: + tied_parameters = [] + + modules_to_search = modules.copy() + module_found = False + + while modules_to_search: + name, module = modules_to_search.pop(0) + if module_sizes[name] <= size_limit: + tied_param_groups = [ + tied_group + for tied_group in tied_parameters + if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) + ] + + tied_params = sum( + [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] + ) + + if not tied_params: + module_found = True + break + + tied_module_names = [] + for tied_param in tied_params: + tied_module_indices = [i for i, (n, _) in enumerate(modules_to_search) if n in tied_param] + if tied_module_indices: + tied_module_names.append(modules_to_search[tied_module_indices[0]][0]) + + module_size_with_ties = module_sizes[name] + for tied_module_name in tied_module_names: + module_size_with_ties += module_sizes[tied_module_name] + + if module_size_with_ties <= size_limit: + module_found = True + break + + if not isinstance(module, nn.Module) or module.__class__.__name__ in no_split_module_classes: + continue + + modules_children = list(module.named_children()) + if not modules_children: + continue + + modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search + + if not module_found: + return None, None, modules + + # Prepare the module list for removal of the found module + current_names = [n for n, _ in modules] + dot_idx = [i for i, c in enumerate(name) if c == '.'] + + for dot_index in dot_idx[:-1]: + parent_name = name[:dot_index] + if parent_name in current_names: + parent_module_idx = current_names.index(parent_name) + _, parent_module = modules[parent_module_idx] + module_children = (list(parent_module.named_parameters(recurse=False)) + + list(parent_module.named_children())) + modules = (modules[:parent_module_idx] + + [(f"{parent_name}.{n}", v) for n, v in module_children] + + modules[parent_module_idx + 1:]) + current_names = [n for n, _ in modules] + + # Now the target module should be directly in the list + target_idx = current_names.index(name) + name, module = modules.pop(target_idx) + + return name, module, modules + + def infer_auto_device_map( model: nn.Module, max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, @@ -1474,10 +1577,10 @@ def infer_auto_device_map( f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items() ) warnings.warn( - f"Based on the current allocation process, no modules could be assigned to the following devices due to" + f"Based on the current allocation process, no modules could be assigned to the following devices due to " f"insufficient memory:\n" f"{devices_info}\n" - f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing" + f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing " f"the available memory for these devices to at least the specified minimum, or adjusting the model config." ) return device_map From dc6641c5855c54a157616e2c128da18d4881afb1 Mon Sep 17 00:00:00 2001 From: Chen Date: Sat, 14 Sep 2024 23:51:41 -0400 Subject: [PATCH 04/13] fix: include last dot index in the iteration --- src/accelerate/utils/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 1b6dcd43bb0..e1e61382955 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1249,7 +1249,7 @@ def fallback_allocate( current_names = [n for n, _ in modules] dot_idx = [i for i, c in enumerate(name) if c == '.'] - for dot_index in dot_idx[:-1]: + for dot_index in dot_idx: parent_name = name[:dot_index] if parent_name in current_names: parent_module_idx = current_names.index(parent_name) From d607bfb530517478b90aa89c2a87a03c318a2e58 Mon Sep 17 00:00:00 2001 From: Chen Date: Wed, 25 Sep 2024 23:01:06 -0400 Subject: [PATCH 05/13] feat: incorporate fallback allocation into infer_auto_device_map --- src/accelerate/utils/modeling.py | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index e1e61382955..99d969dd136 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1277,6 +1277,7 @@ def infer_auto_device_map( verbose: bool = False, clean_result: bool = True, offload_buffers: bool = False, + fallback_allocation: bool = False, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -1351,6 +1352,7 @@ def infer_auto_device_map( device_memory_used = {} device_buffer_sizes = {} device_minimum_assignment_memory = {} + fallback_attempted = False # Direct submodules and parameters modules_to_treat = ( @@ -1422,8 +1424,27 @@ def infer_auto_device_map( # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") + + if fallback_allocation and devices[current_device] in main_devices and \ + current_memory_used == 0 and not fallback_attempted: + fallback_module_name, fallback_module, remaining_modules = fallback_allocate( + modules_to_treat, + module_sizes, + current_max_size - current_memory_used, + no_split_module_classes, + tied_parameters, + ) + + # use the next iteration to put the fallback module on the next device to avoid code duplication + if fallback_module is not None: + modules_to_treat = [(fallback_module_name, fallback_module)]\ + + [(name, module)]\ + + remaining_modules + continue + if current_memory_used == 0: device_minimum_assignment_memory[device] = module_size + current_memory_reserved + device_memory_used[device] = current_memory_used + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat @@ -1522,6 +1543,24 @@ def infer_auto_device_map( # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") + + if fallback_allocation and devices[current_device] in main_devices and \ + current_memory_used == 0 and not fallback_attempted: + fallback_module_name, fallback_module, remaining_modules = fallback_allocate( + modules_to_treat, + module_sizes, + current_max_size - current_memory_used, + no_split_module_classes, + tied_parameters, + ) + + # use the next iteration to put the fallback module on the next device to avoid code duplication + if fallback_module is not None: + modules_to_treat = [(fallback_module_name, fallback_module)] \ + + [(name, module)] \ + + remaining_modules + continue + if current_memory_used == 0: device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved From f040302700152f3054f315865a0bd42ff55aa44e Mon Sep 17 00:00:00 2001 From: Chen Date: Sun, 6 Oct 2024 17:48:56 -0400 Subject: [PATCH 06/13] Revert "feat: incorporate fallback allocation into infer_auto_device_map" This reverts commit d607bfb530517478b90aa89c2a87a03c318a2e58. --- src/accelerate/utils/modeling.py | 39 -------------------------------- 1 file changed, 39 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 99d969dd136..e1e61382955 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1277,7 +1277,6 @@ def infer_auto_device_map( verbose: bool = False, clean_result: bool = True, offload_buffers: bool = False, - fallback_allocation: bool = False, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -1352,7 +1351,6 @@ def infer_auto_device_map( device_memory_used = {} device_buffer_sizes = {} device_minimum_assignment_memory = {} - fallback_attempted = False # Direct submodules and parameters modules_to_treat = ( @@ -1424,27 +1422,8 @@ def infer_auto_device_map( # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") - - if fallback_allocation and devices[current_device] in main_devices and \ - current_memory_used == 0 and not fallback_attempted: - fallback_module_name, fallback_module, remaining_modules = fallback_allocate( - modules_to_treat, - module_sizes, - current_max_size - current_memory_used, - no_split_module_classes, - tied_parameters, - ) - - # use the next iteration to put the fallback module on the next device to avoid code duplication - if fallback_module is not None: - modules_to_treat = [(fallback_module_name, fallback_module)]\ - + [(name, module)]\ - + remaining_modules - continue - if current_memory_used == 0: device_minimum_assignment_memory[device] = module_size + current_memory_reserved - device_memory_used[device] = current_memory_used + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat @@ -1543,24 +1522,6 @@ def infer_auto_device_map( # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") - - if fallback_allocation and devices[current_device] in main_devices and \ - current_memory_used == 0 and not fallback_attempted: - fallback_module_name, fallback_module, remaining_modules = fallback_allocate( - modules_to_treat, - module_sizes, - current_max_size - current_memory_used, - no_split_module_classes, - tied_parameters, - ) - - # use the next iteration to put the fallback module on the next device to avoid code duplication - if fallback_module is not None: - modules_to_treat = [(fallback_module_name, fallback_module)] \ - + [(name, module)] \ - + remaining_modules - continue - if current_memory_used == 0: device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved From 9db401e7b9efa0fdf7d9ac603eefbcc34e76e517 Mon Sep 17 00:00:00 2001 From: Chen Date: Tue, 8 Oct 2024 22:37:54 -0400 Subject: [PATCH 07/13] refactor: add helper functions and eliminate redundant variables The fallback allocation will be reintroduced once the branching logic is fully refactored. This commit prepares the function infer_auto_device_map for further refactoring. --- src/accelerate/utils/modeling.py | 167 +++++++++++++++++++------------ 1 file changed, 105 insertions(+), 62 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index e1e61382955..8b863221972 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1268,6 +1268,80 @@ def fallback_allocate( return name, module, modules +def init_infer_auto_device_map( + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +) -> Tuple: + max_memory = get_max_memory(max_memory) + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + devices = list(max_memory.keys()) + if "disk" not in devices: + devices.append("disk") + gpus = [device for device in devices if device not in ["cpu", "disk"]] + + # Devices that need to keep space for a potential offloaded layer. + if "mps" in gpus: + main_devices = ["mps"] + elif len(gpus) > 0: + main_devices = [gpus[0], "cpu"] + else: + main_devices = ["cpu"] + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + tied_parameters = find_tied_parameters(model) + + if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: + logger.warn( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + + # Direct submodules and parameters + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + + return ( + devices, + main_devices, + gpus, + module_sizes, + tied_parameters, + no_split_module_classes, + modules_to_treat, + ) + + +def get_module_size_with_ties( + tied_params, + module_size, + module_sizes, + modules_to_treat, +) -> int: + if not tied_params: + return module_size, [], [] + tied_module_names = [] + tied_modules = [] + for tied_param in tied_params: + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] + tied_module_names.append(modules_to_treat[tied_module_index][0]) + tied_modules.append(modules_to_treat[tied_module_index][1]) + + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(tied_params, tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] + + return module_size_with_ties, tied_module_names, tied_modules + + def infer_auto_device_map( model: nn.Module, max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, @@ -1317,47 +1391,24 @@ def infer_auto_device_map( In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters. """ - # Get default / clean up max_memory - max_memory = get_max_memory(max_memory) - if no_split_module_classes is None: - no_split_module_classes = [] - elif not isinstance(no_split_module_classes, (list, tuple)): - no_split_module_classes = [no_split_module_classes] - - devices = list(max_memory.keys()) - if "disk" not in devices: - devices.append("disk") - gpus = [device for device in devices if device not in ["cpu", "disk"]] - - # Devices that need to keep space for a potential offloaded layer. - if "mps" in gpus: - main_devices = ["mps"] - elif len(gpus) > 0: - main_devices = [gpus[0], "cpu"] - else: - main_devices = ["cpu"] - - module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) - tied_parameters = find_tied_parameters(model) - if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: - logger.warn( - "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." - ) + # Initialize the variables + ( + devices, + main_devices, + gpus, + module_sizes, + tied_parameters, + no_split_module_classes, + modules_to_treat, + ) = init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes) device_map = OrderedDict() current_device = 0 - current_memory_used = 0 - device_memory_used = {} + device_memory_used = {device: 0 for device in devices} device_buffer_sizes = {} device_minimum_assignment_memory = {} - # Direct submodules and parameters - modules_to_treat = ( - list(model.named_parameters(recurse=False)) - + list(model.named_children()) - + list(model.named_buffers(recurse=False)) - ) # Initialize maximum largest layer, to know which space to keep in memory max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) @@ -1381,18 +1432,18 @@ def infer_auto_device_map( # and the other is not. # Note: If we are currently processing the name `compute.weight`, an other parameter named e.g. `compute.weight_submodule.parameter` # needs to be considered outside the current module, hence the check with additional dots. - tied_param_goups = [ + tied_param_groups = [ tied_group for tied_group in tied_parameters if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) ] - if verbose and len(tied_param_goups) > 0: - print(f" Found the relevant tied param groups {tied_param_goups}") + if verbose and len(tied_param_groups) > 0: + print(f" Found the relevant tied param groups {tied_param_groups}") # Then we keep track of all the parameters that are tied to the current module, but not in the current module tied_params = sum( - [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_goups], [] + [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] ) if verbose and len(tied_params) > 0: @@ -1405,8 +1456,9 @@ def infer_auto_device_map( if devices[current_device] in main_devices: current_max_size = current_max_size - max_layer_size current_memory_reserved = max_layer_size + # Case 1 -> We're too big! - if current_max_size is not None and current_memory_used + module_size > current_max_size: + if current_max_size is not None and device_memory_used[device] + module_size > current_max_size: # Split or not split? modules_children = ( [] @@ -1416,18 +1468,17 @@ def infer_auto_device_map( if verbose: print( f"Not enough space on {devices[current_device]} to put {name} (space available " - f"{current_max_size - current_memory_used}, module size {module_size})." + f"{current_max_size - device_memory_used[device]}, module size {module_size})." ) if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") - if current_memory_used == 0: + if device_memory_used[device] == 0: device_minimum_assignment_memory[device] = module_size + current_memory_reserved - device_memory_used[device] = current_memory_used + current_memory_reserved + device_memory_used[device] = device_memory_used[device] + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat - current_memory_used = 0 else: # -> split, we replace the module studied by its children + parameters if verbose: @@ -1444,12 +1495,7 @@ def infer_auto_device_map( # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. elif len(tied_params) > 0: # First locate all tied modules - tied_module_names = [] - tied_modules = [] - for tied_param in tied_params: - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] - tied_module_names.append(modules_to_treat[tied_module_index][0]) - tied_modules.append(modules_to_treat[tied_module_index][1]) + module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat) if verbose: print( f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " @@ -1457,15 +1503,11 @@ def infer_auto_device_map( ) # Let's see if it all fits first - module_size_with_ties = module_size - for tied_param, tied_module_name in zip(tied_params, tied_module_names): - module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] - - if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size: + if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: # We really really fit! if verbose: print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") - current_memory_used += module_size_with_ties + device_memory_used[device] += module_size_with_ties device_map[name] = devices[current_device] for tied_module_name in tied_module_names: if tied_module_name in [m[0] for m in modules_to_treat]: @@ -1488,7 +1530,7 @@ def infer_auto_device_map( if verbose: print( f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " - f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})." + f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." ) split_happened = False for tied_module_name, tied_module in zip(tied_module_names, tied_modules): @@ -1522,13 +1564,13 @@ def infer_auto_device_map( # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") - if current_memory_used == 0: + if device_memory_used[device] == 0: device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved - device_memory_used[device] = current_memory_used + current_memory_reserved + device_memory_used[device] = device_memory_used[device] + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat - current_memory_used = 0 + device_memory_used[device] = 0 else: if verbose: @@ -1537,10 +1579,9 @@ def infer_auto_device_map( else: print( f"Putting {name} (size={module_size}) on {devices[current_device]} " - f"(available={current_max_size - current_memory_used})." + f"(available={current_max_size - device_memory_used[device]})." ) - current_memory_used += module_size - device_memory_used[device] = current_memory_used + current_memory_reserved + device_memory_used[device] += module_size device_map[name] = devices[current_device] if not offload_buffers and isinstance(module, nn.Module): @@ -1549,6 +1590,8 @@ def infer_auto_device_map( ) device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} + if clean_result: device_map = clean_device_map(device_map) From 0b52bdcf70d586082b1511322506b94e9ba6b3d4 Mon Sep 17 00:00:00 2001 From: Chen Date: Sat, 12 Oct 2024 00:16:24 -0400 Subject: [PATCH 08/13] refactor: simplify allocation logic by removing duplicates and reducing nesting --- src/accelerate/utils/modeling.py | 201 +++++++++++++++---------------- 1 file changed, 100 insertions(+), 101 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 8b863221972..222b5c758ca 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1326,7 +1326,7 @@ def get_module_size_with_ties( module_sizes, modules_to_treat, ) -> int: - if not tied_params: + if len(tied_params) < 1: return module_size, [], [] tied_module_names = [] tied_modules = [] @@ -1457,8 +1457,105 @@ def infer_auto_device_map( current_max_size = current_max_size - max_layer_size current_memory_reserved = max_layer_size - # Case 1 -> We're too big! - if current_max_size is not None and device_memory_used[device] + module_size > current_max_size: + ( + module_size_with_ties, + tied_module_names, + tied_modules + ) = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat) + + # the module and its tied modules fit on the current device + if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: + if verbose: + output = f"Putting {name}" + + if tied_module_names: + output += f" and {tied_module_names}" + else: + output += f" (size={module_size})" + + if current_max_size is not None: + output += f" (available={current_max_size - device_memory_used[device]})" + + output += f" on {device}." + print(output) + + device_memory_used[device] += module_size_with_ties + + # Assign the primary module to the device + device_map[name] = device + + # Assign tied modules if any + for tied_module_name in tied_module_names: + if tied_module_name in [m[0] for m in modules_to_treat]: + # Find the index of the tied module in the list + tied_module_index = next( + i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name + ) + # Remove the tied module from the list to prevent reprocessing + modules_to_treat.pop(tied_module_index) + + # Assign the tied module to the device + device_map[tied_module_name] = device + + # Buffer Handling + if not offload_buffers and isinstance(module, nn.Module): + # Compute the total buffer size for the module + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes + ) + # Update the buffer size on the device + device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + + # if the current module itself fits, we try to split the tied modules + elif len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: + # can we split one of the tied modules to make it smaller or do we need to go on the next device? + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " + f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." + ) + split_happened = False + for tied_module_name, tied_module in zip(tied_module_names, tied_modules): + tied_module_children = list(tied_module.named_children()) + if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: + # can't break this one. + continue + + if verbose: + print(f"Splitting {tied_module_name}.") + tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children + tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] + + modules_to_treat = ( + [(name, module)] + + modules_to_treat[:tied_module_index] + + tied_module_children + + modules_to_treat[tied_module_index + 1:] + ) + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + split_happened = True + break + + if not split_happened: + # If the tied module is not split, we go to the next device + if verbose: + print("None of the tied module can be split, going to the next device.") + if device_memory_used[device] == 0: + device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved + + device_memory_used[device] = device_memory_used[device] + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + device_memory_used[device] = 0 + + # the current module itself doesn't fit, so we have to split it or go to the next device + else: # Split or not split? modules_children = ( [] @@ -1492,104 +1589,6 @@ def infer_auto_device_map( no_split_module_classes, ) - # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. - elif len(tied_params) > 0: - # First locate all tied modules - module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat) - if verbose: - print( - f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " - f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}" - ) - - # Let's see if it all fits first - if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: - # We really really fit! - if verbose: - print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") - device_memory_used[device] += module_size_with_ties - device_map[name] = devices[current_device] - for tied_module_name in tied_module_names: - if tied_module_name in [m[0] for m in modules_to_treat]: - # The module may have been removed by a previous iteration of this loop. - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][ - 0 - ] - modules_to_treat.pop(tied_module_index) - device_map[tied_module_name] = devices[current_device] - - if not offload_buffers and isinstance(module, nn.Module): - current_buffer_size = compute_module_total_buffer_size( - module, dtype=dtype, special_dtypes=special_dtypes - ) - device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size - - else: - # We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it - # smaller or do we need to go on the next device? - if verbose: - print( - f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " - f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." - ) - split_happened = False - for tied_module_name, tied_module in zip(tied_module_names, tied_modules): - tied_module_children = list(tied_module.named_children()) - if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: - # can't break this one. - continue - - if verbose: - print(f"Splitting {tied_module_name}.") - tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children - tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] - - modules_to_treat = ( - [(name, module)] - + modules_to_treat[:tied_module_index] - + tied_module_children - + modules_to_treat[tied_module_index + 1 :] - ) - # Update the max layer size. - max_layer_size, max_layer_names = get_max_layer_size( - [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], - module_sizes, - no_split_module_classes, - ) - split_happened = True - break - - if not split_happened: - # If the tied module is not split, we go to the next device - if verbose: - print("None of the tied module can be split, going to the next device.") - if device_memory_used[device] == 0: - device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved - - device_memory_used[device] = device_memory_used[device] + current_memory_reserved - current_device += 1 - modules_to_treat = [(name, module)] + modules_to_treat - device_memory_used[device] = 0 - - else: - if verbose: - if current_max_size is None: - print(f"Putting {name} (size={module_size}) on {devices[current_device]}.") - else: - print( - f"Putting {name} (size={module_size}) on {devices[current_device]} " - f"(available={current_max_size - device_memory_used[device]})." - ) - device_memory_used[device] += module_size - device_map[name] = devices[current_device] - - if not offload_buffers and isinstance(module, nn.Module): - current_buffer_size = compute_module_total_buffer_size( - module, dtype=dtype, special_dtypes=special_dtypes - ) - device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size - device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} if clean_result: From 7ff9f2679cbbd30b962e42e5bb0fdc2265ba4d68 Mon Sep 17 00:00:00 2001 From: Chen Date: Sun, 13 Oct 2024 17:07:43 -0400 Subject: [PATCH 09/13] feat: incorporate fallback allocation into infer_auto_device_map Implemented fallback allocation to allow modules to be allocated to devices using BFS when regular allocation fails. This enhancement improves the allocation process by ensuring that at least one module is assigned to the device, even under tight memory constraints. --- src/accelerate/utils/modeling.py | 271 ++++++++++++++++++------------- tests/test_modeling_utils.py | 135 +++++++++++++++ 2 files changed, 289 insertions(+), 117 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 222b5c758ca..487a35d1a1d 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1165,6 +1165,97 @@ def calculate_maximum_sizes(model: torch.nn.Module): return total_size, largest_layer +def _init_infer_auto_device_map( + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +) -> Tuple: + """ + Initialize variables required for computing the device map for model allocation. + """ + max_memory = get_max_memory(max_memory) + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + devices = list(max_memory.keys()) + if "disk" not in devices: + devices.append("disk") + gpus = [device for device in devices if device not in ["cpu", "disk"]] + + # Devices that need to keep space for a potential offloaded layer. + if "mps" in gpus: + main_devices = ["mps"] + elif len(gpus) > 0: + main_devices = [gpus[0], "cpu"] + else: + main_devices = ["cpu"] + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + tied_parameters = find_tied_parameters(model) + + if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: + logger.warn( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + + # Direct submodules and parameters + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + + return ( + devices, + main_devices, + gpus, + module_sizes, + tied_parameters, + no_split_module_classes, + modules_to_treat, + ) + + +def get_module_size_with_ties( + tied_params, + module_size, + module_sizes, + modules_to_treat, +) -> Tuple[int, List[str], List[nn.Module]]: + """ + Calculate the total size of a module, including its tied parameters. + + Args: + tied_params (`List[str]`): The list of tied parameters. + module_size (`int`): The size of the module without tied parameters. + module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size. + modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat. + + Returns: + `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the + tied modules. + """ + if len(tied_params) < 1: + return module_size, [], [] + tied_module_names = [] + tied_modules = [] + + for tied_param in tied_params: + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] + tied_module_names.append(modules_to_treat[tied_module_index][0]) + tied_modules.append(modules_to_treat[tied_module_index][1]) + + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(tied_params, tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] + + return module_size_with_ties, tied_module_names, tied_modules + + def fallback_allocate( modules: List[Tuple[str, nn.Module]], module_sizes: Dict[str, int], @@ -1189,9 +1280,15 @@ def fallback_allocate( Returns: `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: - The name of the module that fits in the + A tuple containing: + - The name of the module that fits within the size limit. + - The module itself. + - The list of remaining modules after the found module is removed. """ - size_limit = convert_file_size_to_int(size_limit) + try: + size_limit = convert_file_size_to_int(size_limit) + except ValueError: + return None, None, modules if no_split_module_classes is None: no_split_module_classes = [] @@ -1215,19 +1312,9 @@ def fallback_allocate( [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] ) - if not tied_params: - module_found = True - break - - tied_module_names = [] - for tied_param in tied_params: - tied_module_indices = [i for i, (n, _) in enumerate(modules_to_search) if n in tied_param] - if tied_module_indices: - tied_module_names.append(modules_to_search[tied_module_indices[0]][0]) - - module_size_with_ties = module_sizes[name] - for tied_module_name in tied_module_names: - module_size_with_ties += module_sizes[tied_module_name] + module_size_with_ties, _, _ = get_module_size_with_ties( + tied_params, module_sizes[name], module_sizes, modules_to_search + ) if module_size_with_ties <= size_limit: module_found = True @@ -1247,7 +1334,7 @@ def fallback_allocate( # Prepare the module list for removal of the found module current_names = [n for n, _ in modules] - dot_idx = [i for i, c in enumerate(name) if c == '.'] + dot_idx = [i for i, c in enumerate(name) if c == "."] for dot_index in dot_idx: parent_name = name[:dot_index] @@ -1268,80 +1355,6 @@ def fallback_allocate( return name, module, modules -def init_infer_auto_device_map( - model: nn.Module, - max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, - no_split_module_classes: Optional[List[str]] = None, - dtype: Optional[Union[str, torch.dtype]] = None, - special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, -) -> Tuple: - max_memory = get_max_memory(max_memory) - if no_split_module_classes is None: - no_split_module_classes = [] - elif not isinstance(no_split_module_classes, (list, tuple)): - no_split_module_classes = [no_split_module_classes] - - devices = list(max_memory.keys()) - if "disk" not in devices: - devices.append("disk") - gpus = [device for device in devices if device not in ["cpu", "disk"]] - - # Devices that need to keep space for a potential offloaded layer. - if "mps" in gpus: - main_devices = ["mps"] - elif len(gpus) > 0: - main_devices = [gpus[0], "cpu"] - else: - main_devices = ["cpu"] - - module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) - tied_parameters = find_tied_parameters(model) - - if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: - logger.warn( - "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." - ) - - # Direct submodules and parameters - modules_to_treat = ( - list(model.named_parameters(recurse=False)) - + list(model.named_children()) - + list(model.named_buffers(recurse=False)) - ) - - return ( - devices, - main_devices, - gpus, - module_sizes, - tied_parameters, - no_split_module_classes, - modules_to_treat, - ) - - -def get_module_size_with_ties( - tied_params, - module_size, - module_sizes, - modules_to_treat, -) -> int: - if len(tied_params) < 1: - return module_size, [], [] - tied_module_names = [] - tied_modules = [] - for tied_param in tied_params: - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] - tied_module_names.append(modules_to_treat[tied_module_index][0]) - tied_modules.append(modules_to_treat[tied_module_index][1]) - - module_size_with_ties = module_size - for tied_param, tied_module_name in zip(tied_params, tied_module_names): - module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] - - return module_size_with_ties, tied_module_names, tied_modules - - def infer_auto_device_map( model: nn.Module, max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, @@ -1351,6 +1364,7 @@ def infer_auto_device_map( verbose: bool = False, clean_result: bool = True, offload_buffers: bool = False, + fallback_allocation: bool = False, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -1390,6 +1404,8 @@ def infer_auto_device_map( offload_buffers (`bool`, *optional*, defaults to `False`): In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters. + fallback_allocation (`bool`, *optional*, defaults to `False`): + When regular allocation fails, try to allocate a module that fits in the size limit using BFS. """ # Initialize the variables @@ -1401,7 +1417,7 @@ def infer_auto_device_map( tied_parameters, no_split_module_classes, modules_to_treat, - ) = init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes) + ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes) device_map = OrderedDict() current_device = 0 @@ -1430,7 +1446,8 @@ def infer_auto_device_map( # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module # and the other is not. - # Note: If we are currently processing the name `compute.weight`, an other parameter named e.g. `compute.weight_submodule.parameter` + # Note: If we are currently processing the name `compute.weight`, an other parameter named + # e.g. `compute.weight_submodule.parameter` # needs to be considered outside the current module, hence the check with additional dots. tied_param_groups = [ tied_group @@ -1457,13 +1474,11 @@ def infer_auto_device_map( current_max_size = current_max_size - max_layer_size current_memory_reserved = max_layer_size - ( - module_size_with_ties, - tied_module_names, - tied_modules - ) = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat) + module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties( + tied_params, module_size, module_sizes, modules_to_treat + ) - # the module and its tied modules fit on the current device + # The module and its tied modules fit on the current device. if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: if verbose: output = f"Putting {name}" @@ -1481,10 +1496,10 @@ def infer_auto_device_map( device_memory_used[device] += module_size_with_ties - # Assign the primary module to the device + # Assign the primary module to the device. device_map[name] = device - # Assign tied modules if any + # Assign tied modules if any. for tied_module_name in tied_module_names: if tied_module_name in [m[0] for m in modules_to_treat]: # Find the index of the tied module in the list @@ -1506,7 +1521,9 @@ def infer_auto_device_map( # Update the buffer size on the device device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size - # if the current module itself fits, we try to split the tied modules + continue + + # If the current module itself fits, we try to split the tied modules. elif len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: # can we split one of the tied modules to make it smaller or do we need to go on the next device? if verbose: @@ -1542,19 +1559,14 @@ def infer_auto_device_map( split_happened = True break - if not split_happened: - # If the tied module is not split, we go to the next device - if verbose: - print("None of the tied module can be split, going to the next device.") - if device_memory_used[device] == 0: - device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved + if split_happened: + continue - device_memory_used[device] = device_memory_used[device] + current_memory_reserved - current_device += 1 - modules_to_treat = [(name, module)] + modules_to_treat - device_memory_used[device] = 0 + # If the tied module is not split, we go to the next device + if verbose: + print("None of the tied module can be split, going to the next device.") - # the current module itself doesn't fit, so we have to split it or go to the next device + # The current module itself doesn't fit, so we have to split it or go to the next device. else: # Split or not split? modules_children = ( @@ -1571,11 +1583,7 @@ def infer_auto_device_map( # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") - if device_memory_used[device] == 0: - device_minimum_assignment_memory[device] = module_size + current_memory_reserved - device_memory_used[device] = device_memory_used[device] + current_memory_reserved - current_device += 1 - modules_to_treat = [(name, module)] + modules_to_treat + else: # -> split, we replace the module studied by its children + parameters if verbose: @@ -1589,6 +1597,35 @@ def infer_auto_device_map( no_split_module_classes, ) + continue + + # Neither the current module nor any tied modules can be split, so we move to the next device + if device_memory_used[device] == 0 and fallback_allocation and device != "disk": + # We try to allocate a module that fits in the size limit using BFS. + # Recompute the current max size as we need to consider the current module as well. + current_max_size = max_memory[device] - max(max_layer_size, module_size_with_ties) + + fallback_module_name, fallback_module, remaining_modules = fallback_allocate( + modules_to_treat, + module_sizes, + current_max_size - device_memory_used[device], + no_split_module_classes, + tied_parameters, + ) + # use the next iteration to put the fallback module on the next device to avoid code duplication + if fallback_module is not None: + modules_to_treat = [(fallback_module_name, fallback_module)] \ + + [(name, module)] \ + + remaining_modules + continue + + if device_memory_used[device] == 0: + device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved + + device_memory_used[device] = device_memory_used[device] + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} if clean_result: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 50dd06cd61a..271ebab270e 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -709,6 +709,141 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): assert len(w) == 0 assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"} + def test_infer_auto_device_map_with_fallback_allocation(self): + # Create a model where modules cannot be allocated without fallback_allocation + model = nn.Sequential( + OrderedDict([ + ("module", nn.Sequential(OrderedDict([ + ("linear1", nn.Linear(10, 4)), + ("linear2", nn.Linear(4, 4)), + ("linear3", nn.Linear(4, 8)) + ]))) + ]) + ) + + max_memory = {0: 256} + + # Without fallback_allocation + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=False) + # No module should be assigned to device 0 + assert all(device != 0 for device in device_map.values()) + # Check for warning about insufficient memory + assert any("insufficient memory" in str(warn.message).lower() for warn in w) + + # With fallback_allocation + device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True) + # At least one submodule should be assigned to device 0 + assert any(device == 0 for device in device_map.values()) + + expected_device_map = { + "module.linear1": "disk", + "module.linear2": 0, + "module.linear3": "disk" + } + assert device_map == expected_device_map + + def test_infer_auto_device_map_with_fallback_allocation_no_fit(self): + # Create a model where even the smallest submodules cannot fit + model = nn.Sequential( + OrderedDict([ + ("module", nn.Sequential(OrderedDict([ + ("linear1", nn.Linear(10, 10)), + ("linear2", nn.Linear(10, 10)), + ("linear3", nn.Linear(10, 10)) + ]))) + ]) + ) + + max_memory = {0: 30} + + # With fallback_allocation + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True) + # No module should be assigned to device 0 + assert all(device != 0 for device in device_map.values()) + # Check for warning about insufficient memory + assert any("insufficient memory" in str(warn.message).lower() for warn in w) + + def test_infer_auto_device_map_with_fallback_allocation_partial_fit(self): + # Create a model with deeper hierarchy + class CustomModule(nn.Module): + def __init__(self): + super().__init__() + self.submodule1 = nn.Linear(20, 20) + self.submodule2 = nn.Linear(20, 20) + + model = nn.Sequential( + OrderedDict([ + ("module1", CustomModule()), + ("module2", CustomModule()), + ("module3", CustomModule()) + ]) + ) + + max_memory = {0: 5000} + + # With fallback_allocation + device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True) + # Check that at least some parameters are assigned to device 0 + assigned_to_device_0 = [name for name, device in device_map.items() if device == 0] + assert len(assigned_to_device_0) > 0 + + def test_infer_auto_device_map_with_fallback_allocation_tied_weights(self): + # Create a model with tied weights + class TiedWeightsModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + self.linear2.weight = self.linear1.weight + + model = TiedWeightsModel() + + max_memory = {0: 600} + + # With fallback_allocation + device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True) + # Check that tied modules are assigned correctly + expected_device_map = { + "": 0 + } + assert device_map == expected_device_map + + def test_infer_auto_device_map_with_fallback_allocation_and_buffers(self): + # Create a model with buffers + model = nn.Sequential( + OrderedDict([ + ("linear1", nn.Linear(10, 10)), + ("batchnorm", nn.BatchNorm1d(10)), + ("linear2", nn.Linear(10, 10)) + ]) + ) + model.linear1.register_buffer("buffer1", torch.zeros(5)) + model.batchnorm.register_buffer("buffer2", torch.zeros(5)) + model.linear2.register_buffer("buffer3", torch.zeros(5)) + + max_memory = {0: 678} + + # With fallback_allocation and offload_buffers=False + with self.assertWarns(Warning) as cm: + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + fallback_allocation=True, + offload_buffers=False + ) + + # Check that the warning contains the expected message + warning_message = str(cm.warning) + assert "offload_buffers" in warning_message or "Current model requires" in warning_message + + # Verify that the entire model is assigned to device 0 + expected_device_map = {"batchnorm": 0, "linear1": "disk", "linear2": "disk"} + assert device_map == expected_device_map + @require_cuda def test_get_balanced_memory(self): model = ModelForTest() From be385b08abf5bb9ebc57ebd3e6a54dfb3c8d89f9 Mon Sep 17 00:00:00 2001 From: Chen Date: Sun, 13 Oct 2024 21:35:54 -0400 Subject: [PATCH 10/13] fix: fix module splitting logic --- src/accelerate/utils/modeling.py | 45 ++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 487a35d1a1d..d2326d74115 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1301,32 +1301,39 @@ def fallback_allocate( while modules_to_search: name, module = modules_to_search.pop(0) - if module_sizes[name] <= size_limit: - tied_param_groups = [ - tied_group - for tied_group in tied_parameters - if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) - ] - tied_params = sum( - [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] - ) + tied_param_groups = [ + tied_group + for tied_group in tied_parameters + if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) + ] - module_size_with_ties, _, _ = get_module_size_with_ties( - tied_params, module_sizes[name], module_sizes, modules_to_search - ) + tied_params = sum( + [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] + ) - if module_size_with_ties <= size_limit: - module_found = True - break + module_size_with_ties, _, _ = get_module_size_with_ties( + tied_params, module_sizes[name], module_sizes, modules_to_search + ) - if not isinstance(module, nn.Module) or module.__class__.__name__ in no_split_module_classes: - continue + # If the module fits in the size limit, we found it. + if module_size_with_ties <= size_limit: + module_found = True + break + + # The module is too big, we need to split it if possible. + modules_children = ( + [] + if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor) + else list(module.named_children()) + ) - modules_children = list(module.named_children()) - if not modules_children: + # Split fails, move to the next module + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: continue + # split is possible, add the children to the list of modules to search + modules_children = list(module.named_parameters(recurse=False)) + modules_children modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search if not module_found: From c6025ee91f208426045d1f26482e5f6e8323cae8 Mon Sep 17 00:00:00 2001 From: Chen Date: Mon, 14 Oct 2024 10:47:05 -0400 Subject: [PATCH 11/13] styles: fix styling errors --- src/accelerate/utils/modeling.py | 48 ++++++++++---------- tests/test_modeling_utils.py | 75 +++++++++++++++++--------------- 2 files changed, 62 insertions(+), 61 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index d2326d74115..0db7f7a2a43 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1204,9 +1204,9 @@ def _init_infer_auto_device_map( # Direct submodules and parameters modules_to_treat = ( - list(model.named_parameters(recurse=False)) - + list(model.named_children()) - + list(model.named_buffers(recurse=False)) + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) ) return ( @@ -1257,11 +1257,11 @@ def get_module_size_with_ties( def fallback_allocate( - modules: List[Tuple[str, nn.Module]], - module_sizes: Dict[str, int], - size_limit: Union[int, str], - no_split_module_classes: Optional[List[str]] = None, - tied_parameters: Optional[List[List[str]]] = None, + modules: List[Tuple[str, nn.Module]], + module_sizes: Dict[str, int], + size_limit: Union[int, str], + no_split_module_classes: Optional[List[str]] = None, + tied_parameters: Optional[List[List[str]]] = None, ) -> Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]: """ Find a module that fits in the size limit using BFS and return it with its name and the remaining modules. @@ -1279,8 +1279,7 @@ def fallback_allocate( A list of lists of parameter names being all tied together. Returns: - `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: - A tuple containing: + `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: A tuple containing: - The name of the module that fits within the size limit. - The module itself. - The list of remaining modules after the found module is removed. @@ -1348,11 +1347,14 @@ def fallback_allocate( if parent_name in current_names: parent_module_idx = current_names.index(parent_name) _, parent_module = modules[parent_module_idx] - module_children = (list(parent_module.named_parameters(recurse=False)) - + list(parent_module.named_children())) - modules = (modules[:parent_module_idx] - + [(f"{parent_name}.{n}", v) for n, v in module_children] - + modules[parent_module_idx + 1:]) + module_children = list(parent_module.named_parameters(recurse=False)) + list( + parent_module.named_children() + ) + modules = ( + modules[:parent_module_idx] + + [(f"{parent_name}.{n}", v) for n, v in module_children] + + modules[parent_module_idx + 1 :] + ) current_names = [n for n, _ in modules] # Now the target module should be directly in the list @@ -1510,9 +1512,7 @@ def infer_auto_device_map( for tied_module_name in tied_module_names: if tied_module_name in [m[0] for m in modules_to_treat]: # Find the index of the tied module in the list - tied_module_index = next( - i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name - ) + tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name) # Remove the tied module from the list to prevent reprocessing modules_to_treat.pop(tied_module_index) @@ -1552,10 +1552,10 @@ def infer_auto_device_map( tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] modules_to_treat = ( - [(name, module)] - + modules_to_treat[:tied_module_index] - + tied_module_children - + modules_to_treat[tied_module_index + 1:] + [(name, module)] + + modules_to_treat[:tied_module_index] + + tied_module_children + + modules_to_treat[tied_module_index + 1 :] ) # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size( @@ -1621,9 +1621,7 @@ def infer_auto_device_map( ) # use the next iteration to put the fallback module on the next device to avoid code duplication if fallback_module is not None: - modules_to_treat = [(fallback_module_name, fallback_module)] \ - + [(name, module)] \ - + remaining_modules + modules_to_treat = [(fallback_module_name, fallback_module)] + [(name, module)] + remaining_modules continue if device_memory_used[device] == 0: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 271ebab270e..1402a7ab5f1 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -712,13 +712,22 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): def test_infer_auto_device_map_with_fallback_allocation(self): # Create a model where modules cannot be allocated without fallback_allocation model = nn.Sequential( - OrderedDict([ - ("module", nn.Sequential(OrderedDict([ - ("linear1", nn.Linear(10, 4)), - ("linear2", nn.Linear(4, 4)), - ("linear3", nn.Linear(4, 8)) - ]))) - ]) + OrderedDict( + [ + ( + "module", + nn.Sequential( + OrderedDict( + [ + ("linear1", nn.Linear(10, 4)), + ("linear2", nn.Linear(4, 4)), + ("linear3", nn.Linear(4, 8)), + ] + ) + ), + ) + ] + ) ) max_memory = {0: 256} @@ -737,23 +746,28 @@ def test_infer_auto_device_map_with_fallback_allocation(self): # At least one submodule should be assigned to device 0 assert any(device == 0 for device in device_map.values()) - expected_device_map = { - "module.linear1": "disk", - "module.linear2": 0, - "module.linear3": "disk" - } + expected_device_map = {"module.linear1": "disk", "module.linear2": 0, "module.linear3": "disk"} assert device_map == expected_device_map def test_infer_auto_device_map_with_fallback_allocation_no_fit(self): # Create a model where even the smallest submodules cannot fit model = nn.Sequential( - OrderedDict([ - ("module", nn.Sequential(OrderedDict([ - ("linear1", nn.Linear(10, 10)), - ("linear2", nn.Linear(10, 10)), - ("linear3", nn.Linear(10, 10)) - ]))) - ]) + OrderedDict( + [ + ( + "module", + nn.Sequential( + OrderedDict( + [ + ("linear1", nn.Linear(10, 10)), + ("linear2", nn.Linear(10, 10)), + ("linear3", nn.Linear(10, 10)), + ] + ) + ), + ) + ] + ) ) max_memory = {0: 30} @@ -776,11 +790,7 @@ def __init__(self): self.submodule2 = nn.Linear(20, 20) model = nn.Sequential( - OrderedDict([ - ("module1", CustomModule()), - ("module2", CustomModule()), - ("module3", CustomModule()) - ]) + OrderedDict([("module1", CustomModule()), ("module2", CustomModule()), ("module3", CustomModule())]) ) max_memory = {0: 5000} @@ -807,19 +817,15 @@ def __init__(self): # With fallback_allocation device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True) # Check that tied modules are assigned correctly - expected_device_map = { - "": 0 - } + expected_device_map = {"": 0} assert device_map == expected_device_map def test_infer_auto_device_map_with_fallback_allocation_and_buffers(self): # Create a model with buffers model = nn.Sequential( - OrderedDict([ - ("linear1", nn.Linear(10, 10)), - ("batchnorm", nn.BatchNorm1d(10)), - ("linear2", nn.Linear(10, 10)) - ]) + OrderedDict( + [("linear1", nn.Linear(10, 10)), ("batchnorm", nn.BatchNorm1d(10)), ("linear2", nn.Linear(10, 10))] + ) ) model.linear1.register_buffer("buffer1", torch.zeros(5)) model.batchnorm.register_buffer("buffer2", torch.zeros(5)) @@ -830,10 +836,7 @@ def test_infer_auto_device_map_with_fallback_allocation_and_buffers(self): # With fallback_allocation and offload_buffers=False with self.assertWarns(Warning) as cm: device_map = infer_auto_device_map( - model, - max_memory=max_memory, - fallback_allocation=True, - offload_buffers=False + model, max_memory=max_memory, fallback_allocation=True, offload_buffers=False ) # Check that the warning contains the expected message From a35776ccb5e88e304236c941d56cdc24a50ec865 Mon Sep 17 00:00:00 2001 From: Chen Date: Mon, 14 Oct 2024 14:24:02 -0400 Subject: [PATCH 12/13] test: add test coverage for no-warning cases test_infer_auto_device_map and test_infer_auto_device_map_with_fallback_allocation now each have a no-warning test case. Simplified and rewrote code sections that were made unreadable by the linter. --- tests/test_modeling_utils.py | 53 +++++++++++++----------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 1402a7ab5f1..0e1edc1acc5 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -519,7 +519,10 @@ def test_infer_auto_device_map(self): model = ModelForTest() # model has size 236: linear1 64, batchnorm 72, linear2 100 - device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 200}) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 200}) + assert len(w) == 0, f"Unexpected warnings: {[str(warning.message) for warning in w]}" # only linear1 fits on device 0 as we keep memory available for the maximum layer in case of offload assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": 1} @@ -711,25 +714,14 @@ def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): def test_infer_auto_device_map_with_fallback_allocation(self): # Create a model where modules cannot be allocated without fallback_allocation - model = nn.Sequential( - OrderedDict( - [ - ( - "module", - nn.Sequential( - OrderedDict( - [ - ("linear1", nn.Linear(10, 4)), - ("linear2", nn.Linear(4, 4)), - ("linear3", nn.Linear(4, 8)), - ] - ) - ), - ) - ] - ) + # Define the inner module with its layers + inner_module = nn.Sequential( + OrderedDict([("linear1", nn.Linear(10, 4)), ("linear2", nn.Linear(4, 4)), ("linear3", nn.Linear(4, 8))]) ) + # Wrap the inner module in another module + model = nn.Sequential(OrderedDict([("module", inner_module)])) + max_memory = {0: 256} # Without fallback_allocation @@ -742,7 +734,10 @@ def test_infer_auto_device_map_with_fallback_allocation(self): assert any("insufficient memory" in str(warn.message).lower() for warn in w) # With fallback_allocation - device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + device_map = infer_auto_device_map(model, max_memory=max_memory, fallback_allocation=True) + assert len(w) == 0, f"Unexpected warnings: {[str(warning.message) for warning in w]}" # At least one submodule should be assigned to device 0 assert any(device == 0 for device in device_map.values()) @@ -751,25 +746,15 @@ def test_infer_auto_device_map_with_fallback_allocation(self): def test_infer_auto_device_map_with_fallback_allocation_no_fit(self): # Create a model where even the smallest submodules cannot fit - model = nn.Sequential( + inner_module = nn.Sequential( OrderedDict( - [ - ( - "module", - nn.Sequential( - OrderedDict( - [ - ("linear1", nn.Linear(10, 10)), - ("linear2", nn.Linear(10, 10)), - ("linear3", nn.Linear(10, 10)), - ] - ) - ), - ) - ] + [("linear1", nn.Linear(10, 10)), ("linear2", nn.Linear(10, 10)), ("linear3", nn.Linear(10, 10))] ) ) + # Wrap the inner module in another module + model = nn.Sequential(OrderedDict([("module", inner_module)])) + max_memory = {0: 30} # With fallback_allocation From 73dbc76879ee04d0bde8e8dcf40875a55d00940c Mon Sep 17 00:00:00 2001 From: Chen Date: Mon, 14 Oct 2024 14:41:04 -0400 Subject: [PATCH 13/13] refactor: simplify control flow in infer_auto_device_map Added complete return type hinting for _init_infer_auto_device_map --- src/accelerate/utils/modeling.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 0db7f7a2a43..9c3dd232cc3 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1171,7 +1171,15 @@ def _init_infer_auto_device_map( no_split_module_classes: Optional[List[str]] = None, dtype: Optional[Union[str, torch.dtype]] = None, special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, -) -> Tuple: +) -> Tuple[ + List[Union[int, str]], + List[Union[int, str]], + List[int], + Dict[str, int], + List[List[str]], + List[str], + List[Tuple[str, nn.Module]], +]: """ Initialize variables required for computing the device map for model allocation. """ @@ -1530,8 +1538,8 @@ def infer_auto_device_map( continue - # If the current module itself fits, we try to split the tied modules. - elif len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: + # The current module itself fits, so we try to split the tied modules. + if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: # can we split one of the tied modules to make it smaller or do we need to go on the next device? if verbose: print( @@ -1574,7 +1582,7 @@ def infer_auto_device_map( print("None of the tied module can be split, going to the next device.") # The current module itself doesn't fit, so we have to split it or go to the next device. - else: + if device_memory_used[device] + module_size >= current_max_size: # Split or not split? modules_children = ( [] @@ -1603,7 +1611,6 @@ def infer_auto_device_map( module_sizes, no_split_module_classes, ) - continue # Neither the current module nor any tied modules can be split, so we move to the next device