Skip to content

Commit

Permalink
Handle multiple tied parameters (#1241)
Browse files Browse the repository at this point in the history
* Handle multiple tied parameters

* Add tests

* Ensure backward compatibility with Transformers

* Update src/accelerate/utils/modeling.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* Gate test requiring Transformers

---------

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
  • Loading branch information
sgugger and LysandreJik authored Mar 24, 2023
1 parent 1fe27e7 commit a826e44
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 52 deletions.
167 changes: 125 additions & 42 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,20 @@ def named_module_tensors(module: nn.Module, include_buffers: bool = True, recurs
yield named_buffer


class FindTiedParametersResult(list):
"""
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
a list or on the `values` method as in the future this will be removed.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def values(self):
# TODO: at the next Transformers release (4.28.0) issue a deprecation warning here.
return sum([x[1:] for x in self], [])


def find_tied_parameters(model: nn.Module, **kwargs):
"""
Find the tied parameters in a given model.
Expand All @@ -198,7 +212,7 @@ def find_tied_parameters(model: nn.Module, **kwargs):
model (`torch.nn.Module`): The model to inspect.
Returns:
Dict[str, str]: A dictionary mapping tied parameter names to the name of the parameter they are tied to.
List[List[str]]: A list of lists of parameter names being all tied together.
Example:
Expand All @@ -207,9 +221,9 @@ def find_tied_parameters(model: nn.Module, **kwargs):
>>> import torch.nn as nn
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
>>> model.linear2.weight = test_model.linear1.weight
>>> find_tied_parameters(test_model)
{'linear1.weight': 'linear2.weight'}
>>> model.linear2.weight = model.linear1.weight
>>> find_tied_parameters(model)
[['linear1.weight', 'linear2.weight']]
```
"""
# Initialize result and named_parameters before recursing.
Expand All @@ -229,14 +243,16 @@ def find_tied_parameters(model: nn.Module, **kwargs):
# When we find one, it has to be one of the existing parameters.
for new_name, new_param in named_parameters.items():
if new_param is parameter:
result[new_name] = full_name
if new_name not in result:
result[new_name] = []
result[new_name].append(full_name)

# Once we have treated direct parameters, we move to the child modules.
for name, child in model.named_children():
child_name = name if prefix == "" else f"{prefix}.{name}"
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)

return result
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])


def retie_parameters(model, tied_params):
Expand All @@ -246,17 +262,21 @@ def retie_parameters(model, tied_params):
Args:
model (`torch.nn.Module`):
The model in which to retie parameters.
tied_params (`Dict[str, str]`):
tied_params (`List[List[str]]`):
A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`.
"""
for param_name, tied_param_name in tied_params.items():
param = model
for split in param_name.split("."):
param = getattr(param, split)
tied_module = model
for split in tied_param_name.split(".")[:-1]:
tied_module = getattr(tied_module, split)
setattr(tied_module, tied_param_name.split(".")[-1], param)
for tied_group in tied_params:
param_to_tie = None
# First iteration of the loop will set param_to_tie, next ones will tie it to the others
for param_name in tied_group:
module = model
splits = param_name.split(".")
for split in splits[:-1]:
module = getattr(module, split)
if param_to_tie is None:
param_to_tie = getattr(module, splits[-1])
else:
setattr(module, splits[-1], param_to_tie)


def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
Expand Down Expand Up @@ -508,6 +528,7 @@ def infer_auto_device_map(
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -539,6 +560,8 @@ def infer_auto_device_map(
special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
If provided, special dtypes to consider for some specific weights (will override dtype used as default for
all weights).
verbose (`bool`, *optional*, defaults to `False`):
Whether or not to provide debugging statements as the function builds the device_map.
"""
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
Expand Down Expand Up @@ -574,6 +597,8 @@ def infer_auto_device_map(
# Ready ? This is going to be a bit messy.
while len(modules_to_treat) > 0:
name, module = modules_to_treat.pop(0)
if verbose:
print(f"\nTreating module {name}.")
# Max size in the remaining layers may have changed since we took one, so we maybe update it.
max_layer_names = [n for n in max_layer_names if not n.startswith(name)]
if len(max_layer_names) == 0:
Expand All @@ -584,11 +609,20 @@ def infer_auto_device_map(
)
# Assess size needed
module_size = module_sizes[name]
# We keep relevant tied parameters only: once of the tied parameters is inside the current module and the other
# is not.
tied_params = [v for k, v in tied_parameters.items() if name in k and name not in v]
# We ignore parameters that are tied when they're tied to > 1 one
tied_param = tied_params[0] if len(tied_params) == 1 else None

# We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
# and the other is not.
tied_param_goups = [
tied_group
for tied_group in tied_parameters
if any(name in k for k in tied_group) and not all(name in k for k in tied_group)
]
if verbose and len(tied_param_goups) > 0:
print(f" Found the relevant tied param groups {tied_param_goups}")
# Then we keep track of all the parameters that are tied to the current module, but not in the current module
tied_params = sum([[p for p in tied_group if name not in p] for tied_group in tied_param_goups], [])
if verbose and len(tied_params) > 0:
print(f" So those parameters need to be taken into account {tied_params}")

device = devices[current_device]
current_max_size = max_memory[device] if device != "disk" else None
Expand All @@ -599,13 +633,22 @@ def infer_auto_device_map(
if current_max_size is not None and current_memory_used + module_size > current_max_size:
# Split or not split?
modules_children = list(module.named_children())
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} (space available "
f"{current_max_size-current_memory_used}, module size {module_size})."
)
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
# -> no split, we go to the next device
if verbose:
print("This module cannot be split, going to the next device.")
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
else:
# -> split, we replace the module studied by its children + parameters
if verbose:
print(f"Splitting {name}.")
modules_children = list(module.named_parameters(recurse=False)) + modules_children
modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
# Update the max layer size.
Expand All @@ -616,24 +659,57 @@ def infer_auto_device_map(
)

# Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters.
elif tied_param is not None:
# Determine the sized occupied by this module + the module containing the tied parameter
tied_module_size = module_size
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0]
tied_module_name, tied_module = modules_to_treat[tied_module_index]
tied_module_size += module_sizes[tied_module_name] - module_sizes[tied_param]
if current_max_size is not None and current_memory_used + tied_module_size > current_max_size:
# Split or not split?
tied_module_children = list(tied_module.named_children())
if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
# If the tied module is not split, we go to the next device
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
else:
# Otherwise, we replace the tied module by its children.
elif len(tied_params) > 0:
# First locate all tied modules
tied_module_names = []
tied_modules = []
for tied_param in tied_params:
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0]
tied_module_names.append(modules_to_treat[tied_module_index][0])
tied_modules.append(modules_to_treat[tied_module_index][1])
if verbose:
print(
f" It looks like {name} is going to fit on {devices[current_device]} but we have tied "
f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}"
)

# Let's see if it all fits first
module_size_with_ties = module_size
for tied_param, tied_module_name in zip(tied_params, tied_module_names):
module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]

if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size:
# We really really fit!
if verbose:
print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.")
current_memory_used += module_size_with_ties
device_map[name] = devices[current_device]
for tied_module_name in tied_module_names:
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]

else:
# We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it
# smaller or do we need to go on the next device?
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
f"available {current_max_size-current_memory_used}, needed size {module_size_with_ties})."
)
split_happened = False
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
tied_module_children = list(tied_module.named_children())
if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
# can't break this one.
continue

if verbose:
print(f"Splitting {tied_module_name}.")
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]

modules_to_treat = (
[(name, module)]
+ modules_to_treat[:tied_module_index]
Expand All @@ -646,13 +722,20 @@ def infer_auto_device_map(
module_sizes,
no_split_module_classes,
)
else:
# We really really fit!
current_memory_used += tied_module_size
device_map[name] = devices[current_device]
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]
split_happened = True
break

if not split_happened:
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0

else:
if verbose:
print(f"Putting {name} on {devices[current_device]}.")
current_memory_used += module_size
device_map[name] = devices[current_device]

Expand Down
Loading

0 comments on commit a826e44

Please sign in to comment.