diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index bcd170411e7c..c53cda16d471 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -739,6 +739,9 @@ void paged_attention_v1_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V1(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V1(64); break; @@ -903,6 +906,9 @@ void paged_attention_v2_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V2(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V2(64); break; diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index abb4e3bea14b..55921c05711f 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -375,6 +375,9 @@ void paged_attention_v1_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; case 64: LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; @@ -692,6 +695,9 @@ void paged_attention_v2_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; case 64: LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py new file mode 100644 index 000000000000..7cf6e3fb5933 --- /dev/null +++ b/examples/offline_inference_bert_embedding.py @@ -0,0 +1,16 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "This is an example sentence.", + "Another example sentence.", +] + +# Create an LLM. +model = LLM(model="bert-base-uncased", enforce_eager=True) +outputs = model.encode(prompts) + +# Print the outputs. +for output in outputs: + print(output.outputs.embedding) # list of 768 floats + print(len(output.outputs.embedding)) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 7d5ef128bc8e..013d2d6bb735 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -1,15 +1,22 @@ from vllm import LLM +from vllm.inputs import build_decoder_prompts # Sample prompts. -prompts = [ +prompts = build_decoder_prompts([ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", -] +]) # Create an LLM. -model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) +model = LLM( + model="intfloat/e5-mistral-7b-instruct", + enforce_eager=True, + # NOTE: sliding_window is not supported by encoder_decoder_model + disable_sliding_window=True, + gpu_memory_utilization=0.95, +) # Generate embedding. The output is a list of EmbeddingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index be316c6e12da..3a644479dd3b 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -6,9 +6,22 @@ import torch import torch.nn.functional as F +from vllm.inputs import build_decoder_prompts + MODELS = [ - "intfloat/e5-mistral-7b-instruct", - "BAAI/bge-multilingual-gemma2", + { + "name": "intfloat/e5-mistral-7b-instruct", + "is_decoder_only": True + }, + { + "name": "BAAI/bge-multilingual-gemma2", + "is_decoder_only": True + }, + { + "name": "bert-base-uncased", + "is_decoder_only": False, + "max_model_len": 512 + }, ] @@ -26,7 +39,7 @@ def test_models( hf_runner, vllm_runner, example_prompts, - model: str, + model: dict, dtype: str, ) -> None: # The example_prompts has ending "\n", for example: @@ -37,11 +50,22 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: + model_name = model["name"] + is_decoder_only = model["is_decoder_only"] + max_model_len = model.get("max_model_len", 1024) + with hf_runner(model_name, dtype=dtype, + is_embedding_model=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + with vllm_runner( + model_name, + dtype=dtype, + disable_sliding_window=True, + max_model_len=max_model_len, + ) as vllm_model: + prompt_inputs = build_decoder_prompts( + example_prompts) if is_decoder_only else example_prompts + vllm_outputs = vllm_model.encode(prompt_inputs) similarities = compare_embeddings(hf_outputs, vllm_outputs) all_similarities = torch.stack(similarities) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 92023d5b75f5..076f151ffcb6 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -34,7 +34,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 120, 128, 192, 256] + return [32, 64, 80, 96, 112, 120, 128, 192, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/config.py b/vllm/config.py index 7a3248f4087a..a81584b21088 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -566,6 +566,10 @@ def is_encoder_decoder_model(self) -> bool: (hasattr(self.hf_config, "text_config") and getattr( self.hf_config.text_config, "is_encoder_decoder", False))) + @property + def is_encoder_model(self) -> bool: + return ModelRegistry.is_encoder_model(self.hf_config.architectures) + @property def is_embedding_model(self) -> bool: """Extract the embedding model flag.""" diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index a337392bbed5..e5bff765694b 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -63,9 +63,16 @@ def free(self, seq: Sequence) -> None: # No operation on free return + def free_cross(self, seq: Sequence) -> None: + # No operation on free + return + def get_block_table(self, seq: Sequence) -> List[int]: return None # type: ignore + def get_cross_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + def get_num_free_gpu_blocks(self) -> int: return 1 diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index a8c8672cb5fe..b22350beefd4 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,8 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, LLMInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) + TokensPrompt, build_decoder_prompt, build_decoder_prompts, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -21,6 +22,8 @@ "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", + "build_decoder_prompt", + "build_decoder_prompts", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 724cdd2e6e80..797b58e86c8a 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -228,6 +228,18 @@ def to_enc_dec_tuple_list( for enc_dec_prompt in enc_dec_prompts] +def build_decoder_prompt( + prompt: _T2, ) -> ExplicitEncoderDecoderPrompt[SingletonPrompt, _T2]: + return build_explicit_enc_dec_prompt(encoder_prompt="", + decoder_prompt=prompt) + + +def build_decoder_prompts( + prompts: Iterable[_T2], +) -> List[ExplicitEncoderDecoderPrompt[SingletonPrompt, _T2]]: + return [build_decoder_prompt(prompt) for prompt in prompts] + + def __getattr__(name: str): if name == "PromptInput": import warnings diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 64387fd2fa47..e86ddaef7d65 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -25,6 +25,7 @@ DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], Optional["MultiModalDataDict"], Optional[Dict[str, Any]]] +_DEFAULT_BOS_TOKEN_ID = 1 class InputPreprocessor: @@ -54,7 +55,13 @@ def get_bos_token_id(self, "is not initialized") return None - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + bos_token_id = self.tokenizer.get_lora_tokenizer( + lora_request).bos_token_id + + if bos_token_id is None and self.model_config.is_encoder_model: + bos_token_id = _DEFAULT_BOS_TOKEN_ID + + return bos_token_id def get_eos_token_id(self, lora_request: Optional[LoRARequest] = None @@ -86,9 +93,10 @@ def get_decoder_start_token_id(self) -> Optional[int]: dec_start_token_id = getattr(self.model_config.hf_config, 'decoder_start_token_id', None) if dec_start_token_id is None: - print_warning_once("Falling back on for decoder start token " - "id because decoder start token id is not " - "available.") + if not self.model_config.is_encoder_model: + logger.warning( + "Falling back on for decoder start token id " + "because decoder start token id is not available.") dec_start_token_id = self.get_bos_token_id() return dec_start_token_id @@ -577,4 +585,5 @@ async def preprocess_async( ) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return self.model_config.is_encoder_decoder_model \ + or self.model_config.is_encoder_model diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 76ccb3dfe0a6..1fe0593c99bf 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -12,6 +12,7 @@ class PoolingType(IntEnum): """Enumeration for different types of pooling methods.""" LAST = 0 ALL = 1 + MEAN = 2 class Pooler(nn.Module): @@ -50,6 +51,17 @@ def forward( for prompt_len in prompt_lens: pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len + elif self.pooling_type == PoolingType.MEAN: + # Calculate mean pooling + cumsum = torch.cumsum(hidden_states, dim=0) + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) + end_indices = torch.cumsum(prompt_lens, dim=0) + pooled_data = ( + cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py new file mode 100644 index 000000000000..50c8607a33d4 --- /dev/null +++ b/vllm/model_executor/models/bert_embedding.py @@ -0,0 +1,448 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import BertConfig + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + + +class BertModel(nn.Module): + + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embeddings = BertEmbedding(config) + self.encoder = BertEncoder(config, cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds) + hidden_states = self.encoder(hidden_states, kv_caches, attn_metadata)\ + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + return hidden_states + + +class BertEmbedding(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + device = input_ids.device + else: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + seq_length = input_shape[0] + + # input embeddings + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # position embeddings + if position_ids is None: + position_ids = torch.arange(seq_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + # token type embeddings + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + embeddings += position_embeddings + embeddings = self.layernorm(embeddings) + return embeddings + + +class BertEncoder(nn.Module): + + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.layer = nn.ModuleList([ + BertLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i in range(len(self.layer)): + layer = self.layer[i] + hidden_states = layer( + hidden_states, + kv_caches[i], + attn_metadata, + ) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super(BertLayer, self).__init__() + self.attention = BertAttention(config=config, + cache_config=cache_config, + quant_config=quant_config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + ): + self_attention_outputs = self.attention( + hidden_states, + kv_cache, + attn_metadata, + ) + + output = self.feed_forward(self_attention_outputs) + return output + + def feed_forward(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertAttention(nn.Module): + + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.self = BertSelfAttention(config=config, + cache_config=cache_config, + quant_config=quant_config) + self.output = BertSelfOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + self_outputs = self.self(hidden_states, kv_cache, attn_metadata) + attn_output = self.output(self_outputs, hidden_states) + return attn_output + + +class BertSelfAttention(nn.Module): + + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=True, + quant_config=quant_config) + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER) + return output + + +class BertSelfOutput(nn.Module): + + def __init__(self, config: BertConfig): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.layernorm(hidden_states + input_tensor) + return hidden_states + + +class BertIntermediate(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.layernorm(hidden_states + input_tensor) + return hidden_states + + +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertEmbeddingModel(nn.Module): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + stacked_params_mapping = { + "query": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "key": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "value": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.base_model_prefix = "bert" + self.model = BertModel(config=kwargs["config"], + cache_config=kwargs.get("cache_config", None), + quant_config=kwargs.get("quant_config", None)) + self._pooler = Pooler(pooling_type=PoolingType.MEAN, normalize=False) + # self._pooler = BertPooler(config=kwargs["config"]) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + encoder_input_ids: Optional[torch.Tensor], + encoder_positions: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids=encoder_input_ids, + position_ids=encoder_positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_dict = dict(self.model.named_parameters()) + + for name, loaded_weight in weights: + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py index 13574e84d7aa..85acc9ac4b47 100644 --- a/vllm/model_executor/models/llama_embedding.py +++ b/vllm/model_executor/models/llama_embedding.py @@ -37,6 +37,8 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + encoder_input_ids: Optional[torch.Tensor], + encoder_positions: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b06d3d612dbc..3643a310a2e8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -89,8 +89,19 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), + "BertForMaskedLM": ("bert_embedding", "BertEmbeddingModel"), + "BertModel": ("bert_embedding", "BertEmbeddingModel"), + "RobertaForMaskedLM": ("roberta_embedding", "RobertaEmbeddingModel"), + "RobertaModel": ("roberta_embedding", "RobertaEmbeddingModel"), } +_ENCODER_MODELS = [ + "BertForMaskedLM", + "BertModel", + "RobertaForMaskedLM", + "RobertaModel", +] + _MULTIMODAL_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), @@ -377,6 +388,15 @@ def is_embedding_model( ) -> bool: return self.inspect_model_cls(architectures).is_embedding_model + @staticmethod + def is_encoder_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return any(arch in _ENCODER_MODELS for arch in architectures) + def is_multimodal_model( self, architectures: Union[str, List[str]], diff --git a/vllm/model_executor/models/roberta_embedding.py b/vllm/model_executor/models/roberta_embedding.py new file mode 100644 index 000000000000..8f0a7d8f9a58 --- /dev/null +++ b/vllm/model_executor/models/roberta_embedding.py @@ -0,0 +1,78 @@ +from typing import Optional + +from torch import nn +from transformers import RobertaConfig + +from vllm.config import CacheConfig +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.models.bert_embedding import (BertEmbedding, + BertEmbeddingModel, + BertEncoder, BertModel) + + +class RobertaModel(BertModel): + + def __init__( + self, + config: RobertaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + # Skip BertModel.__init__() + nn.Module.__init__(self) + self.embeddings = RobertaEmbedding(config) + self.encoder = BertEncoder(config, cache_config, quant_config) + + +class RobertaEmbedding(BertEmbedding): + + def __init__(self, config: RobertaConfig): + # Skip BertEmbedding.__init__() + nn.Module.__init__(self) + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx) + + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + +class RobertaEmbeddingModel(BertEmbeddingModel): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the RobertaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of RobertaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + # Skip BertEmbeddingModule.__init__() + nn.Module.__init__(self) + self.base_model_prefix = "roberta" + self.model = RobertaModel( + config=kwargs["config"], + cache_config=kwargs.get("cache_config", None), + quant_config=kwargs.get("quant_config", None)) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # self._pooler = BertPooler(config=kwargs["config"]) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a7f5b2d4fdd1..815f89097638 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,8 +1,9 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast import torch +from vllm.attention.backends.abstract import AttentionBackend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -14,24 +15,33 @@ from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) -from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, - ModelInputForGPUBuilder) +from vllm.worker.enc_dec_model_runner import (EncoderDecoderModelInput, + EncoderDecoderModelRunnerBase) +from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @dataclasses.dataclass(frozen=True) -class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): +class EmbeddingModelInput(EncoderDecoderModelInput): """ Used by the EmbeddingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EmbeddingModelInput": + return cast( + EmbeddingModelInput, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) -class EmbeddingModelRunner( - GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): - _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( - ModelInputForGPUWithPoolingMetadata) + +class EmbeddingModelRunner(EncoderDecoderModelRunnerBase[EmbeddingModelInput]): + _model_input_cls: Type[EmbeddingModelInput] = EmbeddingModelInput _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder def __init__( @@ -63,7 +73,7 @@ def __init__( @torch.inference_mode() def execute_model( self, - model_input: ModelInputForGPUWithPoolingMetadata, + model_input: EmbeddingModelInput, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, @@ -119,6 +129,8 @@ def execute_model( hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, @@ -158,10 +170,8 @@ def execute_model( ] def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> ModelInputForGPUWithPoolingMetadata: - return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EmbeddingModelInput: + return EmbeddingModelInput.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) @@ -171,14 +181,34 @@ def prepare_model_input( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithPoolingMetadata: + ) -> EmbeddingModelInput: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) + + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + encoder_seq_lens, + ) = super()._prepare_encoder_model_input_tensors( + seq_group_metadata_list, model_input) + + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + # Prepare PoolingMetadata. - assert model_input.seq_lens is not None + seq_lens = model_input.seq_lens\ + if not self.model_config.is_encoder_model \ + else encoder_seq_lens + assert seq_lens is not None, "model is_encoder_model: "\ + f"{self.model_config.is_encoder_model}" pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) + seq_lens) return dataclasses.replace(model_input, pooling_metadata=pooling_metadata) @@ -190,7 +220,7 @@ def _prepare_pooling( ) -> PoolingMetadata: """Prepare PoolingMetadata for the sequence group metadata list.""" seq_groups: List[Tuple[List[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) pooling_params = seq_group_metadata.pooling_params seq_groups.append((seq_ids, pooling_params)) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 6a00444f5098..3451c0b16c13 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,6 +1,6 @@ import dataclasses import itertools -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, cast import torch import torch.distributed @@ -28,7 +28,7 @@ from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, - _get_graph_batch_size) + TModelInputForGPU, _get_graph_batch_size) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -36,6 +36,9 @@ logger = init_logger(__name__) +TEncoderDecoderModelInput = TypeVar('TEncoderDecoderModelInput', + bound="EncoderDecoderModelInput") + @dataclasses.dataclass(frozen=True) class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): @@ -44,6 +47,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): """ encoder_input_tokens: Optional[torch.Tensor] = None encoder_input_positions: Optional[torch.Tensor] = None + encoder_seq_lens: Optional[List[int]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -72,10 +76,7 @@ def from_broadcasted_tensor_dict( super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) -class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): - _model_input_cls: Type[EncoderDecoderModelInput] = ( - EncoderDecoderModelInput) - _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) +class EncoderDecoderModelRunnerBase(GPUModelRunnerBase[TModelInputForGPU]): def __init__( self, @@ -171,120 +172,6 @@ def _empty_int32_tensor(self) -> torch.Tensor: def _empty_long_tensor(self) -> torch.Tensor: return self._list_to_long_tensor([]) - @torch.inference_mode() - def execute_model( - self, - model_input: EncoderDecoderModelInput, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[PoolerOutput]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in " - "EncoderDecoderModelRunner") - - if (model_input.attn_metadata is not None - and model_input.attn_metadata.prefill_metadata is None - and model_input.attn_metadata.decode_metadata.use_cuda_graph): - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[ - model_input.virtual_engine][graph_batch_size] - else: - model_executable = self.model - - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - if not self.is_driver_worker: - return [] - - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - output: SamplerOutput = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - - return [output] - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: - return EncoderDecoderModelInput.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> EncoderDecoderModelInput: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - Since chunked prefill is not supported for encoder/decoder models, - `input_tokens` is assumed to be either entirely prefill tokens or - entirely decode tokens. - - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - ( - attn_metadata, - encoder_input_tokens_tensor, - encoder_input_positions_tensor, - ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, - model_input)) - # Inject attn_metadata encoder/cross-attention fields & - # encoder input tokens/positions into model_input. - # Frozen dataclass fields cannot be modified, so use - # dataclasses.replace to construct a new model input - # instance. - model_input = dataclasses.replace( - model_input, - attn_metadata=attn_metadata, - encoder_input_tokens=encoder_input_tokens_tensor, - encoder_input_positions=encoder_input_positions_tensor, - ) - - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - self.pin_memory, - generators=generators) - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) @@ -363,12 +250,12 @@ def profile_run(self) -> None: def _prepare_encoder_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: EncoderDecoderModelInput, + model_input: TEncoderDecoderModelInput, ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], - Optional[torch.Tensor]]: + Optional[torch.Tensor], List[int]]: """Helper method to prepare the encoder- and cross-attn-related model inputs based on a given sequence group. These additional inputs - are used to augment an already-computed `EncoderDecoderModelInput` + are used to augment an already-computed `TEncoderDecoderModelInput` data structure which already has decoder-related model inputs populated. @@ -402,7 +289,7 @@ def _prepare_encoder_model_input_tensors( """ if len(seq_group_metadata_list) == 0: - return (model_input.attn_metadata, None, None) + return (model_input.attn_metadata, None, None, []) # Since we are not supporting chunked prefill either the entire # batch is prefill or it is decode @@ -442,10 +329,12 @@ def _prepare_encoder_model_input_tensors( cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) else: for i in range(0, seq_len): - block_number = seq_group_metadata.cross_block_table[ - i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset + slot = PAD_SLOT_ID + if seq_group_metadata.cross_block_table is not None: + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset cross_slot_mapping.append(slot) # Build encoder input tokens @@ -540,4 +429,126 @@ def _prepare_encoder_model_input_tensors( ) return (attn_metadata, encoder_input_tokens_tensor, - encoder_input_positions_tensor) + encoder_input_positions_tensor, encoder_seq_lens) + + +class EncoderDecoderModelRunner( + EncoderDecoderModelRunnerBase[EncoderDecoderModelInput]): + + _model_input_cls: Type[EncoderDecoderModelInput] = ( + EncoderDecoderModelInput) + _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) + + @torch.inference_mode() + def execute_model( + self, + model_input: EncoderDecoderModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[PoolerOutput]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in " + "EncoderDecoderModelRunner") + + if (model_input.attn_metadata is not None + and model_input.attn_metadata.prefill_metadata is None + and model_input.attn_metadata.decode_metadata.use_cuda_graph): + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[ + model_input.virtual_engine][graph_batch_size] + else: + model_executable = self.model + + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output] + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: + return EncoderDecoderModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInput: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + Since chunked prefill is not supported for encoder/decoder models, + `input_tokens` is assumed to be either entirely prefill tokens or + entirely decode tokens. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + _, + ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input)) + # Inject attn_metadata encoder/cross-attention fields & + # encoder input tokens/positions into model_input. + # Frozen dataclass fields cannot be modified, so use + # dataclasses.replace to construct a new model input + # instance. + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory, + generators=generators) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine)