Skip to content

Commit

Permalink
Add offload for 8-bit model (#1699)
Browse files Browse the repository at this point in the history
* Add offload for 8-bit model

* fix saved 8bit model offload and add tests

* Update src/accelerate/utils/modeling.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/accelerate/utils/modeling.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add doc on how offload works

* remove enable_offload

* make style doc

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
SunMarc and sgugger authored Jul 11, 2023
1 parent c769883 commit 27d2908
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 82 deletions.
21 changes: 21 additions & 0 deletions docs/source/usage_guides/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ quantized_model_from_saved = load_and_quantize_model(empty_model, weights_locati

Note that 4-bit model serialization is currently not supported.

### Offload modules to cpu and disk

You can offload some modules to cpu/disk if you don't have enough space on the GPU to store the entire model on your GPUs.
This uses big model inference under the hood. Check this [documentation](https://huggingface.co/docs/accelerate/usage_guides/big_modeling) for more details.

For 8-bit quantization, the selected modules will be converted to 8-bit precision.

For 4-bit quantization, the selected modules will be kept in `torch_dtype` that the user passed in `BnbQuantizationConfig`. We will add support to convert these offloaded modules in 4-bit when 4-bit serialization will be possible.

You just need to pass a custom `device_map` in order to offload modules on cpu/disk. The offload modules will be dispatched on the GPU when needed. Here's an example :

```py
device_map = {
"transformer.wte": 0,
"transformer.wpe": 0,
"transformer.drop": 0,
"transformer.h": "cpu",
"transformer.ln_f": "disk",
"lm_head": "disk",
}
```
### Fine-tune a quantized model

With the official support of adapters in the Hugging Face ecosystem, you can fine-tune quantized models. Please have a look at [peft](https://github.com/huggingface/peft) library for more details.
Expand Down
11 changes: 10 additions & 1 deletion src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,13 @@ def pre_forward(self, module, *args, **kwargs):
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
):
set_module_tensor_to_device(module, name, self.execution_device, value=self.weights_map[name])
fp16_statistics = None
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
if self.weights_map[name].dtype == torch.int8:
fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
set_module_tensor_to_device(
module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics
)

return send_to_device(args, self.execution_device), send_to_device(
kwargs, self.execution_device, skip_keys=self.skip_keys
Expand All @@ -291,6 +297,9 @@ def post_forward(self, module, output):
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
):
set_module_tensor_to_device(module, name, "meta")
if type(module).__name__ == "Linear8bitLt":
module.state.SCB = None
module.state.CxB = None

if self.io_same_device and self.input_device is not None:
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
Expand Down
86 changes: 64 additions & 22 deletions src/accelerate/utils/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@

from ..big_modeling import dispatch_model, init_empty_weights
from .dataclasses import BnbQuantizationConfig
from .modeling import find_tied_parameters, get_balanced_memory, infer_auto_device_map, load_checkpoint_in_model
from .modeling import (
find_tied_parameters,
get_balanced_memory,
infer_auto_device_map,
load_checkpoint_in_model,
offload_weight,
set_module_tensor_to_device,
)


if is_bnb_available():
Expand Down Expand Up @@ -98,24 +105,20 @@ def load_and_quantize_model(
)

modules_on_cpu = []
# custom device map
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
if len(modules_on_cpu) > 0 and not bnb_quantization_config.enable_fp32_cpu_offload:
raise ValueError(
"If you want to offload some keys to `cpu` or `disk`, you need to set "
" `enable_fp32_cpu_offload=True`. Note that these modules will not be "
" converted to 8-bit but kept in 32-bit."
)

# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if bnb_quantization_config.skip_modules is None:
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)

# add cpu modules to skip modules (after looking into the code on transformers, we don't really keep the cpu module in fp32)
bnb_quantization_config.skip_modules.extend(modules_on_cpu)
# add cpu modules to skip modules only for 4-bit modules
if load_in_4bit:
bnb_quantization_config.skip_modules.extend(modules_on_cpu)
modules_to_not_convert = bnb_quantization_config.skip_modules
# We add the modules we want to keep in full precision

# We add the modules we want to keep in full precision
if bnb_quantization_config.keep_in_fp32_modules is None:
bnb_quantization_config.keep_in_fp32_modules = []
keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
Expand Down Expand Up @@ -176,6 +179,8 @@ def load_and_quantize_model(
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
offload_state_dict = True

offload = any(x in list(device_map.values()) for x in ["cpu", "disk"])

load_checkpoint_in_model(
model,
weights_location,
Expand All @@ -184,6 +189,7 @@ def load_and_quantize_model(
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules,
offload_8bit_bnb=load_in_8bit and offload,
)
return dispatch_model(model, device_map=device_map, offload_dir=offload_folder)

Expand Down Expand Up @@ -247,18 +253,22 @@ def get_quantized_model_device_map(
}
for device in ["cpu", "disk"]:
if device in device_map_without_some_modules.values():
raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the
quantized model. If you want to dispatch the model on the CPU or the disk while keeping these
modules in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom
`device_map` to `from_pretrained`. Check
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
for more details.
"""
)
if bnb_quantization_config.load_in_4bit:
raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
the quantized model. If you want to dispatch the model on the CPU or the disk while keeping
these modules in `torch_dtype`, you need to pass a custom `device_map` to
`load_and_quantize_model`. Check
https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk
for more details.
"""
)
else:
logger.info(
"Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit"
)
del device_map_without_some_modules

return device_map


Expand Down Expand Up @@ -348,9 +358,10 @@ def _replace_with_bnb_layers(
setattr(model, name, bnb_module)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_layers(
_, _has_been_replaced = _replace_with_bnb_layers(
module, bnb_quantization_config, modules_to_not_convert, current_key_name
)
has_been_replaced = has_been_replaced | _has_been_replaced
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
Expand Down Expand Up @@ -418,3 +429,34 @@ def has_4bit_bnb_layers(model):

def get_parameter_device(parameter: nn.Module):
return next(parameter.parameters()).device


def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):
# if it is not quantized, we quantize and offload the quantized weights and the SCB stats
if fp16_statistics is None:
set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
tensor_name = param_name
module = model
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
# offload weights
module._parameters[tensor_name].requires_grad = False
offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
if hasattr(module._parameters[tensor_name], "SCB"):
offload_weight(
module._parameters[tensor_name].SCB,
param_name.replace("weight", "SCB"),
offload_folder,
index=offload_index,
)
else:
offload_weight(param, param_name, offload_folder, index=offload_index)
offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index)

set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size()))
17 changes: 0 additions & 17 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,17 +1378,6 @@ class BnbQuantizationConfig:
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
)

