Skip to content

Commit

Permalink
fix saved 8bit model offload and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jul 10, 2023
1 parent b19f9dc commit 4a0b3c7
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 24 deletions.
50 changes: 28 additions & 22 deletions src/accelerate/utils/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,26 +446,32 @@ def get_parameter_device(parameter: nn.Module):
return next(parameter.parameters()).device


def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index):
set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
tensor_name = param_name
module = model
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
# offload weights
module._parameters[tensor_name].requires_grad = False
offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
if hasattr(module._parameters[tensor_name], "SCB"):
offload_weight(
module._parameters[tensor_name].SCB,
param_name.replace("weight", "SCB"),
offload_folder,
index=offload_index,
)
def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):
# if it is not quantized, we quantize and offload the quantized weights and the SCB stats
if fp16_statistics is None:
set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
tensor_name = param_name
module = model
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
# offload weights
module._parameters[tensor_name].requires_grad = False
offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
if hasattr(module._parameters[tensor_name], "SCB"):
offload_weight(
module._parameters[tensor_name].SCB,
param_name.replace("weight", "SCB"),
offload_folder,
index=offload_index,
)
else:
offload_weight(param, param_name, offload_folder, index=offload_index)
offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index)

set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size()))
4 changes: 2 additions & 2 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ def load_checkpoint_in_model(
new_dtype = param.dtype
if offload_8bit_bnb:
quantize_and_offload_8bit(
model, param, param_name, new_dtype, offload_folder, offload_index
model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics
)
continue
else:
Expand All @@ -1316,7 +1316,7 @@ def load_checkpoint_in_model(
new_dtype = param.dtype
if offload_8bit_bnb:
quantize_and_offload_8bit(
model, param, param_name, new_dtype, state_dict_folder, state_dict_index
model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics
)
else:
set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
Expand Down
61 changes: 61 additions & 0 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,67 @@ def test_int8_serialization(self):

self.check_inference_correctness(model_8bit_from_saved)

def test_int8_serialization_offload(self):
r"""
Test whether it is possible to serialize a model in 8-bit and offload weights to cpu/disk
"""

from bitsandbytes.nn import Int8Params
from transformers import AutoConfig, AutoModelForCausalLM

with tempfile.TemporaryDirectory() as tmpdirname:
# saving state dict for now but will save config and other in the future
self.accelerate.save_model(self.model_8bit, tmpdirname)

with init_empty_weights():
# let's suppose that we can get the right config
model_8bit_from_saved = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name))
model_8bit_from_saved.tie_weights()
bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, enable_offload=True)
device_map = {
"transformer.word_embeddings": "cpu",
"transformer.word_embeddings_layernorm": 0,
"lm_head": "cpu",
"transformer.h.0": "cpu",
"transformer.h.1": "cpu",
"transformer.h.2": "cpu",
"transformer.h.3": "disk",
"transformer.h.4": "disk",
"transformer.h.5": "disk",
"transformer.h.6": 0,
"transformer.h.7": 0,
"transformer.h.8": 0,
"transformer.h.9": 1,
"transformer.h.10": 0,
"transformer.h.11": 1,
"transformer.h.12": 0,
"transformer.h.13": 0,
"transformer.h.14": 1,
"transformer.h.15": 0,
"transformer.h.16": 0,
"transformer.h.17": 1,
"transformer.h.18": 1,
"transformer.h.19": 0,
"transformer.h.20": 1,
"transformer.h.21": 1,
"transformer.h.22": 0,
"transformer.h.23": 0,
"transformer.ln_f": 1,
}
model_8bit_from_saved = load_and_quantize_model(
model_8bit_from_saved,
bnb_quantization_config,
weights_location=tmpdirname + "/pytorch_model.bin",
device_map=device_map,
no_split_module_classes=["BloomBlock"],
offload_folder=tmpdirname + "/tmp",
offload_state_dict=True,
)

self.assertTrue(model_8bit_from_saved.transformer.h[4].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(model_8bit_from_saved.transformer.h[5].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.check_inference_correctness(model_8bit_from_saved)

def test_int8_serialization_shard(self):
r"""
Test whether it is possible to serialize a model in 8-bit.
Expand Down

0 comments on commit 4a0b3c7

Please sign in to comment.