diff --git a/tests/conftest.py b/tests/conftest.py index ae362b228d9d..c56a3059a17e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -205,8 +205,14 @@ class HfRunner: def wrap_device(self, input: _T) -> _T: if not is_cpu(): + # Check if the input is already on the GPU + if hasattr(input, 'device') and input.device.type == "cuda": + return input # Already on GPU, no need to move return input.to("cuda") else: + # Check if the input is already on the CPU + if hasattr(input, 'device') and input.device.type == "cpu": + return input # Already on CPU, no need to move return input.to("cpu") def __init__( diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index b760e9ccb6b7..3f0c6cbc051a 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -2,85 +2,115 @@ Run `pytest tests/quantization/test_bitsandbytes.py`. ''' + +import gc + import pytest import torch from tests.quantization.utils import is_quant_method_supported -from vllm import SamplingParams -models_to_test = [ +models_4bit_to_test = [ ('huggyllama/llama-7b', 'quantize model inflight'), - ('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'), ] +models_pre_qaunt_4bit_to_test = [ + ('lllyasviel/omost-llama-3-8b-4bits', + 'read pre-quantized 4-bit NF4 model'), + ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', + 'read pre-quantized 4-bit FP4 model'), +] + +models_pre_quant_8bit_to_test = [ + ('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'), +] + + +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = {"load_in_4bit": True} + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name, hf_model_kwargs) + + +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", + models_pre_qaunt_4bit_to_test) +def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name) + @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", models_to_test) -def test_load_bnb_model(vllm_runner, model_name, description) -> None: +@pytest.mark.parametrize("model_name, description", + models_pre_quant_8bit_to_test) +def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name) + + +def log_generated_texts(prompts, outputs, runner_name): + logged_texts = [] + for i, (_, generated_text) in enumerate(outputs): + log_entry = { + "prompt": prompts[i], + "runner_name": runner_name, + "generated_text": generated_text, + } + logged_texts.append(log_entry) + return logged_texts + + +def validate_generated_texts(hf_runner, + vllm_runner, + prompts, + model_name, + hf_model_kwargs=None): + + if hf_model_kwargs is None: + hf_model_kwargs = {} + + # Run with HF runner + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + hf_outputs = llm.generate_greedy(prompts, 8) + hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") + + # Clean up the GPU memory for the next test + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + #Run with vLLM runner with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', - enforce_eager=True) as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 - - # check the weights in MLP & SelfAttention are quantized to torch.uint8 - qweight = model.model.layers[0].mlp.gate_up_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].mlp.down_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.o_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.qkv_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') - - # some weights should not be quantized - weight = model.lm_head.weight - assert weight.dtype != torch.uint8, ( - 'lm_head weight dtype should not be torch.uint8') - - weight = model.model.embed_tokens.weight - assert weight.dtype != torch.uint8, ( - 'embed_tokens weight dtype should not be torch.uint8') - - weight = model.model.layers[0].input_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - weight = model.model.layers[0].post_attention_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - # check the output of the model is expected - sampling_params = SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=8) - - prompts = ['That which does not kill us', 'To be or not to be,'] - expected_outputs = [ - 'That which does not kill us makes us stronger.', - 'To be or not to be, that is the question.' - ] - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(outputs) == len(prompts) - - for index in range(len(outputs)): - # compare the first line of the output - actual_output = outputs[index][1][0].split('\n', 1)[0] - expected_output = expected_outputs[index].split('\n', 1)[0] - - assert len(actual_output) >= len(expected_output), ( - f'Actual {actual_output} should be larger than or equal to ' - f'expected {expected_output}') - actual_output = actual_output[:len(expected_output)] - - assert actual_output == expected_output, ( - f'Expected: {expected_output}, but got: {actual_output}') + enforce_eager=True, + gpu_memory_utilization=0.8) as llm: + vllm_outputs = llm.generate_greedy(prompts, 8) + vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") + + # Clean up the GPU memory for the next test + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + # Compare the generated strings + for hf_log, vllm_log in zip(hf_logs, vllm_logs): + hf_str = hf_log["generated_text"] + vllm_str = vllm_log["generated_text"] + prompt = hf_log["prompt"] + assert hf_str == vllm_str, (f"Model: {model_name}" + f"Mismatch between HF and vLLM outputs:\n" + f"Prompt: {prompt}\n" + f"HF Output: '{hf_str}'\n" + f"vLLM Output: '{vllm_str}'") diff --git a/vllm/config.py b/vllm/config.py index 4cbdde5e113a..48187d57a29b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -353,6 +353,8 @@ def verify_with_parallel_config( raise ValueError( "BitAndBytes quantization with TP or PP is not supported yet.") + # Remove the constraint after the bitsandbytes issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 if self.quantization == "bitsandbytes" and self.enforce_eager is False: logger.warning("CUDA graph is not supported on BitAndBytes yet, " "fallback to the eager mode.") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e5b40a64abc4..e3956f847c54 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -35,9 +35,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def adjust_bitsandbytes_shard(param: Parameter, - qkv_offsets: Dict[str, Tuple[int, int]], - loaded_shard_id: str) -> Tuple[int, int]: +def adjust_bitsandbytes_4bit_shard(param: Parameter, + qkv_offsets: Dict[str, Tuple[int, int]], + loaded_shard_id: str) -> Tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" total, _ = qkv_offsets["total"] @@ -506,8 +506,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: shard_size = loaded_weight.shape[output_dim] shard_offset = loaded_weight.shape[output_dim] * \ loaded_shard_id @@ -859,8 +860,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), "k": (self.num_heads * self.head_size, @@ -872,7 +874,7 @@ def weight_loader(self, ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0) } - shard_size, shard_offset = adjust_bitsandbytes_shard( + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index c143d1a8f2bc..66bc5395dbd7 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) @@ -15,8 +14,28 @@ class BitsAndBytesConfig(QuantizationConfig): Reference: https://arxiv.org/abs/2305.14314 """ - def __init__(self, ) -> None: - pass + def __init__( + self, + load_in_8bit: bool = False, + load_in_4bit: bool = True, + bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_type: str = "fp4", + bnb_4bit_use_double_quant: bool = False, + llm_int8_enable_fp32_cpu_offload: bool = False, + llm_int8_has_fp16_weight: bool = False, + llm_int8_skip_modules: Optional[Any] = None, + llm_int8_threshold: float = 0.0, + ) -> None: + + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_threshold = llm_int8_threshold def __repr__(self) -> str: return "BitsAndBytesConfig" @@ -41,7 +60,46 @@ def get_config_filenames() -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": - return cls() + + def get_safe_value(config, keys, default_value=None): + try: + value = cls.get_from_keys(config, keys) + return value if value is not None else default_value + except ValueError: + return default_value + + load_in_8bit = get_safe_value(config, ["load_in_8bit"], + default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], + default_value=True) + bnb_4bit_compute_dtype = get_safe_value(config, + ["bnb_4bit_compute_dtype"], + default_value="float32") + bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], + default_value="fp4") + bnb_4bit_use_double_quant = get_safe_value( + config, ["bnb_4bit_use_double_quant"], default_value=False) + llm_int8_enable_fp32_cpu_offload = get_safe_value( + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) + llm_int8_has_fp16_weight = get_safe_value(config, + ["llm_int8_has_fp16_weight"], + default_value=False) + llm_int8_skip_modules = get_safe_value(config, + ["llm_int8_skip_modules"], + default_value=[]) + llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], + default_value=0.0) + + return cls( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + llm_int8_skip_modules=llm_int8_skip_modules, + llm_int8_threshold=llm_int8_threshold) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["BitsAndBytesLinearMethod"]: @@ -78,39 +136,58 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - quant_ratio = 0 - if params_dtype.is_floating_point: - quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo( - torch.uint8).bits + from bitsandbytes.nn import Int8Params + + def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + def create_qweight_for_8bit(): + qweight = Int8Params( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": 1, + "use_bitsandbytes_8bit": True, + "generation": 0 + }) + return qweight + + def create_qweight_for_4bit(): + quant_ratio = calculate_quant_ratio(params_dtype) + + total_size = input_size_per_partition * sum(output_partition_sizes) + if total_size % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape.") + + qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, + 1, + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True + }) + return qweight + + if self.quant_config.load_in_8bit: + qweight = create_qweight_for_8bit() else: - quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo( - torch.uint8).bits - - if input_size_per_partition * sum( - output_partition_sizes) % quant_ratio != 0: - raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. ") - qweight = Parameter( - torch.empty( - input_size_per_partition * sum(output_partition_sizes) // - quant_ratio, - 1, - dtype=torch.uint8, - ), - requires_grad=False, - ) - - set_weight_attrs( - qweight, - { - "input_dim": 0, - # In bitsandbytes, a tensor of shape [n,m] is quantized to - #[n*m/pack_ratio, 1],so the output_dim is 0 - "output_dim": 0, - "pack_factor": quant_ratio, - "use_bitsandbytes": True, - }) + qweight = create_qweight_for_4bit() + layer.register_parameter("qweight", qweight) set_weight_attrs(qweight, extra_weight_attrs) @@ -119,6 +196,88 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.quant_config.load_in_8bit: + return self._apply_8bit_weight(layer, x, bias) + else: + return self._apply_4bit_weight(layer, x, bias) + + def _apply_8bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import MatmulLtState, matmul + + original_type = x.dtype + bf_x = x.to(torch.bfloat16) + + qweight = layer.qweight + offsets = qweight.bnb_shard_offsets + quant_states = qweight.bnb_quant_state + matmul_states = qweight.matmul_state + generation = qweight.generation + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.float16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + + # in profile_run or the first generation of inference, + # create new matmul_states + if generation == 0 or generation == 1: + matmul_states[i] = MatmulLtState() + matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].SCB = quant_states[i] + matmul_states[i].threshold = ( + self.quant_config.llm_int8_threshold) + matmul_states[i].has_fp16_weights = ( + self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].is_training = False + if matmul_states[i].threshold > 0.0 and not matmul_states[ + i].has_fp16_weights: + matmul_states[i].use_pool = True + + new_x = bf_x.unsqueeze(0) + + out[:, current_index:current_index + output_size] = matmul( + new_x, + qweight[offsets[i]:offsets[i + 1]], + state=matmul_states[i]) + + current_index += output_size + + # only update the matmul_states if it is not profile_run + if (generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None): + del matmul_states[i].CB + qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + + out = out.to(original_type) + + if bias is not None: + out += bias + + qweight.generation += 1 + + return out + + def _apply_4bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # only load the bitsandbytes module when needed from bitsandbytes import matmul_4bit diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2f6cdbc6ce3e..553fa848489b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -771,7 +771,11 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): return pt_weights_iterator(hf_weights_files) def _get_quantized_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], pre_quant: bool + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, @@ -780,11 +784,9 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes - from bitsandbytes.functional import QuantState if bitsandbytes.__version__ < "0.42.0": raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.42.0.") - from bitsandbytes.functional import quantize_4bit except ImportError as err: raise ImportError("Please install bitsandbytes>=0.42.0 via " "`pip install bitsandbytes>=0.42.0` to use " @@ -793,80 +795,111 @@ def _get_quantized_weights_iterator( hf_weights_files, use_safetensors = self._prepare_weights( model_name_or_path, revision) - quant_state_dict = {} - - def quantized_checkpoint() -> Generator: - # First iterate over all quant state weights - weight_iterator = self._hf_weight_iter(hf_weights_files, - use_safetensors) - temp_state_dict = {} - for weight_name, weight_tensor in weight_iterator: - if weight_name.endswith(".weight"): - continue - # TODO: only nf4 quantization is supported for now - if weight_name.endswith(".quant_state.bitsandbytes__fp4"): - raise NotImplementedError( - "Only bitsandbytes_nf4 quantization" - f"is supported for now. {weight_name} is fp4 quantized" - ) - temp_state_dict[weight_name] = weight_tensor + quant_state_dict: Dict[str, Any] = {} - # Closure to parse quant_state for each prequant weight - def _parse_quant_state(param_name: str, - temp_state_dict: Dict) -> QuantState: - quant_state = {} - for k in temp_state_dict: - if param_name + "." in k: - quant_state[k] = temp_state_dict[k] - # bitsandbytes library requires - # weight.quant_state.bitsandbytes__nf4 in CPU - quant_state[param_name + - ".quant_state.bitsandbytes__nf4"] = quant_state[ - param_name + - ".quant_state.bitsandbytes__nf4"].cpu().data - return QuantState.from_dict(quant_state, device="cuda") - - # Second iterate over all prequant and normal weights - # pre quantized weights would have a quant_state - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - # Filter out all weights whose suffix is not ".weight" - if not weight_name.endswith(".weight"): - continue - if weight_name + ".quant_state.bitsandbytes__nf4" \ - in temp_state_dict: - quant_state = _parse_quant_state(weight_name, - temp_state_dict) - weight_name = weight_name.replace(".weight", ".qweight") - quant_state_dict[weight_name] = quant_state - yield weight_name.replace(".weight", - ".qweight"), weight_tensor - else: - yield weight_name, weight_tensor - - def generator() -> Generator: - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - if any(target_module in weight_name - for target_module in self.target_modules): - weight_name = weight_name.replace(".weight", ".qweight") - # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data - with set_default_torch_dtype(torch.float32): - processed_weight, quant_state = quantize_4bit( - loaded_weight, - compress_statistics=True, - quant_type="nf4") - - quant_state_dict[weight_name] = quant_state - else: - processed_weight = weight_tensor + if pre_quant: + if load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict - yield weight_name, processed_weight + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict - if pre_quant: - return quantized_checkpoint(), quant_state_dict - return generator(), quant_state_dict + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if not weight_name.lower().endswith(".scb"): + continue + + weight_key = weight_name.lower().replace(".scb", ".qweight") + quant_state_dict[weight_key] = weight_tensor + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + + if not weight_name.endswith(".weight"): + continue + + qweight_name = weight_name.replace(".weight", ".qweight") + if qweight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield qweight_name, weight_tensor + else: + yield weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith(".weight"): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in weight_name: + temp_state_dict[weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + # Filter out all weights whose suffix is not ".weight" + if not weight_name.endswith(".weight"): + continue + if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ + in temp_state_dict) or \ + (f"{weight_name}.quant_state.bitsandbytes__fp4" \ + in temp_state_dict): + quant_state = _parse_quant_state(weight_name, temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None: @@ -883,16 +916,26 @@ def _load_weights(self, model_config: ModelConfig, logger.info("Loading weights with BitsAndBytes quantization. " " May take a while ...") - is_quantized_checkpoint = False quant_config = getattr(model_config.hf_config, "quantization_config", None) - if quant_config is not None and quant_config.get( - 'quant_method') == "bitsandbytes": - is_quantized_checkpoint = True + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get('quant_method') + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get('load_in_8bit', False) qweight_iterator, quant_state_dict = \ self._get_quantized_weights_iterator( - model_config.model, model_config.revision, is_quantized_checkpoint) + model_config.model, model_config.revision, pre_quant, load_8bit) model.load_weights(qweight_iterator) @@ -942,6 +985,10 @@ def _load_weights(self, model_config: ModelConfig, offsets = np.concatenate(([0], np.cumsum(num_elements))) set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig],