Skip to content

Commit

Permalink
Support special mapping of dtypes when preparing device map (#1179)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Mar 13, 2023
1 parent 8dec01a commit 3b3605e
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,36 @@ def retie_parameters(model, tied_params):
setattr(tied_module, tied_param_name.split(".")[-1], param)


def compute_module_sizes(model: nn.Module, dtype: Optional[Union[str, torch.device]] = None):
def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
"""
Compute the size of each submodule of a given model.
Just does torch.dtype(dtype) if necessary.
"""
if isinstance(dtype, str):
# We accept "torch.float16" or just "float16"
dtype = dtype.replace("torch.", "")
dtype = getattr(torch, dtype)
return dtype


def compute_module_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model.
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
for name, tensor in named_module_tensors(model, recurse=True):
if dtype is None:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
Expand Down Expand Up @@ -394,6 +411,7 @@ def get_balanced_memory(
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
low_zero: bool = False,
):
"""
Expand All @@ -416,6 +434,9 @@ def get_balanced_memory(
residual connection).
dtype (`str` or `torch.dtype`, *optional*):
If provided, the weights will be converted to that type when loaded.
special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
If provided, special dtypes to consider for some specific weights (will override dtype used as default for
all weights).
low_zero (`bool`, *optional*):
Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the
Transformers generate function).
Expand All @@ -427,7 +448,7 @@ def get_balanced_memory(
return max_memory

num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0])
module_sizes = compute_module_sizes(model, dtype=dtype)
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)

# We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get
Expand Down Expand Up @@ -486,6 +507,7 @@ def infer_auto_device_map(
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -514,6 +536,9 @@ def infer_auto_device_map(
residual connection).
dtype (`str` or `torch.dtype`, *optional*):
If provided, the weights will be converted to that type when loaded.
special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
If provided, special dtypes to consider for some specific weights (will override dtype used as default for
all weights).
"""
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
Expand All @@ -530,7 +555,7 @@ def infer_auto_device_map(
# Devices that need to keep space for a potential offloaded layer.
main_devices = [gpus[0], "cpu"] if len(gpus) > 0 else ["cpu"]

module_sizes = compute_module_sizes(model, dtype=dtype)
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
tied_parameters = find_tied_parameters(model)

device_map = {}
Expand Down

0 comments on commit 3b3605e

Please sign in to comment.