diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 725b782ca1a..fe5eee0a6e5 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -259,19 +259,36 @@ def retie_parameters(model, tied_params): setattr(tied_module, tied_param_name.split(".")[-1], param) -def compute_module_sizes(model: nn.Module, dtype: Optional[Union[str, torch.device]] = None): +def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: """ - Compute the size of each submodule of a given model. + Just does torch.dtype(dtype) if necessary. """ if isinstance(dtype, str): # We accept "torch.float16" or just "float16" dtype = dtype.replace("torch.", "") dtype = getattr(torch, dtype) + return dtype + + +def compute_module_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +): + """ + Compute the size of each submodule of a given model. + """ if dtype is not None: + dtype = _get_proper_dtype(dtype) dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} module_sizes = defaultdict(int) for name, tensor in named_module_tensors(model, recurse=True): - if dtype is None: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: size = tensor.numel() * dtype_byte_size(tensor.dtype) else: size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) @@ -394,6 +411,7 @@ def get_balanced_memory( max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, no_split_module_classes: Optional[List[str]] = None, dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, low_zero: bool = False, ): """ @@ -416,6 +434,9 @@ def get_balanced_memory( residual connection). dtype (`str` or `torch.dtype`, *optional*): If provided, the weights will be converted to that type when loaded. + special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): + If provided, special dtypes to consider for some specific weights (will override dtype used as default for + all weights). low_zero (`bool`, *optional*): Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the Transformers generate function). @@ -427,7 +448,7 @@ def get_balanced_memory( return max_memory num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0]) - module_sizes = compute_module_sizes(model, dtype=dtype) + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices) # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get @@ -486,6 +507,7 @@ def infer_auto_device_map( max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, no_split_module_classes: Optional[List[str]] = None, dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, @@ -514,6 +536,9 @@ def infer_auto_device_map( residual connection). dtype (`str` or `torch.dtype`, *optional*): If provided, the weights will be converted to that type when loaded. + special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): + If provided, special dtypes to consider for some specific weights (will override dtype used as default for + all weights). """ # Get default / clean up max_memory max_memory = get_max_memory(max_memory) @@ -530,7 +555,7 @@ def infer_auto_device_map( # Devices that need to keep space for a potential offloaded layer. main_devices = [gpus[0], "cpu"] if len(gpus) > 0 else ["cpu"] - module_sizes = compute_module_sizes(model, dtype=dtype) + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) tied_parameters = find_tied_parameters(model) device_map = {}