From 8e770bdec73efed5439a540c80d5ccc8f512743b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 6 Oct 2024 15:22:31 +0000 Subject: [PATCH 1/8] Add and test interfaces by task type --- tests/models/test_registry.py | 24 +++- vllm/model_executor/models/__init__.py | 7 + vllm/model_executor/models/interfaces_base.py | 133 ++++++++++++++++++ vllm/model_executor/models/registry.py | 42 ++++-- 4 files changed, 193 insertions(+), 13 deletions(-) create mode 100644 vllm/model_executor/models/interfaces_base.py diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 299aeacb9f33..638452be700b 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -3,7 +3,14 @@ import pytest import torch.cuda -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models import (supports_embedding, + supports_multimodal, + supports_text_generation) +from vllm.model_executor.models.registry import (_EMBEDDING_MODELS, + _MULTIMODAL_MODELS, + _SPECULATIVE_DECODING_MODELS, + _TEXT_GENERATION_MODELS, + ModelRegistry) from vllm.platforms import current_platform from ..utils import fork_new_process_for_each_test @@ -12,7 +19,20 @@ @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): # Ensure all model classes can be imported successfully - ModelRegistry.resolve_model_cls(model_arch) + model_cls, _ = ModelRegistry.resolve_model_cls(model_arch) + + if model_arch in _SPECULATIVE_DECODING_MODELS: + pass # Ignore these models which do not have a unified format + else: + assert supports_text_generation(model_cls) is ( + model_arch in _TEXT_GENERATION_MODELS + or model_arch in _MULTIMODAL_MODELS) + + assert supports_embedding(model_cls) is (model_arch + in _EMBEDDING_MODELS) + + assert supports_multimodal(model_cls) is (model_arch + in _MULTIMODAL_MODELS) @fork_new_process_for_each_test diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 51054a147a06..68969f41e344 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,10 +1,17 @@ from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, SupportsPP, has_inner_state, supports_lora, supports_multimodal, supports_pp) +from .interfaces_base import (VllmModelForEmbedding, + VllmModelForTextGeneration, supports_embedding, + supports_text_generation) from .registry import ModelRegistry __all__ = [ "ModelRegistry", + "VllmModelForEmbedding", + "supports_embedding", + "VllmModelForTextGeneration", + "supports_text_generation", "HasInnerState", "has_inner_state", "SupportsLoRA", diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py new file mode 100644 index 000000000000..aeafb25928f4 --- /dev/null +++ b/vllm/model_executor/models/interfaces_base.py @@ -0,0 +1,133 @@ +from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, + overload, runtime_checkable) + +import torch +from transformers import PretrainedConfig +from typing_extensions import TypeIs, TypeVar + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.config import CacheConfig + from vllm.model_executor.layers.pooler import PoolerOutput + from vllm.model_executor.layers.quantization import QuantizationConfig + from vllm.model_executor.layers.sampler import SamplerOutput + from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.sampling_metadata import SamplingMetadata + +# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags +# for the base interfaces to avoid breaking OOT registration + +# The type of hidden states +# Currently, T = torch.Tensor for all models except for Medusa +# which has T = List[torch.Tensor] +T = TypeVar("T", default=torch.Tensor) + + +@runtime_checkable +class VllmModelForTextGeneration(Protocol[T]): + + def __init__( + self, + config: PretrainedConfig, + *, + cache_config: Optional["CacheConfig"], + quant_config: Optional["QuantizationConfig"], + ) -> None: + ... + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + ) -> T: + ... + + def compute_logits( + self, + hidden_states: T, + sampling_metadata: "SamplingMetadata", + ) -> Optional[T]: + """Return `None` if TP rank > 0.""" + ... + + def sample( + self, + logits: T, + sampling_metadata: "SamplingMetadata", + ) -> "SamplerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def supports_text_generation( + model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]: + ... + + +@overload +def supports_text_generation( + model: object) -> TypeIs[VllmModelForTextGeneration]: + ... + + +def supports_text_generation( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForTextGeneration]], + TypeIs[VllmModelForTextGeneration]]: + if isinstance(model, type): + return isinstance(model, VllmModelForTextGeneration) + + return isinstance(model, VllmModelForTextGeneration) + + +@runtime_checkable +class VllmModelForEmbedding(Protocol[T]): + + def __init__( + self, + config: PretrainedConfig, + *, + cache_config: Optional["CacheConfig"], + quant_config: Optional["QuantizationConfig"], + ) -> None: + ... + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + ) -> T: + ... + + def pooler( + self, + hidden_states: T, + pooling_metadata: "PoolingMetadata", + ) -> "PoolerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def supports_embedding( + model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]: + ... + + +@overload +def supports_embedding(model: object) -> TypeIs[VllmModelForEmbedding]: + ... + + +def supports_embedding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: + if isinstance(model, type): + return isinstance(model, VllmModelForEmbedding) + + return isinstance(model, VllmModelForEmbedding) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ccb0e155ff4a..d9629f56c7db 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -12,10 +12,12 @@ from vllm.utils import is_hip from .interfaces import supports_multimodal, supports_pp +from .interfaces_base import supports_embedding, supports_text_generation logger = init_logger(__name__) -_GENERATION_MODELS = { +_TEXT_GENERATION_MODELS = { + # [Decoder-only] "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), @@ -74,10 +76,9 @@ "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - # NOTE: The below models are for speculative decoding only - "MedusaModel": ("medusa", "Medusa"), - "EAGLEModel": ("eagle", "EAGLE"), - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + # [Encoder-decoder] + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), } _EMBEDDING_MODELS = { @@ -114,16 +115,18 @@ "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), } -_CONDITIONAL_GENERATION_MODELS = { - "BartModel": ("bart", "BartForConditionalGeneration"), - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), + +_SPECULATIVE_DECODING_MODELS = { + "EAGLEModel": ("eagle", "EAGLE"), + "MedusaModel": ("medusa", "Medusa"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } _MODELS = { - **_GENERATION_MODELS, + **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, **_MULTIMODAL_MODELS, - **_CONDITIONAL_GENERATION_MODELS, + **_SPECULATIVE_DECODING_MODELS, } # Architecture -> type or (module, class). @@ -317,6 +320,19 @@ def _check_stateless( return result.returncode == 0 + @staticmethod + def is_text_generation_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_txt_gen = partial(ModelRegistry._check_stateless, + supports_text_generation, + default=False) + + return any(is_txt_gen(arch) for arch in architectures) + @staticmethod def is_embedding_model(architectures: Union[str, List[str]]) -> bool: if isinstance(architectures, str): @@ -324,7 +340,11 @@ def is_embedding_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - return any(arch in _EMBEDDING_MODELS for arch in architectures) + is_emb = partial(ModelRegistry._check_stateless, + supports_embedding, + default=False) + + return any(is_emb(arch) for arch in architectures) @staticmethod def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: From 7a7272c9eca4c25d88e525923c73967396564017 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 6 Oct 2024 15:24:27 +0000 Subject: [PATCH 2/8] Add test --- tests/conftest.py | 20 +++++++++++++++++++ tests/models/test_oot_registration.py | 18 ++++++++++++++--- .../vllm_add_dummy_model/__init__.py | 6 ++++++ .../my_gemma_embedding.py | 14 +++++++++++++ 4 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py diff --git a/tests/conftest.py b/tests/conftest.py index 5de3f1f2a2b9..c042160cbc44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -885,6 +885,7 @@ def num_gpus_available(): temp_dir = tempfile.gettempdir() _dummy_opt_path = os.path.join(temp_dir, "dummy_opt") _dummy_llava_path = os.path.join(temp_dir, "dummy_llava") +_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding") @pytest.fixture @@ -923,3 +924,22 @@ def dummy_llava_path(): with open(json_path, "w") as f: json.dump(config, f) return _dummy_llava_path + + +@pytest.fixture +def dummy_gemma2_embedding_path(): + json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json") + if not os.path.exists(_dummy_gemma2_embedding_path): + snapshot_download(repo_id="BAAI/bge-multilingual-gemma2", + local_dir=_dummy_gemma2_embedding_path, + ignore_patterns=[ + "*.bin", "*.bin.index.json", "*.pt", "*.h5", + "*.msgpack" + ]) + assert os.path.exists(json_path) + with open(json_path, "r") as f: + config = json.load(f) + config["architectures"] = ["MyGemma2Embedding"] + with open(json_path, "w") as f: + json.dump(config, f) + return _dummy_gemma2_embedding_path diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index ee3f8911f318..94be215258f8 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -2,7 +2,7 @@ import pytest -from vllm import LLM, SamplingParams +from vllm import LLM, PoolingParams, SamplingParams from vllm.assets.image import ImageAsset from ..utils import fork_new_process_for_each_test @@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path): @fork_new_process_for_each_test -def test_oot_registration(dummy_opt_path): +def test_oot_registration_text_generation(dummy_opt_path): os.environ["VLLM_PLUGINS"] = "register_dummy_model" prompts = ["Hello, my name is", "The text does not matter"] sampling_params = SamplingParams(temperature=0) @@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path): assert rest == "" +@fork_new_process_for_each_test +def test_oot_registration_embedding(dummy_gemma2_embedding_path): + os.environ["VLLM_PLUGINS"] = "register_dummy_model" + prompts = ["Hello, my name is", "The text does not matter"] + sampling_params = PoolingParams() + llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy") + outputs = llm.encode(prompts, sampling_params) + + for output in outputs: + assert all(v == 0 for v in output.outputs.embedding) + + image = ImageAsset("cherry_blossom").pil_image.convert("RGB") @fork_new_process_for_each_test -def test_oot_multimodal_registration(dummy_llava_path): +def test_oot_registration_multimodal(dummy_llava_path): os.environ["VLLM_PLUGINS"] = "register_dummy_model" prompts = [{ "prompt": "What's in the image?", diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py index 022ba66e38cc..62a8f871fa51 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py @@ -9,6 +9,12 @@ def register(): ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM) # Test passing lazy model + if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs(): + ModelRegistry.register_model( + "MyGemma2Embedding", + "vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding", + ) + if "MyLlava" not in ModelRegistry.get_supported_archs(): ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava") diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py new file mode 100644 index 000000000000..71c824c6a91c --- /dev/null +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -0,0 +1,14 @@ +import torch + +from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel + + +class MyGemma2Embedding(Gemma2EmbeddingModel): + + def forward(self, *args, **kwargs) -> torch.Tensor: + hidden_states = super().forward(*args, **kwargs) + + # We assume PP isn't used in the test + assert isinstance(hidden_states, torch.Tensor) + + return torch.zeros_like(hidden_states) From 0b3731f8d9f36cd9b488236fd834da3829b9311d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 6 Oct 2024 15:30:25 +0000 Subject: [PATCH 3/8] Update docs --- .../vllm_add_dummy_model/my_gemma_embedding.py | 1 + vllm/model_executor/models/interfaces_base.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 71c824c6a91c..0a7bbfdc23cf 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -11,4 +11,5 @@ def forward(self, *args, **kwargs) -> torch.Tensor: # We assume PP isn't used in the test assert isinstance(hidden_states, torch.Tensor) + # Return all-zero embeddings return torch.zeros_like(hidden_states) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index aeafb25928f4..6577488f1235 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -15,7 +15,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata # NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags -# for the base interfaces to avoid breaking OOT registration +# for the base interfaces to avoid breaking OOT registration for existing models +# that don't inherit from the base interface classes # The type of hidden states # Currently, T = torch.Tensor for all models except for Medusa From ca44d814cf88464932368a8f2a394e8a50dc87f4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 7 Oct 2024 03:21:32 +0000 Subject: [PATCH 4/8] Improve interface checks --- tests/models/test_registry.py | 8 +- vllm/model_executor/models/__init__.py | 8 +- vllm/model_executor/models/interfaces.py | 28 ++--- vllm/model_executor/models/interfaces_base.py | 105 ++++++++++++------ vllm/model_executor/models/registry.py | 6 +- vllm/utils.py | 9 ++ 6 files changed, 101 insertions(+), 63 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 638452be700b..2a09f826cc00 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -3,9 +3,9 @@ import pytest import torch.cuda -from vllm.model_executor.models import (supports_embedding, +from vllm.model_executor.models import (is_embedding_model, supports_multimodal, - supports_text_generation) + is_text_generation_model) from vllm.model_executor.models.registry import (_EMBEDDING_MODELS, _MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, @@ -24,11 +24,11 @@ def test_registry_imports(model_arch): if model_arch in _SPECULATIVE_DECODING_MODELS: pass # Ignore these models which do not have a unified format else: - assert supports_text_generation(model_cls) is ( + assert is_text_generation_model(model_cls) is ( model_arch in _TEXT_GENERATION_MODELS or model_arch in _MULTIMODAL_MODELS) - assert supports_embedding(model_cls) is (model_arch + assert is_embedding_model(model_cls) is (model_arch in _EMBEDDING_MODELS) assert supports_multimodal(model_cls) is (model_arch diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 68969f41e344..eaa2b93eb333 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -2,16 +2,16 @@ SupportsPP, has_inner_state, supports_lora, supports_multimodal, supports_pp) from .interfaces_base import (VllmModelForEmbedding, - VllmModelForTextGeneration, supports_embedding, - supports_text_generation) + VllmModelForTextGeneration, is_embedding_model, + is_text_generation_model) from .registry import ModelRegistry __all__ = [ "ModelRegistry", "VllmModelForEmbedding", - "supports_embedding", + "is_embedding_model", "VllmModelForTextGeneration", - "supports_text_generation", + "is_text_generation_model", "HasInnerState", "has_inner_state", "SupportsLoRA", diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 298174fa0596..278dfc52078e 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,4 +1,3 @@ -import inspect from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, Protocol, Type, Union, overload, runtime_checkable) @@ -6,9 +5,9 @@ from typing_extensions import TypeIs from vllm.logger import init_logger +from vllm.utils import supports_kw if TYPE_CHECKING: - from vllm.attention import AttentionMetadata from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.sequence import IntermediateTensors @@ -142,9 +141,7 @@ def supports_lora( return result -def _supports_lora( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: +def _supports_lora(model: Union[Type[object], object]) -> bool: if isinstance(model, type): return isinstance(model, _SupportsLoRAType) @@ -175,10 +172,7 @@ def make_empty_intermediate_tensors( def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", + *, intermediate_tensors: Optional["IntermediateTensors"], ) -> Union[torch.Tensor, "IntermediateTensors"]: """ @@ -205,10 +199,7 @@ def make_empty_intermediate_tensors( def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", + *, intermediate_tensors: Optional["IntermediateTensors"], ) -> Union[torch.Tensor, "IntermediateTensors"]: ... @@ -257,24 +248,19 @@ def supports_pp( return supports_attributes and supports_inspect -def _supports_pp_attributes( - model: Union[Type[object], object], -) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: +def _supports_pp_attributes(model: Union[Type[object], object]) -> bool: if isinstance(model, type): return isinstance(model, _SupportsPPType) return isinstance(model, SupportsPP) -def _supports_pp_inspect( - model: Union[Type[object], object], -) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: +def _supports_pp_inspect(model: Union[Type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False - forward_params = inspect.signature(model_forward).parameters - return "intermediate_tensors" in forward_params + return supports_kw(model_forward, "intermediate_tensors") @runtime_checkable diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 6577488f1235..4d2a42d73903 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -2,9 +2,13 @@ overload, runtime_checkable) import torch +import torch.nn as nn from transformers import PretrainedConfig from typing_extensions import TypeIs, TypeVar +from vllm.logger import init_logger +from vllm.utils import supports_kw + if TYPE_CHECKING: from vllm.attention import AttentionMetadata from vllm.config import CacheConfig @@ -14,22 +18,28 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags -# for the base interfaces to avoid breaking OOT registration for existing models -# that don't inherit from the base interface classes +logger = init_logger(__name__) + +# The type of HF config +C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True) # The type of hidden states # Currently, T = torch.Tensor for all models except for Medusa # which has T = List[torch.Tensor] T = TypeVar("T", default=torch.Tensor) +T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) + +# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags +# for the base interfaces to avoid breaking OOT registration for existing models +# that don't inherit from the base interface classes @runtime_checkable -class VllmModelForTextGeneration(Protocol[T]): +class VllmModel(Protocol[C_co, T_co]): def __init__( self, - config: PretrainedConfig, + config: C_co, *, cache_config: Optional["CacheConfig"], quant_config: Optional["QuantizationConfig"], @@ -42,9 +52,54 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: "AttentionMetadata", - ) -> T: + ) -> T_co: ... + +def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: + model_init = model.__init__ + vllm_kws = ("cache_config", "quant_config") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_init, kw)) + + if missing_kws and isinstance(model, nn.Module): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its initializer: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_forward, kw)) + + if missing_kws and isinstance(model, nn.Module): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its initializer: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +def is_vllm_model(model: Union[Type[object], object]) -> bool: + return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + + +@runtime_checkable +class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]): + def compute_logits( self, hidden_states: T, @@ -63,21 +118,24 @@ def sample( @overload -def supports_text_generation( +def is_text_generation_model( model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]: ... @overload -def supports_text_generation( +def is_text_generation_model( model: object) -> TypeIs[VllmModelForTextGeneration]: ... -def supports_text_generation( +def is_text_generation_model( model: Union[Type[object], object], ) -> Union[TypeIs[Type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration]]: + if not is_vllm_model(model): + return False + if isinstance(model, type): return isinstance(model, VllmModelForTextGeneration) @@ -85,25 +143,7 @@ def supports_text_generation( @runtime_checkable -class VllmModelForEmbedding(Protocol[T]): - - def __init__( - self, - config: PretrainedConfig, - *, - cache_config: Optional["CacheConfig"], - quant_config: Optional["QuantizationConfig"], - ) -> None: - ... - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", - ) -> T: - ... +class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): def pooler( self, @@ -115,19 +155,22 @@ def pooler( @overload -def supports_embedding( +def is_embedding_model( model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]: ... @overload -def supports_embedding(model: object) -> TypeIs[VllmModelForEmbedding]: +def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: ... -def supports_embedding( +def is_embedding_model( model: Union[Type[object], object], ) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: + if not is_vllm_model(model): + return False + if isinstance(model, type): return isinstance(model, VllmModelForEmbedding) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d9629f56c7db..46c69f17f447 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -12,7 +12,7 @@ from vllm.utils import is_hip from .interfaces import supports_multimodal, supports_pp -from .interfaces_base import supports_embedding, supports_text_generation +from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) @@ -328,7 +328,7 @@ def is_text_generation_model(architectures: Union[str, List[str]]) -> bool: logger.warning("No model architectures are specified") is_txt_gen = partial(ModelRegistry._check_stateless, - supports_text_generation, + is_text_generation_model, default=False) return any(is_txt_gen(arch) for arch in architectures) @@ -341,7 +341,7 @@ def is_embedding_model(architectures: Union[str, List[str]]) -> bool: logger.warning("No model architectures are specified") is_emb = partial(ModelRegistry._check_stateless, - supports_embedding, + is_embedding_model, default=False) return any(is_emb(arch) for arch in architectures) diff --git a/vllm/utils.py b/vllm/utils.py index e44365fa2499..bec2f951d69d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1277,6 +1277,15 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: + params = inspect.signature(callable).parameters + if kw_name in params: + return True + + return any(param.kind == inspect.Parameter.VAR_KEYWORD + for param in params.values()) + + def get_allowed_kwarg_only_overrides( callable: Callable[..., object], overrides: Optional[Dict[str, Any]], From 7ecd6f4a493fbc5b8f2292f642347fb1dd20f80a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 7 Oct 2024 03:22:27 +0000 Subject: [PATCH 5/8] format --- tests/models/test_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 2a09f826cc00..a2194fa15f90 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -4,8 +4,8 @@ import torch.cuda from vllm.model_executor.models import (is_embedding_model, - supports_multimodal, - is_text_generation_model) + is_text_generation_model, + supports_multimodal) from vllm.model_executor.models.registry import (_EMBEDDING_MODELS, _MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, From 179781a64efbba7361e8535230d5a9aab55878f5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 7 Oct 2024 03:24:18 +0000 Subject: [PATCH 6/8] Add overloads --- vllm/model_executor/models/interfaces_base.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4d2a42d73903..9c918125a9c3 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -93,7 +93,19 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: return len(missing_kws) == 0 -def is_vllm_model(model: Union[Type[object], object]) -> bool: +@overload +def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]: + ... + + +@overload +def is_vllm_model(model: object) -> TypeIs[VllmModel]: + ... + + +def is_vllm_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]: return _check_vllm_model_init(model) and _check_vllm_model_forward(model) From 7d88a10f5deb47bf95d1f8bbee2c2de3ce9be2a3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 7 Oct 2024 03:32:20 +0000 Subject: [PATCH 7/8] Support PP --- .../my_gemma_embedding.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 0a7bbfdc23cf..1d61f6b74f52 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -1,15 +1,34 @@ +from typing import List, Optional, Union + import torch +from vllm.attention import AttentionMetadata from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel +from vllm.sequence import IntermediateTensors class MyGemma2Embedding(Gemma2EmbeddingModel): - def forward(self, *args, **kwargs) -> torch.Tensor: - hidden_states = super().forward(*args, **kwargs) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = super().forward( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) - # We assume PP isn't used in the test - assert isinstance(hidden_states, torch.Tensor) + if isinstance(hidden_states, IntermediateTensors): + return hidden_states # Return all-zero embeddings return torch.zeros_like(hidden_states) From 62bbd7a70de7f72cce08025fac6c1bb478d4e092 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 7 Oct 2024 04:03:51 +0000 Subject: [PATCH 8/8] Fix --- vllm/model_executor/models/interfaces_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 9c918125a9c3..8d2d422f9891 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -62,7 +62,8 @@ def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_init, kw)) - if missing_kws and isinstance(model, nn.Module): + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): logger.warning( "The model (%s) is missing " "vLLM-specific keywords from its initializer: %s", @@ -82,7 +83,8 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw)) - if missing_kws and isinstance(model, nn.Module): + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): logger.warning( "The model (%s) is missing " "vLLM-specific keywords from its initializer: %s",