diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 445b30b8c6e9..76ccb3dfe0a6 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -11,6 +11,7 @@ class PoolingType(IntEnum): """Enumeration for different types of pooling methods.""" LAST = 0 + ALL = 1 class Pooler(nn.Module): @@ -43,6 +44,12 @@ def forward( if self.pooling_type == PoolingType.LAST: last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 pooled_data = hidden_states[last_token_flat_indices] + elif self.pooling_type == PoolingType.ALL: + offset = 0 + pooled_data = [] + for prompt_len in prompt_lens: + pooled_data.append(hidden_states[offset:offset + prompt_len]) + offset += prompt_len else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 3a6fa9e26ff4..682a2e71a1db 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -74,6 +74,7 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), + "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), } _MULTIMODAL_MODELS = { diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py new file mode 100644 index 000000000000..51cef5c47c4d --- /dev/null +++ b/vllm/model_executor/models/qwen2_rm.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .utils import is_pp_missing_parameter + + +class ReLU(nn.Module): + + def __init__(self): + super().__init__() + self.activation = nn.ReLU() + + def forward(self, input): + input, _ = input + return self.activation(input) + + +class Qwen2ForRewardModel(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(config, cache_config, quant_config) + + self.score = nn.Sequential( + ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config), + ReLU(), + RowParallelLinear(config.hidden_size, 1, + quant_config=quant_config), + ) + self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + logits, _ = self.score(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + # Skip loading lm_head for embedding model + if name == "lm_head.weight": + continue + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)