diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 8e4be08f3aba..9a20a8df07a9 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,7 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_xlmroberta.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # online inference docker exec cpu-test bash -c " diff --git a/examples/hf_bge.py b/examples/hf_bge.py new file mode 100644 index 000000000000..0711b78feb3e --- /dev/null +++ b/examples/hf_bge.py @@ -0,0 +1,33 @@ +from typing import List, Tuple, Union + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +model_name_or_path = "BAAI/bge-reranker-base" +cache_dir = None +max_length = 512 + +sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]] = \ + [("hello world", "nice to meet you"), ("head north", "head south")] +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, + cache_dir=cache_dir) +# XLMRobertaForSequenceClassification +model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, + cache_dir=cache_dir) +model = model.to("cuda") +model.eval() + +inputs = tokenizer( + sentence_pairs, + padding=True, + truncation=True, + return_tensors='pt', + max_length=max_length, +).to("cuda") + +all_scores = [] +with torch.no_grad(): + logits = model(**inputs, return_dict=True).logits + scores = logits.view(-1, ).float() + all_scores.extend(scores.cpu().numpy().tolist()) +print(all_scores) diff --git a/examples/offline_inference_xlmroberta.py b/examples/offline_inference_xlmroberta.py new file mode 100644 index 000000000000..93aea17729e9 --- /dev/null +++ b/examples/offline_inference_xlmroberta.py @@ -0,0 +1,31 @@ +from typing import List, Tuple, Union + +from transformers import AutoTokenizer + +from vllm import LLM + +model = "BAAI/bge-reranker-base" +llm = LLM(model=model, tensor_parallel_size=1) + +prompt = "this is a useless prompt." +sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]] = \ + [("hello world", "nice to meet you"), ("head north", "head south")] +tokenizer = AutoTokenizer.from_pretrained(model, cache_dir=None) + +inputs = tokenizer( + sentence_pairs, + padding=True, + truncation=True, + return_tensors='pt', + max_length=512, +).to("cuda") +outputs = llm.process([{ + "prompt": prompt, + "multi_modal_data": { + "xlmroberta": inputs, + } +}], + use_tqdm=False) + +for output in outputs: + print(output.outputs.result) diff --git a/tests/conftest.py b/tests/conftest.py index 6e033e76964b..08ea80bc3e50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,8 @@ import tempfile from collections import UserList from enum import Enum -from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, - TypeVar, Union) +from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple, + TypedDict, TypeVar, Union) import pytest import torch @@ -25,8 +25,9 @@ from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs, + TextPrompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs @@ -201,6 +202,7 @@ def __init__( is_embedding_model: bool = False, is_vision_model: bool = False, is_encoder_decoder_model: bool = False, + is_simple_model: bool = False, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding] = identity, ) -> None: @@ -221,6 +223,9 @@ def __init__( auto_cls = AutoModelForVision2Seq elif is_encoder_decoder_model: auto_cls = AutoModelForSeq2SeqLM + elif is_simple_model: + from transformers import AutoModelForSequenceClassification + auto_cls = AutoModelForSequenceClassification else: auto_cls = AutoModelForCausalLM @@ -513,6 +518,17 @@ def generate_encoder_decoder_greedy_logprobs_limit( def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) + def process( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + with torch.no_grad(): + req_outputs = self.model(input_ids, + attention_mask, + return_dict=True) + return req_outputs + def __enter__(self): return self @@ -711,6 +727,14 @@ def encode(self, prompts: List[str]) -> List[List[float]]: outputs.append(embedding) return outputs + def process( + self, + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]]] = None, + ) -> torch.Tensor: + req_outputs = self.model.process(prompts) + return req_outputs + def __enter__(self): return self diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 3783b7cd66a6..2f44f432fdf4 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -19,7 +19,8 @@ class MockModelConfig: tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None - embedding_mode = False + # refer vllm.model_executor.models.ModelMode + model_mode = False @dataclass diff --git a/tests/models/test_xlmroberta.py b/tests/models/test_xlmroberta.py new file mode 100644 index 000000000000..e601605c8f40 --- /dev/null +++ b/tests/models/test_xlmroberta.py @@ -0,0 +1,65 @@ +from typing import List, Optional, Tuple, Type, Union + +import pytest +import torch +from transformers import AutoTokenizer + +from ..conftest import HfRunner, VllmRunner + +models = ["BAAI/bge-reranker-base"] + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + model: str, + *, + dtype: str, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm.""" + + prompt = "this is a useless prompt." + sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]] = \ + [("hello world", "nice to meet you"), ("head north", "head south")] + tokenizer = AutoTokenizer.from_pretrained(model, cache_dir=None) + inputs = tokenizer( + sentence_pairs, + padding=True, + truncation=True, + return_tensors='pt', + max_length=512, + ).to("cuda") + + with vllm_runner(model, + dtype=dtype, + max_model_len=512, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.process([{ + "prompt": prompt, + "multi_modal_data": { + "xlmroberta": inputs, + } + }]) + + with hf_runner(model, dtype=dtype, is_simple_model=True) as hf_model: + hf_outputs = hf_model.process(**inputs) + + print(vllm_outputs[0].outputs.result, hf_outputs.logits.view(-1, )) + assert torch.allclose(vllm_outputs[0].outputs.result, + hf_outputs.logits.view(-1, )) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models(hf_runner, vllm_runner, model, dtype: str) -> None: + run_test( + hf_runner, + vllm_runner, + model, + dtype=dtype, + tensor_parallel_size=1, + ) diff --git a/vllm/config.py b/vllm/config.py index a5a9984a0114..7b98c6dd7579 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models import ModelMode, ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config @@ -167,6 +167,8 @@ def __init__( code_revision, rope_scaling, rope_theta) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + architectures = getattr(self.hf_config, "architectures", []) + self.model_mode = ModelRegistry.get_model_mode(architectures) # Choose a default enforce_eager value if the user did not specify # a value (enforce_eager is None) @@ -217,7 +219,6 @@ def __init__( limit_mm_per_prompt) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() - self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() @@ -244,11 +245,6 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode - def _verify_embedding_mode(self) -> None: - architectures = getattr(self.hf_config, "architectures", []) - self.embedding_mode = any( - ModelRegistry.is_embedding_model(arch) for arch in architectures) - def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is None: @@ -496,16 +492,6 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config - @property - def is_encoder_decoder_model(self) -> bool: - """Extract the HF encoder/decoder model flag.""" - return getattr(self.hf_config, "is_encoder_decoder", False) - - @property - def is_embedding_model(self) -> bool: - """Extract the embedding model flag.""" - return self.embedding_mode - class CacheConfig: """Configuration for the KV cache. @@ -860,7 +846,8 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. - embedding_mode: Whether the running model is for embedding. + model_mode: one of [DECODER, ENCODER, ENCODER_DECODER, EMBEDDING, + SIMPLE] preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than @@ -882,7 +869,7 @@ def __init__(self, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, - embedding_mode: Optional[bool] = False, + model_mode: ModelMode = ModelMode.DECODER, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, send_delta_data: bool = False) -> None: @@ -893,14 +880,19 @@ def __init__(self, # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. self.max_num_batched_tokens = 512 - elif embedding_mode: - # For embedding, choose specific value for higher throughput - self.max_num_batched_tokens = max( - max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + max_num_batched_tokens = max(max_model_len, 2048) + max_num_batched_tokens_for_mode = \ + ModelMode.get_model_max_num_batched_tokens(model_mode) + if max_num_batched_tokens_for_mode is not None: + max_num_batched_tokens = max( + max_num_batched_tokens, + max_num_batched_tokens_for_mode) + + self.max_num_batched_tokens = max_num_batched_tokens + if enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", @@ -912,7 +904,7 @@ def __init__(self, self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill - self.embedding_mode = embedding_mode + self.model_mode = model_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps self.send_delta_data = send_delta_data diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795b..e966754f52dc 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -23,25 +23,6 @@ class AllocStatus(enum.Enum): class BlockSpaceManager(ABC): - @staticmethod - def get_block_space_manager_class(version: str): - version = version.lower() - - if version == "v1": - from vllm.core.block_manager_v1 import BlockSpaceManagerV1 - return BlockSpaceManagerV1 - - if version == "v2": - from vllm.core.block_manager_v2 import BlockSpaceManagerV2 - return BlockSpaceManagerV2 - - if version == "embedding": - from vllm.core.embedding_model_block_manager import ( - EmbeddingModelBlockSpaceManager) - return EmbeddingModelBlockSpaceManager - - raise ValueError(f"Unknown version {version=}") - @abstractmethod def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 802359d2283f..e8e0ef593124 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -6,8 +6,8 @@ from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.config import CacheConfig, LoRAConfig, ModelMode, SchedulerConfig +from vllm.core.interfaces import AllocStatus from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest @@ -307,14 +307,9 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - version = "v1" - if self.scheduler_config.use_v2_block_manager: - version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" - - BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version) + BlockSpaceManagerImpl = ModelMode.get_block_space_manager_impl( + self.scheduler_config.use_v2_block_manager, + self.scheduler_config.model_mode) num_gpu_blocks = cache_config.num_gpu_blocks if num_gpu_blocks: diff --git a/vllm/core/simple_model_block_manager.py b/vllm/core/simple_model_block_manager.py new file mode 100644 index 000000000000..288034448e33 --- /dev/null +++ b/vllm/core/simple_model_block_manager.py @@ -0,0 +1,83 @@ +from typing import List, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup + + +class SimpleModelBlockSpaceManager(BlockSpaceManager): + """An simple version of BlockSpaceManager for use in environments + with non-[decoder,embedding] models where block management is not required. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return None # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + return None # type: ignore + + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + pass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8fca2cc04995..5e5c321ffdf8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -903,7 +903,7 @@ def create_engine_config(self, ) -> EngineConfig: num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, - embedding_mode=model_config.embedding_mode, + model_mode=model_config.model_mode, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fcf45a38b942..6eb1d614c427 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, + ModelMode, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, @@ -29,16 +29,19 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.utils import (is_embedding_model_config, + is_encoder_decoder_model_config, + is_simple_model_config) from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, - RequestOutputFactory) + RequestOutputFactory, SimpleRequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - PoolerOutput, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, - SequenceStatus) +from vllm.sequence import (ExecuteModelRequest, PoolerOutput, SamplerOutput, + Sequence, SequenceGroup, SequenceGroupMetadata, + SequenceStatus, SimpleOutput) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -67,7 +70,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: return config.to_diff_dict() -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput, SimpleRequestOutput) PromptComponents = Tuple[Optional[str], List[int], Optional[MultiModalDataDict]] @@ -278,7 +281,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: observability_config=self.observability_config, ) - if not self.model_config.embedding_mode: + if ModelRegistry.need_initialize_kv_caches( + self.model_config.model_mode): self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. @@ -1160,9 +1164,7 @@ def has_unfinished_requests_for_virtual_engine( def _process_sequence_group_outputs( self, seq_group: SequenceGroup, - outputs: List[EmbeddingSequenceGroupOutput], ) -> None: - seq_group.embeddings = outputs[0].embeddings for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_STOPPED @@ -1171,11 +1173,13 @@ def _process_sequence_group_outputs( def _process_model_outputs( self, - output: GenericSequence[Union[SamplerOutput, PoolerOutput]], + output: GenericSequence[Union[SamplerOutput, PoolerOutput, + SimpleOutput]], scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> List[Union[RequestOutput, EmbeddingRequestOutput, + SimpleRequestOutput]]: """Apply the model output to the sequences in the scheduled seq groups. Returns RequestOutputs that can be returned to the client. @@ -1211,8 +1215,13 @@ def _process_model_outputs( else: seq_group.metrics.model_execute_time = ( o.model_execute_time) - if self.model_config.embedding_mode: - self._process_sequence_group_outputs(seq_group, outputs) + if self.model_config.model_mode is ModelMode.EMBEDDING: + seq_group.embeddings = outputs[0].embeddings + self._process_sequence_group_outputs(seq_group) + continue + if self.model_config.model_mode is ModelMode.SIMPLE: + seq_group.result = outputs[0].result + self._process_sequence_group_outputs(seq_group) continue self.output_processor.process_prompt_logprob(seq_group, outputs) @@ -1224,8 +1233,8 @@ def _process_model_outputs( scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] + request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput, + SimpleRequestOutput]] = [] for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) @@ -1236,7 +1245,10 @@ def _process_model_outputs( request_outputs.append(request_output) return request_outputs - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def step( + self + ) -> List[Union[RequestOutput, EmbeddingRequestOutput, + SimpleRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png @@ -1617,7 +1629,10 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: metrics.model_execute_time) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return is_encoder_decoder_model_config(self.model_config) def is_embedding_model(self): - return self.model_config.is_embedding_model + return is_embedding_model_config(self.model_config) + + def is_simple_model(self): + return is_simple_model_config(self.model_config) diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 57cc33d91118..a9e61737fc16 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -2,11 +2,13 @@ from typing import Sequence as GenericSequence from typing import Union -from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput +from vllm.sequence import (PoolerOutput, SamplerOutput, SequenceGroupOutput, + SimpleOutput) def create_output_by_sequence_group( - outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], + outputs: GenericSequence[Union[SamplerOutput, PoolerOutput, + SimpleOutput]], num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ecd6dc64d343..1a338d70ceea 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -4,6 +4,7 @@ from tqdm.auto import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.config import ModelMode from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, @@ -16,7 +17,8 @@ from vllm.model_executor.guided_decoding import ( GuidedDecodingRequest, get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + SimpleRequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -312,7 +314,9 @@ def generate( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - if self.llm_engine.model_config.embedding_mode: + if self.llm_engine.model_config.model_mode not in [ + ModelMode.DECODER, ModelMode.ENCODER_DECODER + ]: raise ValueError( "LLM.generate() is only supported for (conditional) generation " "models (XForCausalLM, XForConditionalGeneration).") @@ -520,7 +524,8 @@ def encode( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - if not self.llm_engine.model_config.embedding_mode: + + if self.llm_engine.model_config.model_mode is not ModelMode.EMBEDDING: raise ValueError( "LLM.encode() is only supported for embedding models (XModel)." ) @@ -547,6 +552,43 @@ def encode( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) + def process( + self, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[SimpleRequestOutput]: + """Processing the simple model, like XLMRoberta* + + Args: + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + + Returns: + A list of `SimpleRequestOutput` objects containing the + generated simple result in the same order as the input data. + + Note: + Only ``inputs`` reserved for simple model. + """ + if self.llm_engine.model_config.model_mode is not ModelMode.SIMPLE: + raise ValueError( + "LLM.process() is only supported for simple models.") + + pooling_params = PoolingParams() + + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=None, + ) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, SimpleRequestOutput) + # LEGACY def _convert_v1_inputs( self, @@ -712,3 +754,6 @@ def _is_encoder_decoder_model(self): def _is_embedding_model(self): return self.llm_engine.is_embedding_model() + + def _is_simple_model(self): + return self.llm_engine.is_simple_model() \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d79238e08d54..ba7240d8c795 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -43,6 +43,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger +from vllm.model_executor.models import ModelMode from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path from vllm.version import __version__ as VLLM_VERSION @@ -70,7 +71,7 @@ def model_is_embedding(model_name: str, trust_remote_code: bool, trust_remote_code=trust_remote_code, quantization=quantization, seed=0, - dtype="auto").embedding_mode + dtype="auto").model_mode == ModelMode.EMBEDDING @asynccontextmanager diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 0dc3c3bc7d15..7f8ae5f90e89 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -7,7 +7,7 @@ import numpy as np from fastapi import Request -from vllm.config import ModelConfig +from vllm.config import ModelConfig, ModelMode from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, @@ -71,7 +71,9 @@ def __init__( lora_modules=None, prompt_adapters=None, request_logger=request_logger) - self._enabled = self._check_embedding_mode(model_config.embedding_mode) + + self._check_embedding_mode(model_config.model_mode) + self._enabled = model_config.model_mode == ModelMode.EMBEDDING async def create_embedding( self, @@ -175,10 +177,9 @@ async def create_embedding( return response - def _check_embedding_mode(self, embedding_mode: bool): - if not embedding_mode: + def _check_embedding_mode(self, model_mode: ModelMode): + if model_mode is not ModelMode.EMBEDDING: logger.warning( - "embedding_mode is False. Embedding API will not work.") + "model_mode is not EMBEDDING. Embedding API will not work.") else: logger.info("Activating the server engine with embedding enabled.") - return embedding_mode diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index b009ad8c882d..fb156cd36c92 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -15,7 +15,6 @@ import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -459,6 +458,8 @@ def tensorize_vllm_model(engine_args: EngineArgs, ) as stream: stream.write(encryption_params.key) + # Avoid circular import, move 'import LLMEngine' here + from vllm.engine.llm_engine import LLMEngine engine = LLMEngine.from_engine_args(engine_args) if tensorizer_config._is_sharded: # if the engine is a distributed engine (for tensor parallel) then each diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 32cafa845a6e..c668327b86a3 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,3 +1,4 @@ +import enum import functools import importlib from typing import Dict, List, Optional, Tuple, Type @@ -89,13 +90,87 @@ "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), } +_SIMPLE_MODELS = { + "XLMRobertaForSequenceClassification": + ("xlmroberta", "XLMRobertaForSequenceClassification"), +} + _MODELS = { **_GENERATION_MODELS, **_EMBEDDING_MODELS, **_MULTIMODAL_MODELS, **_CONDITIONAL_GENERATION_MODELS, + **_SIMPLE_MODELS } +_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_SIMPLE_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 + + +class ModelMode(enum.Enum): + """ + 1. DECODER: decoder model, like GPT2* + 2. ENCODER: encoder model, like BERT* + 3. EMBEDDING: embedding model, like MistralModel + 4. ENCODER_DECODER: encoder-decoder model, like BART* + 5. SIMPLE: simple model, like XLMRoberta* + """ + DECODER = enum.auto() + ENCODER = enum.auto() + ENCODER_DECODER = enum.auto() + EMBEDDING = enum.auto() + SIMPLE = enum.auto() + + @staticmethod + def get_model_runner_cls(model_mode: "ModelMode"): + + if model_mode == ModelMode.EMBEDDING: + from vllm.worker.embedding_model_runner import EmbeddingModelRunner + return EmbeddingModelRunner + + if model_mode == ModelMode.SIMPLE: + from vllm.worker.simple_model_runner import SimpleModelRunner + return SimpleModelRunner + + if model_mode == ModelMode.ENCODER_DECODER: + from vllm.worker.enc_dec_model_runner import ( + EncoderDecoderModelRunner) + return EncoderDecoderModelRunner + + from vllm.worker.model_runner import ModelRunner + return ModelRunner + + @staticmethod + def get_block_space_manager_impl(use_v2_block_manager: bool, + model_mode: "ModelMode"): + + if use_v2_block_manager: + from vllm.core.block_manager_v2 import BlockSpaceManagerV2 + return BlockSpaceManagerV2 + + if model_mode == ModelMode.EMBEDDING: + from vllm.core.embedding_model_block_manager import ( + EmbeddingModelBlockSpaceManager) + return EmbeddingModelBlockSpaceManager + + if model_mode == ModelMode.SIMPLE: + from vllm.core.simple_model_block_manager import ( + SimpleModelBlockSpaceManager) + return SimpleModelBlockSpaceManager + + from vllm.core.block_manager_v1 import BlockSpaceManagerV1 + return BlockSpaceManagerV1 + + @staticmethod + def get_model_max_num_batched_tokens( + model_mode: "ModelMode") -> Optional[int]: + if model_mode == ModelMode.EMBEDDING: + return _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS + if model_mode == ModelMode.SIMPLE: + return _SIMPLE_MODEL_MAX_NUM_BATCHED_TOKENS + return None + + # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} @@ -181,8 +256,27 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): _OOT_MODELS[model_arch] = model_cls @staticmethod - def is_embedding_model(model_arch: str) -> bool: - return model_arch in _EMBEDDING_MODELS + def get_model_mode(architectures: List[str]) -> ModelMode: + + if any(arch in _EMBEDDING_MODELS for arch in architectures): + return ModelMode.EMBEDDING + + if any(arch in _CONDITIONAL_GENERATION_MODELS + for arch in architectures): + return ModelMode.ENCODER_DECODER + + if any(arch in _SIMPLE_MODELS for arch in architectures): + return ModelMode.SIMPLE + + return ModelMode.DECODER + + @staticmethod + def need_initialize_kv_caches(model_mode: ModelMode) -> bool: + if model_mode == ModelMode.EMBEDDING: + return False + if model_mode == ModelMode.SIMPLE: + return False + return True @staticmethod def is_multimodal_model(model_arch: str) -> bool: @@ -191,7 +285,7 @@ def is_multimodal_model(model_arch: str) -> bool: # use `supports_multimodal` to determine if a model is multimodal # model_cls = ModelRegistry._try_load_model_cls(model_arch) # from vllm.model_executor.models.interfaces import supports_multimodal - return model_arch in _MULTIMODAL_MODELS + return model_arch in _MULTIMODAL_MODELS or model_arch in _SIMPLE_MODELS __all__ = [ diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 91b414b1fd91..80993b961088 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Dict, Iterable, List, Optional, Protocol, Tuple import torch @@ -9,7 +10,7 @@ SchedulerConfig) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.loader import build_model -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models import ModelMode, ModelRegistry from vllm.multimodal import BatchedTensors from vllm.utils import is_pin_memory_available @@ -227,3 +228,27 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: if name.startswith(missing_layer_name): return True return False + + +@lru_cache(maxsize=None) +def is_encoder_decoder_model_config(model_config) -> bool: + """Check model is encoder-decoder model or not.""" + + return model_config is not None and \ + model_config.model_mode == ModelMode.ENCODER_DECODER + + +@lru_cache(maxsize=None) +def is_embedding_model_config(model_config) -> bool: + """Check model is embedding model or not""" + + return model_config is not None and \ + model_config.model_mode == ModelMode.EMBEDDING + + +@lru_cache(maxsize=None) +def is_simple_model_config(model_config) -> bool: + """Check model is simple model or not""" + + return model_config is not None and \ + model_config.model_mode == ModelMode.SIMPLE diff --git a/vllm/model_executor/models/xlmroberta.py b/vllm/model_executor/models/xlmroberta.py new file mode 100644 index 000000000000..0180ae04427f --- /dev/null +++ b/vllm/model_executor/models/xlmroberta.py @@ -0,0 +1,950 @@ +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.xlm_roberta.configuration_xlm_roberta import ( + XLMRobertaConfig) +from transformers.pytorch_utils import (apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import SimpleOutput, SimpleSequenceGroupOutput + +logger = init_logger(__name__) + + +class XLMRobertaEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + # position_ids (1, len position emb) is contiguous in memory + # and exported when serialized + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False) + self.register_buffer("token_type_ids", + torch.zeros(self.position_ids.size(), + dtype=torch.long), + persistent=False) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. + # Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids( + input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds( + inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = \ + buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange(self.padding_idx + 1, + sequence_length + self.padding_idx + 1, + dtype=torch.long, + device=inputs_embeds.device) + return position_ids.unsqueeze(0).expand(input_shape) + + +class XLMRobertaSelfAttention(nn.Module): + + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None, + position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple " + f"of the number of attention " + f"heads ({config.num_attention_heads})") + + self.num_attention_heads = config.num_attention_heads + + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or \ + self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + + self.is_decoder = config.is_decoder + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.attention_head_size, + self.num_attention_heads, + bias=True, + quant_config=quant_config, + ) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + new_x_shape = x.size()[:-1] + (self.num_attention_heads // + tensor_model_parallel_world_size, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + qkv, _ = self.qkv_proj(hidden_states) + mixed_query_layer, _key_layer, _value_layer = qkv.chunk(chunks=3, + dim=-1) + + key_layer = self.transpose_for_scores(_key_layer) + value_layer = self.transpose_for_scores(_value_layer) + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to + # get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or \ + self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor( + key_length - 1, + dtype=torch.long, + device=hidden_states.device).view(-1, 1) + else: + position_ids_l = torch.arange( + query_length, + dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, + dtype=torch.long, + device=hidden_states.device).view( + 1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + \ + relative_position_scores_query + \ + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in + # XLMRobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size // tensor_model_parallel_world_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + if self.is_decoder: + outputs = outputs + (past_key_value, ) + return outputs + + +class XLMRobertaSelfOutput(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +XLM_ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": XLMRobertaSelfAttention, +} + + +class XLMRobertaAttention(nn.Module): + + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None, + position_embedding_type=None): + super().__init__() + self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[ + config._attn_implementation]( + config, + quant_config=quant_config, + position_embedding_type=position_embedding_type) + self.output = XLMRobertaSelfOutput(config, quant_config=quant_config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = \ + self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class XLMRobertaIntermediate(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.dense = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + ) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class XLMRobertaOutput(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.dense = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class XLMRobertaLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = XLMRobertaAttention(config, quant_config=quant_config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model" + "if cross attention is added") + self.crossattention = XLMRobertaAttention( + config, + quant_config=quant_config, + position_embedding_type="absolute") + self.intermediate = XLMRobertaIntermediate(config, + quant_config=quant_config) + self.output = XLMRobertaOutput(config, quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_attn_past_key_value = past_key_value[:2] if past_key_value \ + is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be " + " instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`") + + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class XLMRobertaEncoder(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ + XLMRobertaLayer(config, quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient " + "checkpointing. Setting `use_cache=False`...") + use_cache = False + else: + pass + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class XLMRobertaPooler(nn.Module): + + def __init__(self, config): + super().__init__() + # unused. + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class XLMRobertaPreTrainedModel(PreTrainedModel): + + config_class = XLMRobertaConfig + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses + # truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, + std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, + std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class XLMRobertaModel(XLMRobertaPreTrainedModel): + + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None, + add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = XLMRobertaEmbeddings(config) + self.encoder = XLMRobertaEncoder(config, quant_config=quant_config) + + self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None + + #self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of + {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], + BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None \ + else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds" + "at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, + attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None \ + else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = \ + self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = \ + buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=device) + + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to + # all heads. + extended_attention_mask: torch.Tensor = \ + self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = \ + encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, + device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or + # [num_hidden_layers x num_heads] + # and head_mask is converted to shape + # [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@MULTIMODAL_REGISTRY.register_xlmroberta_input_mapper() +class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel): + + def __init__(self, + config: XLMRobertaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.quant_config = quant_config + + self.roberta = XLMRobertaModel(config, + quant_config=quant_config, + add_pooling_layer=False) + self.classifier = XLMRobertaClassificationHead( + config, quant_config=quant_config) + + # Initialize weights and apply final processing + #self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + positions: Optional[torch.Tensor] = None, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> Optional[SimpleOutput]: + + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + logits = self.classifier(sequence_output) + + logits = logits.reshape([1, -1]) + + outputs = [SimpleSequenceGroupOutput(data) for data in logits] + + return SimpleOutput(outputs=outputs) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name: + continue + if name.startswith("decoder."): + name = "model." + name + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name.endswith(".position_ids") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +class XLMRobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.dense = ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=True, + quant_config=None, + ) + self.out_proj = RowParallelLinear( + config.hidden_size, + config.num_labels, + bias=True, + quant_config=None, + ) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x, _ = self.dense(x) + x = torch.tanh(x) + x, _ = self.out_proj(x) + return x + + +def create_position_ids_from_input_ids(input_ids, + padding_idx, + past_key_values_length=0): + + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + + past_key_values_length) * mask + return incremental_indices.long() + padding_idx diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index cd16cdcbd890..3461e8c08ef2 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -9,6 +9,7 @@ from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, MultiModalPlugin, MultiModalTokensCalc, NestedTensors) from .image import ImagePlugin +from .xlmroberta import XLMRobertaPlugin logger = init_logger(__name__) @@ -34,7 +35,7 @@ class MultiModalRegistry: :class:`~vllm.multimodal.MultiModalPlugin` for each modality. """ - DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin()) + DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), XLMRobertaPlugin()) def __init__( self, @@ -72,6 +73,17 @@ def _get_plugin(self, data_type_key: str): msg = f"Unknown multi-modal data type: {data_type_key}" raise NotImplementedError(msg) + def register_xlmroberta_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper for xlmroberta data to a model class. + + See :meth:`MultiModalPlugin.register_input_mapper` for more details. + """ + return self.register_input_mapper("xlmroberta", mapper) + def register_input_mapper( self, data_type_key: str, diff --git a/vllm/multimodal/xlmroberta.py b/vllm/multimodal/xlmroberta.py new file mode 100644 index 000000000000..ee9955243157 --- /dev/null +++ b/vllm/multimodal/xlmroberta.py @@ -0,0 +1,26 @@ +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger + +from .base import MultiModalInputs, MultiModalPlugin + +logger = init_logger(__name__) + + +class XLMRobertaPlugin(MultiModalPlugin): + + def get_data_key(self) -> str: + return "xlmroberta" + + def _default_input_mapper(self, ctx: InputContext, + data: object) -> MultiModalInputs: + + input_ids = data['input_ids'] # type: ignore + attention_mask = data['attention_mask'] # type: ignore + + return MultiModalInputs({ + "input_ids": input_ids, + "attention_mask": attention_mask + }) + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + raise NotImplementedError diff --git a/vllm/outputs.py b/vllm/outputs.py index e091b576f597..d22159a67e2b 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,6 +4,8 @@ from typing import Sequence as GenericSequence from typing import Union +import torch + from vllm.lora.request import LoRARequest from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceStatus) @@ -66,6 +68,21 @@ def __repr__(self) -> str: f"embedding={len(self.embedding)})") +@dataclass +class SimpleOutput: + """The output data of one completion output of a request. + + Args: + result: The result of Simple model, like XLMRoberta* + """ + + result: torch.Tensor + + def __repr__(self) -> str: + return (f"SimpleOutput(" + f"result_shape={self.result.shape})") + + class RequestOutput: """The output data of a completion request to the LLM. @@ -227,11 +244,59 @@ def __repr__(self): f"finished={self.finished})") +class SimpleRequestOutput: + """ + The output data of an simple model request to the LLM. + + Args: + request_id (str): A unique identifier for the simple model request. + outputs (SimpleOutput): Results for the given input. + prompt_token_ids (List[int]): A list of token IDs used in the prompt. + finished (bool): A flag indicating whether the result is completed. + """ + + def __init__(self, request_id: str, outputs: "SimpleOutput", + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + @classmethod + def from_seq_group(cls, + seq_group: 'SequenceGroup') -> "SimpleRequestOutput": + if seq_group.result is None: + raise ValueError( + "result is missing in seq_group for SimpleRequest.") + output = SimpleOutput(seq_group.result) + prompt_token_ids = seq_group.prompt_token_ids + finished = seq_group.is_finished() + + return cls(seq_group.request_id, output, prompt_token_ids, finished) + + def __repr__(self): + """ + Returns a string representation of an SimpleRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the SimpleRequestOutput instance. + """ + return (f"SimpleRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") + + class RequestOutputFactory: @staticmethod def create(seq_group): # Determine the type based on a condition, for example: + if hasattr(seq_group, 'result') and seq_group.result is not None: + return SimpleRequestOutput.from_seq_group(seq_group) if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) diff --git a/vllm/sequence.py b/vllm/sequence.py index b15955cde76c..848112e8a667 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -595,6 +595,7 @@ def __init__( sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, embeddings: Optional[List[float]] = None, + result: Optional[torch.Tensor] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, @@ -615,6 +616,7 @@ def __init__( self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() self.embeddings = embeddings + self.result = result self.pooling_params = pooling_params self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq @@ -1031,6 +1033,25 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings +class SimpleSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with an tensor.""" + + def __init__( + self, + data: torch.Tensor, + ) -> None: + self.result = data + + def __repr__(self) -> str: + return (f"SimpleSequenceGroupOutput(" + f"result_shape={len(self.result)})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SimpleSequenceGroupOutput): + raise NotImplementedError() + return self.result == other.result + + class IntermediateTensors( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -1147,6 +1168,24 @@ def __eq__(self, other: object): self.__class__) and self.outputs == other.outputs +@dataclass +class SimpleOutput: + outputs: List[SimpleSequenceGroupOutput] + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + def get_all_seq_ids( seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: """Given a list of SequenceGroupMetadata, create a list of all diff --git a/vllm/worker/simple_model_runner.py b/vllm/worker/simple_model_runner.py new file mode 100644 index 000000000000..2cf0f2344990 --- /dev/null +++ b/vllm/worker/simple_model_runner.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.sequence import (IntermediateTensors, SequenceGroupMetadata, + SimpleOutput) +from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, + ModelInputForGPUBuilder) + +logger = init_logger(__name__) + + +class SimpleModelRunner(GPUModelRunnerBase[ModelInputForGPU]): + _model_input_cls: Type[ModelInputForGPU] = (ModelInputForGPU) + _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + observability_config=observability_config) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPU, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SimpleOutput]]: + if num_steps > 1: + raise ValueError( + "SimpleModelRunner does not support multi-step execution.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + virtual_engine = model_input.virtual_engine + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + + execute_model_kwargs = { + "input_ids": model_input.input_tokens, + "positions": model_input.input_positions, + "kv_caches": kv_caches, + "attn_metadata": model_input.attn_metadata, + **(model_input.multi_modal_kwargs or {}), + } + + hidden_states = model_executable(**execute_model_kwargs) + if not self.is_driver_worker: + return [] + + return [hidden_states] + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForGPU: + return ModelInputForGPU.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForGPU: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + assert model_input.seq_lens is not None + return model_input diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 97be68934be4..f90b59fbd554 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,8 +7,8 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, + ModelConfig, ModelMode, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -16,15 +16,16 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models.utils import (is_embedding_model_config, + is_encoder_decoder_model_config, + is_simple_model_config) from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput, SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -85,13 +86,8 @@ def __init__( not in ["medusa", "mlp_speculator"]) \ else {"return_hidden_states": True} - ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_runner_cls is not None: - ModelRunnerClass = model_runner_cls - elif self._is_embedding_model(): - ModelRunnerClass = EmbeddingModelRunner - elif self._is_encoder_decoder_model(): - ModelRunnerClass = EncoderDecoderModelRunner + ModelRunnerClass = model_runner_cls or ModelMode.get_model_runner_cls( + self.model_config.model_mode) self.model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, parallel_config, @@ -114,10 +110,13 @@ def __init__( self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} def _is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return is_encoder_decoder_model_config(self.model_config) def _is_embedding_model(self): - return self.model_config.is_embedding_model + return is_embedding_model_config(self.model_config) + + def _is_simple_model(self): + return is_simple_model_config(self.model_config) def init_device(self) -> None: if self.device_config.device.type == "cuda":