Skip to content

Commit

Permalink
ExternalMultiModalDataDict
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
  • Loading branch information
xwjiang2010 committed Jun 28, 2024
1 parent f84b793 commit a934663
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 17 deletions.
6 changes: 3 additions & 3 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import (TYPE_CHECKING, Dict, List, Literal, Optional, Sequence,
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)

from typing_extensions import NotRequired

if TYPE_CHECKING:
from vllm.multimodal import EXTERNAL_MM_DATA_TYPE, MultiModalData
from vllm.multimodal import ExternalMultiModalDataDict, MultiModalData


class ParsedText(TypedDict):
Expand Down Expand Up @@ -136,7 +136,7 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""

multi_modal_data: NotRequired[Optional[Dict[str, "EXTERNAL_MM_DATA_TYPE"]]]
multi_modal_data: NotRequired[Optional["ExternalMultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
Expand Down
4 changes: 2 additions & 2 deletions vllm/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import EXTERNAL_MM_DATA_TYPE, MultiModalData, MultiModalPlugin
from .base import ExternalMultiModalDataDict, MultiModalData, MultiModalPlugin
from .registry import MultiModalRegistry

MULTIMODAL_REGISTRY = MultiModalRegistry()
Expand All @@ -15,5 +15,5 @@
"MultiModalPlugin",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
"EXTERNAL_MM_DATA_TYPE",
"ExternalMultiModalDataDict",
]
18 changes: 13 additions & 5 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Tuple,
Type, TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, Optional,
Tuple, Type, TypedDict, TypeVar, Union)

from vllm.config import ModelConfig
from vllm.inputs import InputContext
Expand All @@ -18,6 +18,8 @@ class MultiModalData:
"""
Base class that contains multi-modal data.
This is for internal use.
To add a new modality, add a new file under ``multimodal`` directory.
In this new file, subclass :class:`~MultiModalData` and
Expand All @@ -34,7 +36,14 @@ class MultiModalData:
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])

EXTERNAL_MM_DATA_TYPE = Union["Image.Image", "torch.Tensor"]

class ExternalMultiModalDataBuiltins(TypedDict, total=False):
image: Union["Image.Image", "torch.Tensor"]


ExternalMultiModalDataDict = Union[ExternalMultiModalDataBuiltins, Dict[str,
Any]]

MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
Expand Down Expand Up @@ -65,8 +74,7 @@ def get_internal_data_type(self) -> Type[D]:
raise NotImplementedError

@abstractmethod
def get_external_data_type(
self) -> Tuple[str, Type[EXTERNAL_MM_DATA_TYPE]]:
def get_external_data_type(self) -> Tuple[str, Type[Any]]:
"""The data type that this plugin handles.
For `LLM.generate(multi_modal_data={"key": value})` will
Expand Down
7 changes: 3 additions & 4 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import functools
from typing import Any, Dict, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Optional, Sequence, Type, TypeVar, Union

from torch import nn

from vllm.config import ModelConfig
from vllm.logger import init_logger

from .base import (EXTERNAL_MM_DATA_TYPE, MultiModalData,
from .base import (ExternalMultiModalDataDict, MultiModalData,
MultiModalInputMapper, MultiModalPlugin)
from .image import ImageData, ImagePlugin

Expand Down Expand Up @@ -113,8 +113,7 @@ def register_image_input(self,
return self.register_input_mapper(ImageData, mapper)

def map_input(self, model_config: ModelConfig,
data: Union[MultiModalData, Dict[str,
EXTERNAL_MM_DATA_TYPE]]):
data: Union[MultiModalData, ExternalMultiModalDataDict]):
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model.
Expand Down
6 changes: 3 additions & 3 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal import EXTERNAL_MM_DATA_TYPE, MultiModalData
from vllm.multimodal import ExternalMultiModalDataDict, MultiModalData
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics


Expand Down Expand Up @@ -258,7 +258,7 @@ def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"]

@property
def multi_modal_data(self) -> Dict[str, "EXTERNAL_MM_DATA_TYPE"]:
def multi_modal_data(self) -> "ExternalMultiModalDataDict":
return self.inputs.get("multi_modal_data") or {}

@property
Expand Down Expand Up @@ -617,7 +617,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[Dict[str, "EXTERNAL_MM_DATA_TYPE"]] = None,
multi_modal_data: Optional["ExternalMultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
) -> None:
Expand Down

0 comments on commit a934663

Please sign in to comment.