diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8d109b2c81503..0b76f466702fc 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -1,3 +1,4 @@ +import json import os from typing import Dict, List @@ -13,6 +14,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager) +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) @@ -30,18 +32,68 @@ ] +def test_peft_helper(sql_lora_files): + lora_config_path = os.path.join(sql_lora_files, "adapter_config.json") + with open(lora_config_path) as f: + config = json.load(f) + peft_helper = PEFTHelper.from_dict(config) + assert peft_helper.r == 8 + assert peft_helper.lora_alpha == 16 + assert peft_helper.target_modules == [ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + + expected_error = "vLLM only supports modules_to_save being None." + with pytest.raises(ValueError, match=expected_error): + config = dict( + r=8, + lora_alpha=16, + target_modules=["gate_proj"], + modules_to_save=["lm_head"], + ) + PEFTHelper.from_dict(config) + expected_error = "vLLM does not yet support RSLoRA." + with pytest.raises(ValueError, match=expected_error): + config = dict(r=8, + lora_alpha=16, + target_modules=["gate_proj"], + use_rslora=True) + PEFTHelper.from_dict(config) + + expected_error = "vLLM does not yet support DoRA." + with pytest.raises(ValueError, match=expected_error): + config = dict(r=8, + lora_alpha=16, + target_modules=["gate_proj"], + use_dora=True) + PEFTHelper.from_dict(config) + + @pytest.mark.parametrize("device", CUDA_DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) new_embeddings = load_file( os.path.join(sql_lora_files, "new_embeddings.safetensors")) + + lora_config_path = os.path.join(sql_lora_files, "adapter_config.json") + with open(lora_config_path) as f: + config = json.load(f) + + peft_helper = PEFTHelper.from_dict(config) lora_model = LoRAModel.from_lora_tensors( 1, - 8, - 16, tensors, - device, + peft_helper=peft_helper, + device=device, embeddings=new_embeddings, embedding_modules=EMBEDDING_MODULES, embedding_padding_modules=EMBEDDING_PADDING_MODULES) diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index b648312ba76ec..dde347b78bf81 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -4,6 +4,7 @@ import torch import torch.types +from vllm.lora.peft_helper import PEFTHelper from vllm.utils import is_pin_memory_available @@ -59,6 +60,23 @@ def extra_vocab_size(self) -> int: return self.embeddings_tensor.shape[ 0] if self.embeddings_tensor is not None else 0 + @classmethod + def from_config( + cls, + module_name: str, + peft_helper: PEFTHelper, + embeddings_tensor: Optional[torch.Tensor] = None, + ) -> "LoRALayerWeights": + return cls( + module_name, + peft_helper.r, + peft_helper.lora_alpha, + None, + None, + None, + embeddings_tensor, + ) + @classmethod def create_dummy_lora_weights( cls, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 49cd9f0c236ad..70806a77b9fff 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -21,6 +21,7 @@ LinearScalingRotaryEmbeddingWithLora, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, is_regex_target_modules, @@ -104,14 +105,12 @@ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: def from_lora_tensors( cls, lora_model_id: int, - rank: int, - lora_alpha: int, tensors: Dict[str, torch.Tensor], + peft_helper: PEFTHelper, device: str = "cuda", dtype: Optional[torch.dtype] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None, target_embedding_padding: Optional[int] = None, - scaling_factor: Optional[float] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": @@ -135,10 +134,9 @@ def from_lora_tensors( if pin_memory: lora_embeddings_tensor = ( lora_embeddings_tensor.pin_memory()) - loras[module_name] = LoRALayerWeights(module_name, rank, - lora_alpha, None, None, - None, - lora_embeddings_tensor) + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper, lora_embeddings_tensor) + if is_bias: loras[module_name].bias = tensor.to(device=device, dtype=dtype).t() @@ -170,7 +168,11 @@ def from_lora_tensors( for lora in loras.values(): lora.optimize() - return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor) + + return cls(lora_model_id, + peft_helper.r, + loras, + scaling_factor=peft_helper.vllm_scaling_factor) @classmethod def from_local_checkpoint( @@ -212,6 +214,9 @@ def from_local_checkpoint( "new_embeddings.bin") with open(lora_config_path) as f: config = json.load(f) + + config["vllm_max_position_embeddings"] = max_position_embeddings + peft_helper = PEFTHelper.from_dict(config) if os.path.isfile(lora_tensor_path): tensors: Dict[str, torch.Tensor] = {} # Find unexpected modules. @@ -242,7 +247,7 @@ def from_local_checkpoint( # When a bin file is provided, we rely on config to find unexpected # modules. unexpected_modules = [] - target_modules = config["target_modules"] + target_modules = peft_helper.target_modules if not isinstance(target_modules, list): target_modules = [target_modules] for module in target_modules: @@ -256,7 +261,7 @@ def from_local_checkpoint( # https://github.com/vllm-project/vllm/pull/5909. But there's no # other better mechanism. if unexpected_modules and not is_regex_target_modules( - config["target_modules"], expected_lora_modules): + peft_helper.target_modules, expected_lora_modules): raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" @@ -274,30 +279,17 @@ def from_local_checkpoint( embeddings = torch.load(new_embeddings_bin_file_path, map_location=device) - rank = config["r"] - lora_alpha = config["lora_alpha"] - context_length = config.get("context_length", None) - scaling_factor = None - if context_length: - if max_position_embeddings is None: - max_position_embeddings = context_length - scaling_factor = float( - math.ceil(context_length / max_position_embeddings)) - return cls.from_lora_tensors( lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, - rank=rank, - lora_alpha=lora_alpha, tensors=tensors, + peft_helper=peft_helper, device=device, dtype=dtype, embeddings=embeddings, target_embedding_padding=target_embedding_padding, - scaling_factor=scaling_factor, embedding_modules=embedding_modules, - embedding_padding_modules=embedding_padding_modules, - ) + embedding_padding_modules=embedding_padding_modules) class LoRAModelManager(AdapterModelManager): diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py new file mode 100644 index 0000000000000..edf4ba5659575 --- /dev/null +++ b/vllm/lora/peft_helper.py @@ -0,0 +1,70 @@ +# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py + +import math +from dataclasses import MISSING, dataclass, field, fields +from typing import Literal, Optional, Union + + +@dataclass +class PEFTHelper: + # Required fields + r: int + lora_alpha: int + target_modules: Union[list[str], str] + + bias: Literal["none", "all", "lora_only"] = field(default="none") + modules_to_save: Optional[list[str]] = field(default=None) + use_rslora: bool = field(default=False) + use_dora: bool = field(default=False) + # long lora field + context_length: int = field(default=0) + # Extra vllm field, start with 'vllm_' to avoid conflict + vllm_max_position_embeddings: Optional[int] = field(default=False) + vllm_scaling_factor: Optional[float] = field(default=None) + + def _validate_features(self): + error_msg = [] + + if self.modules_to_save: + error_msg.append("vLLM only supports modules_to_save being None.") + if self.use_rslora: + error_msg.append("vLLM does not yet support RSLoRA.") + + if self.use_dora: + error_msg.append("vLLM does not yet support DoRA.") + + if error_msg: + raise ValueError(f"{', '.join(error_msg)}") + + def __post_init__(self): + self._validate_features() + if self.context_length: + if self.vllm_max_position_embeddings is None: + self.vllm_max_position_embeddings = self.context_length + self.vllm_scaling_factor = float( + math.ceil(self.context_length / + self.vllm_max_position_embeddings)) + + @classmethod + def from_dict(cls, config_dict: dict) -> "PEFTHelper": + # Get all field information from the class + class_fields = {f.name: f for f in fields(cls)} + # Check for required fields + required_fields = { + name + for name, f in class_fields.items() + if f.default is MISSING and f.default_factory is MISSING + } + + # Identify any missing required fields + missing_fields = required_fields - set(config_dict.keys()) + if missing_fields: + raise ValueError( + f"Missing required configuration fields: {missing_fields}") + + # Filter out fields that aren't defined in the class + filtered_dict = { + k: v + for k, v in config_dict.items() if k in class_fields + } + return cls(**filtered_dict)