# we will see if it will be useful
enable_fp32_cpu_offload: bool = field(
default=False,
metadata={
"help": """ this flag is used for advanced use cases and users that are aware of this feature. If you want to split
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
operations will not be run on CPU."""
},
)

def __post_init__(self):
"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
Expand All @@ -1408,9 +1397,6 @@ def __post_init__(self):
if not isinstance(self.llm_int8_threshold, (int, float)):
raise ValueError("llm_int8_threshold must be a float or an int")

if not isinstance(self.enable_fp32_cpu_offload, bool):
raise ValueError("enable_fp32_cpu_offload must be a boolean")

if not isinstance(self.bnb_4bit_quant_type, str):
raise ValueError("bnb_4bit_quant_type must be a string")
elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]:
Expand Down Expand Up @@ -1439,9 +1425,6 @@ def __post_init__(self):
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
raise ValueError("keep_in_fp_32_modules must be a list of strings")

if not isinstance(self.enable_fp32_cpu_offload, bool):
raise ValueError("enable_fp32_cpu_offload must be a boolean")

if self.load_in_4bit:
self.target_dtype = CustomDtype.INT4

Expand Down
87 changes: 66 additions & 21 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,13 @@ def set_module_tensor_to_device(
device_quantization = None
with torch.no_grad():
# leave it on cpu first before moving them to cuda
if param is not None and param.device.type != "cuda" and param_cls.__name__ in ["Int8Params", "FP4Params"]:
# # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
if (
param is not None
and param.device.type != "cuda"
and torch.device(device).type == "cuda"
and param_cls.__name__ in ["Int8Params", "FP4Params"]
):
device_quantization = device
device = "cpu"
if value is None:
Expand All @@ -303,15 +309,25 @@ def set_module_tensor_to_device(
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
# downcast to fp16 if any - needed for 8bit serialization
new_value = new_value.to(torch.float16)
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
# quantize module that are going to stay on the cpu so that we offload quantized weights
if device == "cpu" and param_cls.__name__ == "Int8Params":
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
new_value.CB = new_value.CB.to("cpu")
new_value.SCB = new_value.SCB.to("cpu")
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
module._parameters[tensor_name] = new_value

if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))

