From 126e816b6975e825c4c3bc42bb06bd38015857e2 Mon Sep 17 00:00:00 2001 From: "qian.chen" Date: Wed, 29 May 2024 05:18:47 +0000 Subject: [PATCH] [model] support bitsandbytes/QLoRA (#4033) --- examples/lora_with_quantization_inference.py | 140 ++++++++++ requirements-dev.txt | 3 + tests/quantization/test_bitsandbytes.py | 80 ++++++ vllm/config.py | 9 +- vllm/engine/arg_utils.py | 38 ++- vllm/model_executor/layers/linear.py | 41 ++- .../layers/quantization/__init__.py | 3 + .../layers/quantization/bitsandbytes.py | 171 ++++++++++++ vllm/model_executor/model_loader/loader.py | 248 +++++++++++++++++- .../model_loader/weight_utils.py | 16 +- vllm/model_executor/models/llama.py | 8 + 11 files changed, 749 insertions(+), 8 deletions(-) create mode 100644 examples/lora_with_quantization_inference.py create mode 100644 tests/quantization/test_bitsandbytes.py create mode 100644 vllm/model_executor/layers/quantization/bitsandbytes.py diff --git a/examples/lora_with_quantization_inference.py b/examples/lora_with_quantization_inference.py new file mode 100644 index 000000000000..3b2347c1115e --- /dev/null +++ b/examples/lora_with_quantization_inference.py @@ -0,0 +1,140 @@ +""" +This example shows how to use LoRA with different quantization techniques +for offline inference. + +Requires HuggingFace credentials for access. +""" + +import gc +from typing import List, Optional, Tuple + +import torch +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + + +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: + return [ + # this is an example of using quantization without LoRA + ("My name is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + # the next three examples use quantization with LoRA + ("my name is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), + LoRARequest("lora-test-1", 1, lora_path)), + ("The capital of USA is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), + LoRARequest("lora-test-2", 1, lora_path)), + ("The capital of France is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), + LoRARequest("lora-test-3", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + for request_output in request_outputs: + if request_output.finished: + print("----------------------------------------------------") + print(f"Prompt: {request_output.prompt}") + print(f"Output: {request_output.outputs[0].text}") + + +def initialize_engine(model: str, quantization: str, + lora_repo: Optional[str]) -> LLMEngine: + """Initialize the LLMEngine.""" + + if quantization == "bitsandbytes": + # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique. + # It quantizes the model when loading, with some config info from the + # LoRA adapter repo. So need to set the parameter of load_format and + # qlora_adapter_name_or_path as below. + engine_args = EngineArgs( + model=model, + quantization=quantization, + qlora_adapter_name_or_path=lora_repo, + load_format="bitsandbytes", + enable_lora=True, + max_lora_rank=64, + # set it only in GPUs of limited memory + enforce_eager=True) + else: + engine_args = EngineArgs( + model=model, + quantization=quantization, + enable_lora=True, + max_loras=4, + # set it only in GPUs of limited memory + enforce_eager=True) + return LLMEngine.from_engine_args(engine_args) + + +def main(): + """Main function that sets up and runs the prompt processing.""" + + test_configs = [{ + "name": "qlora_inference_example", + 'model': "huggyllama/llama-7b", + 'quantization': "bitsandbytes", + 'lora_repo': 'timdettmers/qlora-flan-7b' + }, { + "name": "AWQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', + 'quantization': "awq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + }, { + "name": "GPTQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', + 'quantization': "gptq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + }] + + for test_config in test_configs: + print( + f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~" + ) + engine = initialize_engine(test_config['model'], + test_config['quantization'], + test_config['lora_repo']) + lora_path = snapshot_download(repo_id=test_config['lora_repo']) + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + + # Clean up the GPU memory for the next test + del engine + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + main() diff --git a/requirements-dev.txt b/requirements-dev.txt index cf2bb9bef22d..2c6b33ea813a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -35,3 +35,6 @@ aiohttp # Multimodal pillow + +# quantization +bitsandbytes==0.42.0 diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py new file mode 100644 index 000000000000..4e9feb3c4814 --- /dev/null +++ b/tests/quantization/test_bitsandbytes.py @@ -0,0 +1,80 @@ +'''Tests whether bitsandbytes computation is enabled correctly. + +Run `pytest tests/quantization/test_bitsandbytes.py`. +''' +import pytest +import torch + +from vllm import SamplingParams +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +@pytest.mark.skipif( + capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(), + reason='bitsandbytes is not supported on this GPU type.') +def test_load_bnb_model(vllm_runner) -> None: + llm = vllm_runner('huggyllama/llama-7b', + quantization='bitsandbytes', + load_format='bitsandbytes', + enforce_eager=True) + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + + # check the weights in MLP & SelfAttention are quantized to torch.uint8 + qweight = model.model.layers[0].mlp.gate_up_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].mlp.down_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].self_attn.o_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].self_attn.qkv_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') + + # some weights should not be quantized + weight = model.lm_head.weight + assert weight.dtype != torch.uint8, ( + 'lm_head weight dtype should not be torch.uint8') + + weight = model.model.embed_tokens.weight + assert weight.dtype != torch.uint8, ( + 'embed_tokens weight dtype should not be torch.uint8') + + weight = model.model.layers[0].input_layernorm.weight + assert weight.dtype != torch.uint8, ( + 'input_layernorm weight dtype should not be torch.uint8') + + weight = model.model.layers[0].post_attention_layernorm.weight + assert weight.dtype != torch.uint8, ( + 'input_layernorm weight dtype should not be torch.uint8') + + # check the output of the model is expected + sampling_params = SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=8) + + prompts = ['That which does not kill us', 'To be or not to be,'] + expected_outputs = [ + 'That which does not kill us makes us stronger.', + 'To be or not to be, that is the question.' + ] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + assert len(outputs) == len(prompts) + + for index in range(len(outputs)): + # compare the first line of the output + actual_output = outputs[index][1][0].split('\n', 1)[0] + expected_output = expected_outputs[index].split('\n', 1)[0] + assert actual_output == expected_output, ( + f'Expected: {expected_output}, but got: {actual_output}') diff --git a/vllm/config.py b/vllm/config.py index 4b256d00a32d..74ea558722eb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -229,6 +229,12 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") + if self.quantization == "bitsandbytes" and ( + parallel_config.tensor_parallel_size > 1 + or parallel_config.pipeline_parallel_size > 1): + raise ValueError( + "BitAndBytes quantization with TP or PP is not supported yet.") + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled. """ @@ -315,7 +321,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_attention_heads(self, parallel_config: "ParallelConfig") -> int: return self.hf_text_config.num_attention_heads // \ - parallel_config.tensor_parallel_size + parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers @@ -475,6 +481,7 @@ class LoadFormat(str, enum.Enum): DUMMY = "dummy" TENSORIZER = "tensorizer" SHARDED_STATE = "sharded_state" + BITSANDBYTES = "bitsandbytes" @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 11485aa2438c..1880ae4debfb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -92,6 +92,8 @@ class EngineArgs: ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None + qlora_adapter_name_or_path: Optional[str] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -159,7 +161,8 @@ def add_cli_args( type=str, default=EngineArgs.load_format, choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' ], help='The format of the model weights to load.\n\n' '* "auto" will try to load the weights in the safetensors format ' @@ -173,7 +176,9 @@ def add_cli_args( 'which is mainly for profiling.\n' '* "tensorizer" will load the weights using tensorizer from ' 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n') + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') parser.add_argument( '--dtype', type=str, @@ -543,7 +548,10 @@ def add_cli_args( "will also be used in `model_name` tag content of " "prometheus metrics, if multiple names provided, metrics" "tag will take the first one.") - + parser.add_argument('--qlora-adapter-name-or-path', + type=str, + default=None, + help='Name or path of the QLoRA adapter.') return parser @classmethod @@ -555,6 +563,23 @@ def from_cli_args(cls, args: argparse.Namespace): return engine_args def create_engine_config(self, ) -> EngineConfig: + + # bitsandbytes quantization needs a specific model loader + # so we make sure the quant method and the load format are consistent + if (self.quantization == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.load_format != "bitsandbytes": + raise ValueError( + "BitsAndBytes quantization only supports 'bitsandbytes' load " + f"format, but got {self.load_format}") + + if (self.load_format == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.quantization != "bitsandbytes": + raise ValueError( + "BitsAndBytes load format can only be used with bitsandbytes'" + f" quantization, but got {self.quantization}") + device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, @@ -622,6 +647,13 @@ def create_engine_config(self, ) -> EngineConfig: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + if self.qlora_adapter_name_or_path is not None and \ + self.qlora_adapter_name_or_path != "": + if self.model_loader_extra_config is None: + self.model_loader_extra_config = {} + self.model_loader_extra_config[ + "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 34fbfa8e33ef..f5b6bdd9f7fd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -26,6 +26,21 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def adjust_bitsandbytes_shard(param: Parameter, + qkv_offsets: Dict[str, Tuple[int, int]], + loaded_shard_id: str) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = qkv_offsets["total"] + orig_offset, orig_size = qkv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -37,7 +52,7 @@ def create_weights(self, layer: torch.nn.Module, **extra_weight_attrs): """Create weights for a linear layer. The weights will be set as attributes of the layer. - + Args: layer: The layer that is using the LinearMethodBase factory. input_size_per_partition: Size of the weight input dim on rank X. @@ -416,6 +431,12 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size @@ -615,6 +636,22 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, + self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) + } + shard_size, shard_offset = adjust_bitsandbytes_shard( + param, orig_qkv_offsets, loaded_shard_id) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 7b9abe1b629a..0bc42beb6625 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.bitsandbytes import ( + BitsAndBytesConfig) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( @@ -30,6 +32,7 @@ "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "sparseml": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, } diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py new file mode 100644 index 000000000000..65708774ef24 --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -0,0 +1,171 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class BitsAndBytesConfig(QuantizationConfig): + """Config class for BitsAndBytes Quantization. + + Reference: https://arxiv.org/abs/2305.14314 + """ + + def __init__( + self, + adapter_name_or_path: str, + target_modules: List[str], + ) -> None: + + self.adapter_name_or_path = adapter_name_or_path + self.target_modules = target_modules + + def __repr__(self) -> str: + return ( + f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}" + ) + + @classmethod + def get_name(self) -> str: + return "bitsandbytes" + + @classmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(self) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "adapter_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": + adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"]) + default_target_modules = [ + "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", + "o_proj" + ] + if adapter_name == "": + target_modules = default_target_modules + else: + target_modules = cls.get_from_keys(config, ["target_modules"]) + return cls(adapter_name, target_modules) + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]: + if isinstance(layer, LinearBase): + return BitsAndBytesLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + +class BitsAndBytesLinearMethod(LinearMethodBase): + """Linear method for BitsAndBytes. + + Args: + quant_config: The BitsAndBytes quantization config. + """ + + def __init__(self, quant_config: BitsAndBytesConfig): + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.42.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.42.0.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.42.0 via " + "`pip install bitsandbytes>=0.42.0` to use " + "bitsandbytes quantizer.") from err + + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + quant_ratio = 0 + if params_dtype.is_floating_point: + quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo( + torch.uint8).bits + else: + quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo( + torch.uint8).bits + + if input_size_per_partition * sum( + output_partition_sizes) % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. ") + qweight = Parameter( + torch.empty( + input_size_per_partition * sum(output_partition_sizes) // + quant_ratio, + 1, + dtype=torch.uint8, + ), + requires_grad=False, + ) + + set_weight_attrs( + qweight, + { + "input_dim": 0, + # In bitsandbytes, a tensor of shape [n,m] is qunatized to + #[n*m/pack_ratio, 1],so the output_dim is 0 + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes": True, + }) + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import matmul_4bit + + orginal_type = x.dtype + bf_x = x.to(torch.bfloat16) + + qweight = layer.qweight + quant_states = qweight.bnb_quant_state + offsets = qweight.bnb_shard_offsets + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.bfloat16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + out[:, current_index:current_index + output_size] = matmul_4bit( + bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + + current_index += output_size + + out = out.to(orginal_type) + + if bias is not None: + out += bias + + return out diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b7b5b5e7695f..a198b4c692c5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,13 +1,18 @@ # ruff: noqa: SIM117 import collections import copy +import fnmatch import glob +import json +import math import os from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Tuple, Type import huggingface_hub +import numpy as np import torch +from huggingface_hub import HfApi, hf_hub_download from torch import nn from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, @@ -28,6 +33,7 @@ get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.vlm_base import VisionLanguageModelBase +from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -125,7 +131,7 @@ def __init__(self, load_config: LoadConfig): def _maybe_download_from_modelscope( self, model: str, revision: Optional[str]) -> Optional[str]: """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. - + Returns the path to the downloaded model, or None if the model is not downloaded from ModelScope.""" if VLLM_USE_MODELSCOPE: @@ -247,6 +253,7 @@ def load_model(self, *, model_config: ModelConfig, model, "fall_back_to_pt_during_load", True)), ) + for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: @@ -539,6 +546,242 @@ def save_model( ) +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quabntization.""" + + default_target_modules = [ + "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", + "o_proj" + ] + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + self.target_modules = self.default_target_modules + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download(repo_id=qlora_adapter, + filename=file) + break + + if not config_file_path: + raise ValueError( + f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, self.load_config.download_dir, + [pattern], revision) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _get_quantized_weights_iterator( + self, model_name_or_path: str, revision: Optional[str] + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.42.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.42.0.") + from bitsandbytes.functional import quantize_4bit + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.42.0 via " + "`pip install bitsandbytes>=0.42.0` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict = {} + if use_safetensors: + weight_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weight_iterator = pt_weights_iterator(hf_weights_files) + + def generator(): + for weight_name, weight_tensor in weight_iterator: + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + return generator(), quant_state_dict + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, 'load_weights'): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(self).__name__}.") + + if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): + raise AttributeError( + f"Model {type(self).__name__} does not support BitsAndBytes " + "quantization yet.") + + logger.info("Loading weights with BitsAndBytes quantization. " + " May take a while ...") + + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator(model_config.model, + model_config.revision)) + + model.load_weights(qweight_iterator) + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, index + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace( + shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model.") + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in enumerate(quant_states.items()): + num_elements[seq] = math.prod( + quant_state[1].shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + + self._load_weights(model_config, model) + + return model.eval() + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -554,4 +797,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.SHARDED_STATE: return ShardedStateLoader(load_config) + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 53e21eba8fae..6174f0a97471 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -130,7 +130,17 @@ def get_quant_config(model_config: ModelConfig, if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) - model_name_or_path = model_config.model + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. @@ -169,6 +179,10 @@ def get_quant_config(model_config: ModelConfig, quant_config_file = quant_config_files[0] with open(quant_config_file, "r") as f: config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + return quant_cls.from_config(config) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2ca55f9270fc..d83ee9a201c0 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -319,6 +319,14 @@ class LlamaForCausalLM(nn.Module): "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } def __init__( self,