Skip to content

Commit

Permalink
refactor: simplify allocation logic by removing duplicates and reduci…
Browse files Browse the repository at this point in the history
…ng nesting
  • Loading branch information
Nech-C committed Oct 12, 2024
1 parent 9db401e commit 0b52bdc
Showing 1 changed file with 100 additions and 101 deletions.
201 changes: 100 additions & 101 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def get_module_size_with_ties(
module_sizes,
modules_to_treat,
) -> int:
if not tied_params:
if len(tied_params) < 1:
return module_size, [], []
tied_module_names = []
tied_modules = []
Expand Down Expand Up @@ -1457,8 +1457,105 @@ def infer_auto_device_map(
current_max_size = current_max_size - max_layer_size
current_memory_reserved = max_layer_size

# Case 1 -> We're too big!
if current_max_size is not None and device_memory_used[device] + module_size > current_max_size:
(
module_size_with_ties,
tied_module_names,
tied_modules
) = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat)

# the module and its tied modules fit on the current device
if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
if verbose:
output = f"Putting {name}"

if tied_module_names:
output += f" and {tied_module_names}"
else:
output += f" (size={module_size})"

if current_max_size is not None:
output += f" (available={current_max_size - device_memory_used[device]})"

output += f" on {device}."
print(output)

device_memory_used[device] += module_size_with_ties

# Assign the primary module to the device
device_map[name] = device

# Assign tied modules if any
for tied_module_name in tied_module_names:
if tied_module_name in [m[0] for m in modules_to_treat]:
# Find the index of the tied module in the list
tied_module_index = next(
i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name
)
# Remove the tied module from the list to prevent reprocessing
modules_to_treat.pop(tied_module_index)

# Assign the tied module to the device
device_map[tied_module_name] = device

# Buffer Handling
if not offload_buffers and isinstance(module, nn.Module):
# Compute the total buffer size for the module
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
# Update the buffer size on the device
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

# if the current module itself fits, we try to split the tied modules
elif len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size:
# can we split one of the tied modules to make it smaller or do we need to go on the next device?
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
)
split_happened = False
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
tied_module_children = list(tied_module.named_children())
if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
# can't break this one.
continue

if verbose:
print(f"Splitting {tied_module_name}.")
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]

modules_to_treat = (
[(name, module)]
+ modules_to_treat[:tied_module_index]
+ tied_module_children
+ modules_to_treat[tied_module_index + 1:]
)
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
split_happened = True
break

if not split_happened:
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")
if device_memory_used[device] == 0:
device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved

device_memory_used[device] = device_memory_used[device] + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
device_memory_used[device] = 0

# the current module itself doesn't fit, so we have to split it or go to the next device
else:
# Split or not split?
modules_children = (
[]
Expand Down Expand Up @@ -1492,104 +1589,6 @@ def infer_auto_device_map(
no_split_module_classes,
)

# Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters.
elif len(tied_params) > 0:
# First locate all tied modules
module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat)
if verbose:
print(
f" It looks like {name} is going to fit on {devices[current_device]} but we have tied "
f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}"
)

# Let's see if it all fits first
if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
# We really really fit!
if verbose:
print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.")
device_memory_used[device] += module_size_with_ties
device_map[name] = devices[current_device]
for tied_module_name in tied_module_names:
if tied_module_name in [m[0] for m in modules_to_treat]:
# The module may have been removed by a previous iteration of this loop.
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][
0
]
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

else:
# We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it
# smaller or do we need to go on the next device?
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
)
split_happened = False
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
tied_module_children = list(tied_module.named_children())
if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
# can't break this one.
continue

if verbose:
print(f"Splitting {tied_module_name}.")
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]

modules_to_treat = (
[(name, module)]
+ modules_to_treat[:tied_module_index]
+ tied_module_children
+ modules_to_treat[tied_module_index + 1 :]
)
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
module_sizes,
no_split_module_classes,
)
split_happened = True
break

if not split_happened:
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")
if device_memory_used[device] == 0:
device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved

device_memory_used[device] = device_memory_used[device] + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
device_memory_used[device] = 0

else:
if verbose:
if current_max_size is None:
print(f"Putting {name} (size={module_size}) on {devices[current_device]}.")
else:
print(
f"Putting {name} (size={module_size}) on {devices[current_device]} "
f"(available={current_max_size - device_memory_used[device]})."
)
device_memory_used[device] += module_size
device_map[name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}

if clean_result:
Expand Down

0 comments on commit 0b52bdc

Please sign in to comment.