if module.__class__.__name__ == "Linear8bitLt" and getattr(module.weight, "SCB", None) is None:
setattr(module._parameters[tensor_name], "SCB", fp16_statistics.to(device))
del fp16_statistics
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
if (
module.__class__.__name__ == "Linear8bitLt"
and getattr(module.weight, "SCB", None) is None
and str(module.weight.device) != "meta"
):
# quantize only if necessary
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
if not getattr(module.weight, "SCB", None) and device_index is not None:
Expand All @@ -326,6 +342,8 @@ def set_module_tensor_to_device(
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
if not getattr(module.weight, "quant_state", None) and device_index is not None:
module.weight = module.weight.cuda(device_index)
# clean pre and post foward hook
torch.cuda.empty_cache()


def named_module_tensors(module: nn.Module, include_buffers: bool = True, recurse: bool = False):
Expand Down Expand Up @@ -660,11 +678,18 @@ def load_offloaded_weights(model, index, offload_folder):
if index is None or len(index) == 0:
# Nothing to do
return

for param_name, metadata in index.items():
if "SCB" in param_name:
continue
fp16_statistics = None
if "weight" in param_name and param_name.replace("weight", "SCB") in index.keys():
weight_name = param_name.replace("weight", "SCB")
fp16_statistics = load_offloaded_weight(
os.path.join(offload_folder, f"{weight_name}.dat"), index[weight_name]
)
tensor_file = os.path.join(offload_folder, f"{param_name}.dat")
weight = load_offloaded_weight(tensor_file, metadata)
set_module_tensor_to_device(model, param_name, "cpu", value=weight)
set_module_tensor_to_device(model, param_name, "cpu", value=weight, fp16_statistics=fp16_statistics)


def get_balanced_memory(
Expand Down Expand Up @@ -1137,6 +1162,7 @@ def load_checkpoint_in_model(
offload_state_dict: bool = False,
offload_buffers: bool = False,
keep_in_fp32_modules: List[str] = None,
offload_8bit_bnb: bool = False,
):
"""
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
Expand Down Expand Up @@ -1171,8 +1197,13 @@ def load_checkpoint_in_model(
Whether or not to include the buffers in the weights offloaded to disk.
keep_in_fp32_modules(`List[str]`, *optional*):
A list of the modules that we keep in `torch.float32` dtype.
offload_8bit_bnb (`bool`, *optional*):
Whether or not to enable offload of 8-bit modules on cpu/disk.
"""
if offload_8bit_bnb:
from .bnb import quantize_and_offload_8bit

tied_params = find_tied_parameters(model)

if check_tied_parameters_in_config(model) and len(tied_params) == 0:
Expand Down Expand Up @@ -1239,6 +1270,10 @@ def load_checkpoint_in_model(
model.load_state_dict(checkpoint, strict=False)
else:
for param_name, param in checkpoint.items():
# skip SCB parameter (for 8-bit serialization)
if "SCB" in param_name:
continue

module_name = param_name

while len(module_name) > 0 and module_name not in device_map:
Expand Down Expand Up @@ -1268,23 +1303,33 @@ def load_checkpoint_in_model(
if offload_buffers or param_name not in buffer_names:
if new_dtype is None:
new_dtype = param.dtype
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
offload_weight(param, param_name, offload_folder, index=offload_index)
if offload_8bit_bnb:
quantize_and_offload_8bit(
model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics
)
continue
else:
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
offload_weight(param, param_name, offload_folder, index=offload_index)
elif param_device == "cpu" and offload_state_dict:
if new_dtype is None:
new_dtype = param.dtype
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
else:
if "SCB" not in param_name:
set_module_tensor_to_device(
model,
param_name,
param_device,
value=param,
dtype=new_dtype,
fp16_statistics=fp16_statistics,
if offload_8bit_bnb:
quantize_and_offload_8bit(
model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics
)
else:
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
else:
set_module_tensor_to_device(
model,
param_name,
param_device,
value=param,
dtype=new_dtype,
fp16_statistics=fp16_statistics,
)

# Force Python to clean up.
del checkpoint
Expand Down
Loading

0 comments on commit 27d2908

Please sign in to comment.