From a826e4441d69146475744f593fbaf142cc158787 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 24 Mar 2023 09:53:29 -0400 Subject: [PATCH] Handle multiple tied parameters (#1241) * Handle multiple tied parameters * Add tests * Ensure backward compatibility with Transformers * Update src/accelerate/utils/modeling.py Co-authored-by: Lysandre Debut * Gate test requiring Transformers --------- Co-authored-by: Lysandre Debut --- src/accelerate/utils/modeling.py | 167 +++++++++++++++++++++++-------- tests/test_modeling_utils.py | 122 ++++++++++++++++++++-- 2 files changed, 237 insertions(+), 52 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index f1455899860..d7d6cbcf922 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -183,6 +183,20 @@ def named_module_tensors(module: nn.Module, include_buffers: bool = True, recurs yield named_buffer +class FindTiedParametersResult(list): + """ + This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not + a list or on the `values` method as in the future this will be removed. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def values(self): + # TODO: at the next Transformers release (4.28.0) issue a deprecation warning here. + return sum([x[1:] for x in self], []) + + def find_tied_parameters(model: nn.Module, **kwargs): """ Find the tied parameters in a given model. @@ -198,7 +212,7 @@ def find_tied_parameters(model: nn.Module, **kwargs): model (`torch.nn.Module`): The model to inspect. Returns: - Dict[str, str]: A dictionary mapping tied parameter names to the name of the parameter they are tied to. + List[List[str]]: A list of lists of parameter names being all tied together. Example: @@ -207,9 +221,9 @@ def find_tied_parameters(model: nn.Module, **kwargs): >>> import torch.nn as nn >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) - >>> model.linear2.weight = test_model.linear1.weight - >>> find_tied_parameters(test_model) - {'linear1.weight': 'linear2.weight'} + >>> model.linear2.weight = model.linear1.weight + >>> find_tied_parameters(model) + [['linear1.weight', 'linear2.weight']] ``` """ # Initialize result and named_parameters before recursing. @@ -229,14 +243,16 @@ def find_tied_parameters(model: nn.Module, **kwargs): # When we find one, it has to be one of the existing parameters. for new_name, new_param in named_parameters.items(): if new_param is parameter: - result[new_name] = full_name + if new_name not in result: + result[new_name] = [] + result[new_name].append(full_name) # Once we have treated direct parameters, we move to the child modules. for name, child in model.named_children(): child_name = name if prefix == "" else f"{prefix}.{name}" find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) - return result + return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()]) def retie_parameters(model, tied_params): @@ -246,17 +262,21 @@ def retie_parameters(model, tied_params): Args: model (`torch.nn.Module`): The model in which to retie parameters. - tied_params (`Dict[str, str]`): + tied_params (`List[List[str]]`): A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`. """ - for param_name, tied_param_name in tied_params.items(): - param = model - for split in param_name.split("."): - param = getattr(param, split) - tied_module = model - for split in tied_param_name.split(".")[:-1]: - tied_module = getattr(tied_module, split) - setattr(tied_module, tied_param_name.split(".")[-1], param) + for tied_group in tied_params: + param_to_tie = None + # First iteration of the loop will set param_to_tie, next ones will tie it to the others + for param_name in tied_group: + module = model + splits = param_name.split(".") + for split in splits[:-1]: + module = getattr(module, split) + if param_to_tie is None: + param_to_tie = getattr(module, splits[-1]) + else: + setattr(module, splits[-1], param_to_tie) def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: @@ -508,6 +528,7 @@ def infer_auto_device_map( no_split_module_classes: Optional[List[str]] = None, dtype: Optional[Union[str, torch.dtype]] = None, special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, + verbose: bool = False, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -539,6 +560,8 @@ def infer_auto_device_map( special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): If provided, special dtypes to consider for some specific weights (will override dtype used as default for all weights). + verbose (`bool`, *optional*, defaults to `False`): + Whether or not to provide debugging statements as the function builds the device_map. """ # Get default / clean up max_memory max_memory = get_max_memory(max_memory) @@ -574,6 +597,8 @@ def infer_auto_device_map( # Ready ? This is going to be a bit messy. while len(modules_to_treat) > 0: name, module = modules_to_treat.pop(0) + if verbose: + print(f"\nTreating module {name}.") # Max size in the remaining layers may have changed since we took one, so we maybe update it. max_layer_names = [n for n in max_layer_names if not n.startswith(name)] if len(max_layer_names) == 0: @@ -584,11 +609,20 @@ def infer_auto_device_map( ) # Assess size needed module_size = module_sizes[name] - # We keep relevant tied parameters only: once of the tied parameters is inside the current module and the other - # is not. - tied_params = [v for k, v in tied_parameters.items() if name in k and name not in v] - # We ignore parameters that are tied when they're tied to > 1 one - tied_param = tied_params[0] if len(tied_params) == 1 else None + + # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module + # and the other is not. + tied_param_goups = [ + tied_group + for tied_group in tied_parameters + if any(name in k for k in tied_group) and not all(name in k for k in tied_group) + ] + if verbose and len(tied_param_goups) > 0: + print(f" Found the relevant tied param groups {tied_param_goups}") + # Then we keep track of all the parameters that are tied to the current module, but not in the current module + tied_params = sum([[p for p in tied_group if name not in p] for tied_group in tied_param_goups], []) + if verbose and len(tied_params) > 0: + print(f" So those parameters need to be taken into account {tied_params}") device = devices[current_device] current_max_size = max_memory[device] if device != "disk" else None @@ -599,13 +633,22 @@ def infer_auto_device_map( if current_max_size is not None and current_memory_used + module_size > current_max_size: # Split or not split? modules_children = list(module.named_children()) + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} (space available " + f"{current_max_size-current_memory_used}, module size {module_size})." + ) if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: # -> no split, we go to the next device + if verbose: + print("This module cannot be split, going to the next device.") current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat current_memory_used = 0 else: # -> split, we replace the module studied by its children + parameters + if verbose: + print(f"Splitting {name}.") modules_children = list(module.named_parameters(recurse=False)) + modules_children modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat # Update the max layer size. @@ -616,24 +659,57 @@ def infer_auto_device_map( ) # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. - elif tied_param is not None: - # Determine the sized occupied by this module + the module containing the tied parameter - tied_module_size = module_size - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] - tied_module_name, tied_module = modules_to_treat[tied_module_index] - tied_module_size += module_sizes[tied_module_name] - module_sizes[tied_param] - if current_max_size is not None and current_memory_used + tied_module_size > current_max_size: - # Split or not split? - tied_module_children = list(tied_module.named_children()) - if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: - # If the tied module is not split, we go to the next device - current_device += 1 - modules_to_treat = [(name, module)] + modules_to_treat - current_memory_used = 0 - else: - # Otherwise, we replace the tied module by its children. + elif len(tied_params) > 0: + # First locate all tied modules + tied_module_names = [] + tied_modules = [] + for tied_param in tied_params: + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] + tied_module_names.append(modules_to_treat[tied_module_index][0]) + tied_modules.append(modules_to_treat[tied_module_index][1]) + if verbose: + print( + f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " + f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}" + ) + + # Let's see if it all fits first + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(tied_params, tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] + + if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size: + # We really really fit! + if verbose: + print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") + current_memory_used += module_size_with_ties + device_map[name] = devices[current_device] + for tied_module_name in tied_module_names: + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] + modules_to_treat.pop(tied_module_index) + device_map[tied_module_name] = devices[current_device] + + else: + # We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it + # smaller or do we need to go on the next device? + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " + f"available {current_max_size-current_memory_used}, needed size {module_size_with_ties})." + ) + split_happened = False + for tied_module_name, tied_module in zip(tied_module_names, tied_modules): + tied_module_children = list(tied_module.named_children()) + if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: + # can't break this one. + continue + + if verbose: + print(f"Splitting {tied_module_name}.") tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] + modules_to_treat = ( [(name, module)] + modules_to_treat[:tied_module_index] @@ -646,13 +722,20 @@ def infer_auto_device_map( module_sizes, no_split_module_classes, ) - else: - # We really really fit! - current_memory_used += tied_module_size - device_map[name] = devices[current_device] - modules_to_treat.pop(tied_module_index) - device_map[tied_module_name] = devices[current_device] + split_happened = True + break + + if not split_happened: + # If the tied module is not split, we go to the next device + if verbose: + print("None of the tied module can be split, going to the next device.") + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + current_memory_used = 0 + else: + if verbose: + print(f"Putting {name} on {devices[current_device]}.") current_memory_used += module_size device_map[name] = devices[current_device] diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 16f6bc75cb4..22de17d2841 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -16,11 +16,13 @@ import os import tempfile import unittest +from collections import OrderedDict import torch import torch.nn as nn -from accelerate.test_utils import require_cuda, require_multi_gpu, require_safetensors +from accelerate import init_empty_weights +from accelerate.test_utils import require_cuda, require_huggingface_suite, require_multi_gpu, require_safetensors from accelerate.test_utils.testing import require_torch_min_version from accelerate.utils.modeling import ( check_device_map, @@ -32,6 +34,7 @@ load_checkpoint_in_model, load_state_dict, named_module_tensors, + retie_parameters, set_module_tensor_to_device, ) @@ -47,6 +50,11 @@ def forward(self, x): return self.linear2(self.batchnorm(self.linear1(x))) +def sequential_model(num_layers): + layers = OrderedDict([(f"linear{i}", nn.Linear(1000, 1000)) for i in range(1, num_layers + 1)]) + return nn.Sequential(layers) + + @require_torch_min_version(version="1.9.0") class ModelingUtilsTester(unittest.TestCase): def check_set_module_tensor_for_device(self, model, device1, device2): @@ -170,10 +178,52 @@ def test_named_tensors(self): ) def test_find_tied_parameters(self): - model = ModelForTest() - self.assertDictEqual(find_tied_parameters(model), {}) + model = sequential_model(4) + self.assertListEqual(find_tied_parameters(model), []) + model.linear2.weight = model.linear1.weight - self.assertDictEqual(find_tied_parameters(model), {"linear1.weight": "linear2.weight"}) + self.assertListEqual(find_tied_parameters(model), [["linear1.weight", "linear2.weight"]]) + + model.linear4.weight = model.linear1.weight + self.assertListEqual(find_tied_parameters(model), [["linear1.weight", "linear2.weight", "linear4.weight"]]) + + model = sequential_model(5) + model.linear1.weight = model.linear4.weight + model.linear2.weight = model.linear3.weight + model.linear5.weight = model.linear2.weight + tied_params = sorted(find_tied_parameters(model), key=lambda x: len(x)) + self.assertListEqual( + tied_params, [["linear1.weight", "linear4.weight"], ["linear2.weight", "linear3.weight", "linear5.weight"]] + ) + + model = nn.Sequential(OrderedDict([("block1", sequential_model(4)), ("block2", sequential_model(4))])) + model.block1.linear1.weight = model.block2.linear1.weight + self.assertListEqual(find_tied_parameters(model), [["block1.linear1.weight", "block2.linear1.weight"]]) + + def test_retie_parameters(self): + model = sequential_model(2) + retie_parameters(model, [["linear1.weight", "linear2.weight"]]) + self.assertIs(model.linear1.weight, model.linear2.weight) + + model = sequential_model(3) + retie_parameters(model, [["linear1.weight", "linear2.weight", "linear3.weight"]]) + + self.assertIs(model.linear1.weight, model.linear2.weight) + self.assertIs(model.linear1.weight, model.linear3.weight) + + model = sequential_model(5) + retie_parameters( + model, [["linear1.weight", "linear4.weight"], ["linear2.weight", "linear3.weight", "linear5.weight"]] + ) + + self.assertIs(model.linear1.weight, model.linear4.weight) + self.assertIs(model.linear2.weight, model.linear3.weight) + self.assertIs(model.linear2.weight, model.linear5.weight) + + model = nn.Sequential(OrderedDict([("block1", sequential_model(4)), ("block2", sequential_model(4))])) + retie_parameters(model, [["block1.linear1.weight", "block2.linear1.weight"]]) + + self.assertIs(model.block1.linear1.weight, model.block2.linear1.weight) def test_compute_module_sizes(self): model = ModelForTest() @@ -384,15 +434,67 @@ def test_infer_auto_device_map(self): ) self.assertDictEqual(device_map, {"0": 0, "1": 1, "2": 1}) - # Now if we have weights tied inside submodules, tied weights are on the same device. - model = nn.Sequential(ModelForTest(), ModelForTest(), ModelForTest()) - layer0 = getattr(model, "0") - layer2 = getattr(model, "2") - layer0.linear2.weight = layer2.linear2.weight + def test_infer_auto_device_map_with_tied_weights(self): + model = nn.Sequential( + OrderedDict([("layer1", ModelForTest()), ("layer2", ModelForTest()), ("layer3", ModelForTest())]) + ) + model.layer3.linear2.weight = model.layer1.linear2.weight + device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 500}) + expected = {"layer1": 0, "layer3.linear2": 0, "layer2": 1, "layer3.linear1": 1, "layer3.batchnorm": 1} + self.assertDictEqual(device_map, expected) + + # With three weights tied together + model.layer2.linear2.weight = model.layer1.linear2.weight + device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 500}) + expected = { + "layer1": 0, + "layer2.linear2": 0, + "layer3.linear2": 0, + "layer2.linear1": 1, + "layer2.batchnorm": 1, + "layer3.linear1": 1, + "layer3.batchnorm": 1, + } + self.assertDictEqual(device_map, expected) + + # With two groups of weights tied together + model.layer2.linear1.weight = model.layer1.linear1.weight device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 500}) - expected = {"0": 0, "2.linear2": 0, "1": 1, "2.linear1": 1, "2.batchnorm": 1} + expected = { + "layer1": 0, + "layer2.linear1": 0, + "layer2.linear2": 0, + "layer3.linear2": 0, + "layer2.batchnorm": 1, + "layer3.linear1": 1, + "layer3.batchnorm": 1, + } self.assertDictEqual(device_map, expected) + @require_huggingface_suite + def test_infer_auto_device_map_on_t0pp(self): + from transformers import AutoConfig, AutoModelForSeq2SeqLM + + config = AutoConfig.from_pretrained("bigscience/T0pp") + with init_empty_weights(): + model = AutoModelForSeq2SeqLM.from_config(config) + model.tie_weights() + + special_dtypes = {n: torch.float32 for n, _ in model.named_parameters() if "wo" in n} + max_memory = {0: 10**10, 1: 10**10, "cpu": 10**10} + device_map = infer_auto_device_map( + model, + no_split_module_classes=["T5Block"], + dtype=torch.float16, + max_memory=max_memory, + special_dtypes=special_dtypes, + ) + + # The 3 tied weights should all be on device 0 + self.assertEqual(device_map["shared"], 0) + self.assertEqual(device_map["encoder.embed_tokens"], 0) + self.assertEqual(device_map["decoder.embed_tokens"], 0) + @require_cuda def test_get_balanced_memory(self): model = ModelForTest()