From 77cde6d4513588dcd78c5de9e4523501aa55c02c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 27 Jun 2024 16:03:04 +0800 Subject: [PATCH] [Model] Add base class for LoRA-supported models (#5018) Signed-off-by: Alvant --- docs/source/models/lora.rst | 3 + vllm/lora/lora.py | 3 +- vllm/lora/models.py | 6 +- vllm/model_executor/model_loader/loader.py | 20 ++-- vllm/model_executor/models/baichuan.py | 11 +- vllm/model_executor/models/chatglm.py | 11 +- vllm/model_executor/models/decilm.py | 4 +- vllm/model_executor/models/gemma.py | 10 +- vllm/model_executor/models/gpt_bigcode.py | 9 +- vllm/model_executor/models/interfaces.py | 130 +++++++++++++++++++++ vllm/model_executor/models/llama.py | 9 +- vllm/model_executor/models/llava.py | 22 ++-- vllm/model_executor/models/llava_next.py | 20 ++-- vllm/model_executor/models/minicpm.py | 12 +- vllm/model_executor/models/mixtral.py | 9 +- vllm/model_executor/models/phi.py | 22 ++-- vllm/model_executor/models/qwen2.py | 10 +- vllm/model_executor/models/vlm_base.py | 12 -- vllm/model_executor/models/xverse.py | 11 +- vllm/worker/model_runner.py | 11 +- 20 files changed, 270 insertions(+), 75 deletions(-) create mode 100644 vllm/model_executor/models/interfaces.py delete mode 100644 vllm/model_executor/models/vlm_base.py diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index 2278640481a91..934887a607a6a 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -4,6 +4,9 @@ Using LoRA adapters =================== This document shows you how to use `LoRA adapters `_ with vLLM on top of a base model. + +LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`. + Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save them locally with diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 8f3c7f76932af..14081b5ba441c 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -2,6 +2,7 @@ from typing import Sequence as GenericSequence import torch +import torch.types from vllm.utils import is_pin_memory_available @@ -64,7 +65,7 @@ def create_dummy_lora_weights( output_dim: int, rank: int, dtype: torch.dtype, - device: torch.device, + device: torch.types.Device, embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() lora_a = torch.zeros([input_dim, rank], diff --git a/vllm/lora/models.py b/vllm/lora/models.py index afb9ba4550671..0a1fc7c021781 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -18,6 +18,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) @@ -363,7 +364,7 @@ class LoRAModelManager: def __init__( self, - model: nn.Module, + model: SupportsLoRA, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, @@ -411,7 +412,7 @@ def __init__( # embeddings_indices self.indices_len: List[Optional[int]] = [None] * 4 - self.model: nn.Module = model + self.model = model if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) @@ -428,7 +429,6 @@ def __init__( self._active_loras: Dict[int, None] = {} self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() - self.model.lora_manager = self @property def capacity(self) -> int: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d3babcf9c3451..e91bf7cf35b41 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -32,7 +32,8 @@ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.vlm_base import VisionLanguageModelBase +from vllm.model_executor.models.interfaces import (supports_lora, + supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_tpu @@ -64,12 +65,15 @@ def _get_quantization_config( def _get_model_initialization_kwargs( - model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + vlm_config: Optional[VisionLanguageConfig], ) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs: Dict[str, Any] = {} - if hasattr(model_class, "supported_lora_modules"): + + if supports_lora(model_class): + # lora_config=None is used to disable LoRA extra_kwargs["lora_config"] = lora_config elif lora_config: raise ValueError( @@ -77,13 +81,15 @@ def _get_model_initialization_kwargs( "but LoRA is enabled. Support for this model may " "be added in the future. If this is important to you, " "please open an issue on github.") - elif issubclass(model_class, VisionLanguageModelBase): - if vision_language_config is None: + + if supports_vision(model_class): + if vlm_config is None: raise ValueError("Provide `image_input_type` and other vision " "related configurations through LLM entrypoint " "or engine arguments.") - extra_kwargs["vision_language_config"] = vision_language_config + extra_kwargs["vlm_config"] = vlm_config + return extra_kwargs diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index babb92e7cdcef..abaefa3cf7781 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -45,6 +45,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .interfaces import SupportsLoRA + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) @@ -292,7 +294,9 @@ def forward( return hidden_states -class BaiChuanBaseForCausalLM(nn.Module): +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ @@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module): def __init__( self, - config, + config: PretrainedConfig, position_embedding: str, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, cache_config, quant_config) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e3a5e43e23e1c..bf64538ef54a3 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -28,6 +28,8 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig +from .interfaces import SupportsLoRA + class GLMAttention(nn.Module): @@ -322,7 +324,9 @@ def forward( return hidden_states -class ChatGLMForCausalLM(nn.Module): +class ChatGLMForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] @@ -345,7 +349,10 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ): super().__init__() - self.config: ChatGLMConfig = config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index e293ee491908d..65b409a2a15a0 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -26,7 +26,7 @@ from typing import Iterable, Optional, Tuple import torch -from transformers import PretrainedConfig +from transformers import LlamaConfig from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, - config: Optional[PretrainedConfig] = None, + config: LlamaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 65f4ebec5bcf0..9e071a155061b 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -41,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .interfaces import SupportsLoRA + logger = init_logger(__name__) @@ -288,7 +290,9 @@ def forward( return hidden_states -class GemmaForCausalLM(nn.Module): +class GemmaForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -319,9 +323,11 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - del lora_config # Unused. super().__init__() + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b15ed11988c27..009d7b1498c22 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -41,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .interfaces import SupportsLoRA + class GPTBigCodeAttention(nn.Module): @@ -230,7 +232,9 @@ def forward( return hidden_states -class GPTBigCodeForCausalLM(nn.Module): +class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = {"c_attn": ["c_attn"]} supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] @@ -250,7 +254,10 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ): super().__init__() + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, lora_config) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py new file mode 100644 index 0000000000000..a9eb397a5a97f --- /dev/null +++ b/vllm/model_executor/models/interfaces.py @@ -0,0 +1,130 @@ +from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, + Union, overload, runtime_checkable) + +from typing_extensions import TypeGuard + +from vllm.config import LoRAConfig, VisionLanguageConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@runtime_checkable +class SupportsVision(Protocol): + """The interface required for all vision language models (VLMs).""" + + supports_vision: ClassVar[Literal[True]] + + def __init__(self, *, vlm_config: VisionLanguageConfig) -> None: + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsVisionType(Protocol): + supports_vision: Literal[True] + + def __call__(self, *, vlm_config: VisionLanguageConfig) -> None: + ... + + +@overload +def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]: + ... + + +@overload +def supports_vision(model: object) -> TypeGuard[SupportsVision]: + ... + + +def supports_vision( + model: Union[Type[object], object], +) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]: + if isinstance(model, type): + return isinstance(model, _SupportsVisionType) + + return isinstance(model, SupportsVision) + + +@runtime_checkable +class SupportsLoRA(Protocol): + """The interface required for all models that support LoRA.""" + + supports_lora: ClassVar[Literal[True]] + + packed_modules_mapping: ClassVar[Dict[str, List[str]]] + supported_lora_modules: ClassVar[List[str]] + embedding_modules: ClassVar[Dict[str, str]] + embedding_padding_modules: ClassVar[List[str]] + + # lora_config is None when LoRA is not enabled + def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsLoRAType(Protocol): + supports_lora: Literal[True] + + packed_modules_mapping: Dict[str, List[str]] + supported_lora_modules: List[str] + embedding_modules: Dict[str, str] + embedding_padding_modules: List[str] + + def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + ... + + +@overload +def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]: + ... + + +@overload +def supports_lora(model: object) -> TypeGuard[SupportsLoRA]: + ... + + +def supports_lora( + model: Union[Type[object], object], +) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: + result = _supports_lora(model) + + if not result: + lora_attrs = ( + "packed_modules_mapping", + "supported_lora_modules", + "embedding_modules", + "embedding_padding_modules", + ) + missing_attrs = tuple(attr for attr in lora_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_lora", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_lora=True`, " + "but is missing LoRA-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all LoRA-specific attributes, " + "but does not set `supports_lora=True`.", model) + + return result + + +def _supports_lora( + model: Union[Type[object], object], +) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: + if isinstance(model, type): + return isinstance(model, _SupportsLoRAType) + + return isinstance(model, SupportsLoRA) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d83ee9a201c0b..f4918cbfef294 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,6 +49,8 @@ from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once +from .interfaces import SupportsLoRA + class LlamaMLP(nn.Module): @@ -296,7 +298,9 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module): +class LlamaForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -336,7 +340,10 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() + self.config = config + self.lora_config = lora_config + self.model = LlamaModel(config, cache_config, quant_config, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 8e36c54b1c511..8e18b42b76734 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -20,7 +20,7 @@ from vllm.multimodal.image import get_dummy_image_data from vllm.sequence import SamplerOutput -from .vlm_base import VisionLanguageModelBase +from .interfaces import SupportsVision _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", @@ -86,18 +86,21 @@ class LlavaImageFeatureInputs(TypedDict): @MULTIMODAL_REGISTRY.register_image_feature_input() @MULTIMODAL_REGISTRY.register_image_pixel_input() @MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) -class LlavaForConditionalGeneration(VisionLanguageModelBase): +class LlavaForConditionalGeneration(nn.Module, SupportsVision): + + supports_vision = True def __init__(self, config: LlavaConfig, - vision_language_config: VisionLanguageConfig, + vlm_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: - super().__init__(vision_language_config) + super().__init__() self.config = config + self.vlm_config = vlm_config - if self.vision_language_config.image_input_type == ( + if self.vlm_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): self.vision_tower = CLIPVisionModel(config.vision_config) else: @@ -122,11 +125,10 @@ def __init__(self, self.sampler = Sampler() def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != list( - self.vision_language_config.image_input_shape[1:]): + if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]): raise ValueError( f"The expected image tensor shape is batch dimension plus " - f"{self.vision_language_config.image_input_shape[1:]}. " + f"{self.vlm_config.image_input_shape[1:]}. " f"You supplied {data.shape}. " f"If you are using vLLM's entrypoint, make sure your " f"supplied image input is consistent with " @@ -139,7 +141,7 @@ def _parse_and_validate_image_input( pixel_values = kwargs.pop("pixel_values", None) image_features = kwargs.pop("image_features", None) - expected_input_type = self.vision_language_config.image_input_type + expected_input_type = self.vlm_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType if expected_input_type == ImageInputType.PIXEL_VALUES: @@ -273,7 +275,7 @@ def forward( inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, - self.vision_language_config.image_token_id) + self.vlm_config.image_token_id) input_ids = None else: diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c1158c933c88b..5c03fb3705561 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -25,8 +25,8 @@ from vllm.multimodal.image import ImagePixelData, get_dummy_image_data from vllm.sequence import SamplerOutput, SequenceData +from .interfaces import SupportsVision from .llava import LlavaMultiModalProjector, merge_vision_embeddings -from .vlm_base import VisionLanguageModelBase logger = init_logger(__name__) @@ -106,19 +106,21 @@ def _image_pixel_processor( @MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor) @MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data) -class LlavaNextForConditionalGeneration(VisionLanguageModelBase): +class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): + + supports_vision = True def __init__(self, config: LlavaNextConfig, - vision_language_config: VisionLanguageConfig, + vlm_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: - super().__init__(vision_language_config) + super().__init__() - # Update the type annotation from that of its superclass self.config = config + self.vlm_config = vlm_config - if self.vision_language_config.image_input_type == ( + if self.vlm_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): self.vision_tower = CLIPVisionModel(config=config.vision_config) else: @@ -146,7 +148,7 @@ def __init__(self, torch.empty(config.text_config.hidden_size)) def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor: - _, num_channels, _, _ = self.vision_language_config.image_input_shape + _, num_channels, _, _ = self.vlm_config.image_input_shape # Note that this is different from that of vLLM vision_language_config # since the image is resized by the HuggingFace preprocessor @@ -177,7 +179,7 @@ def _parse_and_validate_image_input( image_sizes = kwargs.pop("image_sizes", None) image_features = kwargs.pop("image_features", None) - expected_input_type = self.vision_language_config.image_input_type + expected_input_type = self.vlm_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType if expected_input_type == ImageInputType.PIXEL_VALUES: @@ -386,7 +388,7 @@ def forward( inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, - self.vision_language_config.image_token_id) + self.vlm_config.image_token_id) input_ids = None else: diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 59fbf8e1b35f2..ae17309bd5223 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -26,6 +26,7 @@ import torch from torch import nn +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -51,6 +52,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput +from .interfaces import SupportsLoRA + class MiniCPMMoE(nn.Module): """A tensor-parallel MoE implementation that shards each expert @@ -388,7 +391,9 @@ def forward( return hidden_states -class MiniCPMForCausalLM(nn.Module): +class MiniCPMForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -418,13 +423,16 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() + self.config = config + self.lora_config = lora_config + self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3faf54d292b99..0bdcb21e514fd 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -54,6 +54,8 @@ from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once +from .interfaces import SupportsLoRA + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert @@ -472,7 +474,9 @@ def forward( return hidden_states -class MixtralForCausalLM(nn.Module): +class MixtralForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -504,7 +508,10 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() + self.config = config + self.lora_config = lora_config + self.model = MixtralModel(config, cache_config, quant_config, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index c8e61735a9bb6..d288bdd9d78f5 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -39,7 +39,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import PhiConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -59,11 +59,13 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .interfaces import SupportsLoRA + class PhiAttention(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -131,7 +133,7 @@ def forward( class PhiMLP(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -160,7 +162,7 @@ def forward(self, hidden_states): class PhiLayer(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -192,7 +194,7 @@ def forward( class PhiModel(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -229,7 +231,9 @@ def forward( return hidden_states -class PhiForCausalLM(nn.Module): +class PhiForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module): def __init__( self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - del lora_config # Unused. super().__init__() + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config self.model = PhiModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b5d13bb6b937c..d351adcefc974 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -48,6 +48,8 @@ from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once +from .interfaces import SupportsLoRA + class Qwen2MLP(nn.Module): @@ -263,7 +265,9 @@ def forward( return hidden_states -class Qwen2ForCausalLM(nn.Module): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -293,7 +297,6 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - del lora_config # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None and hasattr(config, "max_window_layers")): @@ -307,7 +310,10 @@ def __init__( )) super().__init__() + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config self.model = Qwen2Model(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/vlm_base.py b/vllm/model_executor/models/vlm_base.py deleted file mode 100644 index eb0aa96e50d59..0000000000000 --- a/vllm/model_executor/models/vlm_base.py +++ /dev/null @@ -1,12 +0,0 @@ -from torch import nn - -from vllm.config import VisionLanguageConfig - - -class VisionLanguageModelBase(nn.Module): - """Base class for all vision language models (VLMs).""" - - def __init__(self, vision_language_config: VisionLanguageConfig) -> None: - super().__init__() - - self.vision_language_config = vision_language_config diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 1e5280dde3ff9..639c3443bc369 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -45,6 +45,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .interfaces import SupportsLoRA + class XverseMLP(nn.Module): @@ -266,7 +268,9 @@ def forward( return hidden_states -class XverseForCausalLM(nn.Module): +class XverseForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -299,10 +303,13 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - lora_config=None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config self.model = XverseModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ac820bbcbca33..181442490a82c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -22,6 +22,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models.interfaces import supports_lora from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata @@ -225,14 +226,8 @@ def load_model(self) -> None: self.model_memory_usage / float(2**30)) if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") - assert hasattr( - self.model, - "embedding_modules"), "Model does not have embedding_modules" - assert hasattr(self.model, "embedding_padding_modules" - ), "Model does not have embedding_padding_modules" + assert supports_lora(self.model), "Model does not support LoRA" + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens,