diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 8b863221972..222b5c758ca 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1326,7 +1326,7 @@ def get_module_size_with_ties( module_sizes, modules_to_treat, ) -> int: - if not tied_params: + if len(tied_params) < 1: return module_size, [], [] tied_module_names = [] tied_modules = [] @@ -1457,8 +1457,105 @@ def infer_auto_device_map( current_max_size = current_max_size - max_layer_size current_memory_reserved = max_layer_size - # Case 1 -> We're too big! - if current_max_size is not None and device_memory_used[device] + module_size > current_max_size: + ( + module_size_with_ties, + tied_module_names, + tied_modules + ) = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat) + + # the module and its tied modules fit on the current device + if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: + if verbose: + output = f"Putting {name}" + + if tied_module_names: + output += f" and {tied_module_names}" + else: + output += f" (size={module_size})" + + if current_max_size is not None: + output += f" (available={current_max_size - device_memory_used[device]})" + + output += f" on {device}." + print(output) + + device_memory_used[device] += module_size_with_ties + + # Assign the primary module to the device + device_map[name] = device + + # Assign tied modules if any + for tied_module_name in tied_module_names: + if tied_module_name in [m[0] for m in modules_to_treat]: + # Find the index of the tied module in the list + tied_module_index = next( + i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name + ) + # Remove the tied module from the list to prevent reprocessing + modules_to_treat.pop(tied_module_index) + + # Assign the tied module to the device + device_map[tied_module_name] = device + + # Buffer Handling + if not offload_buffers and isinstance(module, nn.Module): + # Compute the total buffer size for the module + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes + ) + # Update the buffer size on the device + device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + + # if the current module itself fits, we try to split the tied modules + elif len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: + # can we split one of the tied modules to make it smaller or do we need to go on the next device? + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " + f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." + ) + split_happened = False + for tied_module_name, tied_module in zip(tied_module_names, tied_modules): + tied_module_children = list(tied_module.named_children()) + if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: + # can't break this one. + continue + + if verbose: + print(f"Splitting {tied_module_name}.") + tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children + tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] + + modules_to_treat = ( + [(name, module)] + + modules_to_treat[:tied_module_index] + + tied_module_children + + modules_to_treat[tied_module_index + 1:] + ) + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + split_happened = True + break + + if not split_happened: + # If the tied module is not split, we go to the next device + if verbose: + print("None of the tied module can be split, going to the next device.") + if device_memory_used[device] == 0: + device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved + + device_memory_used[device] = device_memory_used[device] + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + device_memory_used[device] = 0 + + # the current module itself doesn't fit, so we have to split it or go to the next device + else: # Split or not split? modules_children = ( [] @@ -1492,104 +1589,6 @@ def infer_auto_device_map( no_split_module_classes, ) - # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. - elif len(tied_params) > 0: - # First locate all tied modules - module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat) - if verbose: - print( - f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " - f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}" - ) - - # Let's see if it all fits first - if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: - # We really really fit! - if verbose: - print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") - device_memory_used[device] += module_size_with_ties - device_map[name] = devices[current_device] - for tied_module_name in tied_module_names: - if tied_module_name in [m[0] for m in modules_to_treat]: - # The module may have been removed by a previous iteration of this loop. - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][ - 0 - ] - modules_to_treat.pop(tied_module_index) - device_map[tied_module_name] = devices[current_device] - - if not offload_buffers and isinstance(module, nn.Module): - current_buffer_size = compute_module_total_buffer_size( - module, dtype=dtype, special_dtypes=special_dtypes - ) - device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size - - else: - # We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it - # smaller or do we need to go on the next device? - if verbose: - print( - f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " - f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." - ) - split_happened = False - for tied_module_name, tied_module in zip(tied_module_names, tied_modules): - tied_module_children = list(tied_module.named_children()) - if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: - # can't break this one. - continue - - if verbose: - print(f"Splitting {tied_module_name}.") - tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children - tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] - - modules_to_treat = ( - [(name, module)] - + modules_to_treat[:tied_module_index] - + tied_module_children - + modules_to_treat[tied_module_index + 1 :] - ) - # Update the max layer size. - max_layer_size, max_layer_names = get_max_layer_size( - [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], - module_sizes, - no_split_module_classes, - ) - split_happened = True - break - - if not split_happened: - # If the tied module is not split, we go to the next device - if verbose: - print("None of the tied module can be split, going to the next device.") - if device_memory_used[device] == 0: - device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved - - device_memory_used[device] = device_memory_used[device] + current_memory_reserved - current_device += 1 - modules_to_treat = [(name, module)] + modules_to_treat - device_memory_used[device] = 0 - - else: - if verbose: - if current_max_size is None: - print(f"Putting {name} (size={module_size}) on {devices[current_device]}.") - else: - print( - f"Putting {name} (size={module_size}) on {devices[current_device]} " - f"(available={current_max_size - device_memory_used[device]})." - ) - device_memory_used[device] += module_size - device_map[name] = devices[current_device] - - if not offload_buffers and isinstance(module, nn.Module): - current_buffer_size = compute_module_total_buffer_size( - module, dtype=dtype, special_dtypes=special_dtypes - ) - device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size - device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} if clean_result: