Skip to content

Commit

Permalink
[Model] Explicit interface for vLLM models and support OOT embedding …
Browse files Browse the repository at this point in the history
…models (#9108)
  • Loading branch information
DarkLight1337 authored Oct 7, 2024
1 parent 18b296f commit 8c6de96
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 37 deletions.
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def num_gpus_available():
temp_dir = tempfile.gettempdir()
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")


@pytest.fixture
Expand Down Expand Up @@ -909,3 +910,22 @@ def dummy_llava_path():
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_llava_path


@pytest.fixture
def dummy_gemma2_embedding_path():
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
if not os.path.exists(_dummy_gemma2_embedding_path):
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
local_dir=_dummy_gemma2_embedding_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyGemma2Embedding"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_gemma2_embedding_path
18 changes: 15 additions & 3 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm import LLM, SamplingParams
from vllm import LLM, PoolingParams, SamplingParams
from vllm.assets.image import ImageAsset

from ..utils import fork_new_process_for_each_test
Expand All @@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path):


@fork_new_process_for_each_test
def test_oot_registration(dummy_opt_path):
def test_oot_registration_text_generation(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
Expand All @@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path):
assert rest == ""


@fork_new_process_for_each_test
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = PoolingParams()
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.encode(prompts, sampling_params)

for output in outputs:
assert all(v == 0 for v in output.outputs.embedding)


image = ImageAsset("cherry_blossom").pil_image.convert("RGB")


@fork_new_process_for_each_test
def test_oot_multimodal_registration(dummy_llava_path):
def test_oot_registration_multimodal(dummy_llava_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = [{
"prompt": "What's in the image?<image>",
Expand Down
24 changes: 22 additions & 2 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import pytest
import torch.cuda

from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
ModelRegistry)
from vllm.platforms import current_platform

from ..utils import fork_new_process_for_each_test
Expand All @@ -12,7 +19,20 @@
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls(model_arch)
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)

if model_arch in _SPECULATIVE_DECODING_MODELS:
pass # Ignore these models which do not have a unified format
else:
assert is_text_generation_model(model_cls) is (
model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS)

assert is_embedding_model(model_cls) is (model_arch
in _EMBEDDING_MODELS)

assert supports_multimodal(model_cls) is (model_arch
in _MULTIMODAL_MODELS)


@fork_new_process_for_each_test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def register():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)

# Test passing lazy model
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"MyGemma2Embedding",
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
)

if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava")
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import List, Optional, Union

import torch

from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors


class MyGemma2Embedding(Gemma2EmbeddingModel):

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)

if isinstance(hidden_states, IntermediateTensors):
return hidden_states

# Return all-zero embeddings
return torch.zeros_like(hidden_states)
7 changes: 7 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding,
VllmModelForTextGeneration, is_embedding_model,
is_text_generation_model)
from .registry import ModelRegistry

__all__ = [
"ModelRegistry",
"VllmModelForEmbedding",
"is_embedding_model",
"VllmModelForTextGeneration",
"is_text_generation_model",
"HasInnerState",
"has_inner_state",
"SupportsLoRA",
Expand Down
28 changes: 7 additions & 21 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import inspect
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)

import torch
from typing_extensions import TypeIs

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -142,9 +141,7 @@ def supports_lora(
return result


def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
def _supports_lora(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)

Expand Down Expand Up @@ -175,10 +172,7 @@ def make_empty_intermediate_tensors(

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
Expand All @@ -205,10 +199,7 @@ def make_empty_intermediate_tensors(

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
Expand Down Expand Up @@ -257,24 +248,19 @@ def supports_pp(
return supports_attributes and supports_inspect


def _supports_pp_attributes(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)

return isinstance(model, SupportsPP)


def _supports_pp_inspect(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False

forward_params = inspect.signature(model_forward).parameters
return "intermediate_tensors" in forward_params
return supports_kw(model_forward, "intermediate_tensors")


@runtime_checkable
Expand Down
Loading

0 comments on commit 8c6de96

Please sign in to comment.