diff --git a/python/pyproject.toml b/python/pyproject.toml index 800ce0837e..f68a152980 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -26,7 +26,8 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu "outlines>=0.0.44", "modelscope"] # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm -srt = ["sglang[runtime_common]", "torch", "vllm==0.5.5"] + +srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] srt_xpu = ["sglang[runtime_common]"] openai = ["openai>=1.0", "tiktoken"] diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index fa300b25f0..d7602964d4 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -133,6 +133,22 @@ def get_chat_template_by_model_path(model_path): ) ) +# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example +register_chat_template( + ChatTemplate( + name="qwen2-vl", + default_system_prompt="You are a helpful assistant.", + role_prefix_and_suffix={ + "system": ("<|im_start|>system\n", "<|im_end|>\n"), + "user": ("<|im_start|>user\n", "<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), + }, + style=ChatTemplateStyle.PLAIN, + stop_str=("<|im_end|>"), + image_token="<|vision_start|><|image_pad|><|vision_end|>", + ) +) + register_chat_template( ChatTemplate( diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 9e74366709..600b58e493 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,5 +1,8 @@ from sglang.srt.configs.exaone import ExaoneConfig +from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig __all__ = [ "ExaoneConfig", + "Qwen2VLConfig", + "Qwen2VLVisionConfig", ] diff --git a/python/sglang/srt/configs/qwen2vl.py b/python/sglang/srt/configs/qwen2vl.py new file mode 100644 index 0000000000..4d30c741e9 --- /dev/null +++ b/python/sglang/srt/configs/qwen2vl.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2VL model configuration""" + +import os +from typing import Union + +from transformers import PretrainedConfig + + +class Qwen2VLVisionConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + return cls.from_dict(config_dict, **kwargs) + + +class Qwen2VLConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # NOTE: the following section from original transformers config + # for Qwen2-VL is commented out to address rope config loading issue + # + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # if self.rope_scaling["type"] == "mrope": + # self.rope_scaling["type"] = "default" + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # rope_config_validation(self) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index bff58fc14d..73bbc1e2ee 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -530,3 +530,17 @@ def generate_chat_conv( stop_str=["<|im_end|>", "<|action_end|>"], ) ) + +# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example +register_conv_template( + Conversation( + name="qwen2-vl", + system_message="You are a helpful assistant.", + system_template="<|im_start|>system\n{system_message}", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="<|im_end|>\n", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + stop_str=["<|im_end|>"], + image_token="<|vision_start|><|image_pad|><|vision_end|>", + ) +) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index eb3025421a..0f9f94dcac 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -33,12 +33,13 @@ try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig - from sglang.srt.configs import ExaoneConfig + from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ChatGLMConfig.model_type: ChatGLMConfig, DbrxConfig.model_type: DbrxConfig, ExaoneConfig.model_type: ExaoneConfig, + Qwen2VLConfig.model_type: Qwen2VLConfig, } except ImportError: # We want this file to run without vllm dependency diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index 79ad35e447..c90aac1cc4 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -50,6 +50,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, Lk: tl.constexpr, ): cur_batch = tl.program_id(0) @@ -78,7 +79,9 @@ def _fwd_kernel( mask_d = offs_d < Lk q = tl.load( - Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0 + Q + off_q, + mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, ) k_ptrs = K + off_k @@ -91,7 +94,12 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + end_n = ( + cur_batch_seq_len + if not IS_CAUSAL + else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + ) + for start_n in range(0, block_mask * end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load( @@ -104,7 +112,18 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + if IS_CAUSAL: + qk += tl.where( + (start_n + offs_n[None, :] < cur_batch_seq_len) + & (offs_m[:, None] >= (start_n + offs_n[None, :])), + 0, + float("-inf"), + ) + else: + qk += tl.where( + (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf") + ) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -146,7 +165,9 @@ def _fwd_kernel( ) -def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): +def context_attention_fwd( + q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True +): if is_cuda_available and CUDA_CAPABILITY[0] >= 8: BLOCK = 128 else: @@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): BLOCK_M=BLOCK, BLOCK_DMODEL=triton.next_power_of_2(Lk), BLOCK_N=BLOCK, + IS_CAUSAL=is_causal, num_warps=num_warps, num_stages=1, Lk=Lk, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py new file mode 100644 index 0000000000..9af9285feb --- /dev/null +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -0,0 +1,145 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""MRotaryEmbedding""" +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch + + +class MRotaryEmbedding: + """Rotary Embedding with Multimodal Sections.""" + + @staticmethod + def get_input_positions( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + extend_prefix_len: int = 0, + ) -> Tuple[List[List[int]], int]: + """Get mrope input positions and delta value.""" + + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions += extend_prefix_len + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> List[List[int]]: + return [ + list( + range( + context_len + mrope_position_delta, seq_len + mrope_position_delta + ) + ) + for _ in range(3) + ] diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index e1e54af7fc..b958ab89bc 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -177,10 +177,127 @@ async def process_images_async( } +class Qwen2VLImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _image_processor): + self.hf_config = hf_config + self._image_processor = _image_processor + self.executor = concurrent.futures.ProcessPoolExecutor( + initializer=init_global_processor, + mp_context=mp.get_context("fork"), + initargs=(server_args,), + max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), + ) + + @staticmethod + def _process_single_image_task( + image_data: Union[str, bytes], + image_processor=None, + ): + image_processor = image_processor or global_processor.image_processor + + try: + image, image_size = load_image(image_data) + if image_size is not None: + # It is a video with multiple images + image_hash = hash(image_data) + process_result = image_processor(image) + pixel_values, image_grid_thws = ( + process_result["pixel_values"], + process_result["image_grid_thw"][0], + ) + for _ in range(len(pixel_values)): + pixel_values[_] = pixel_values[_].astype(np.float16) + pixel_values = np.stack(pixel_values, axis=0) + image_grid_thws = np.stack(image_grid_thws, axis=0) + return pixel_values, image_hash, image_size, image_grid_thws + else: + # It is an image + image_hash = hash(image_data) + process_result = image_processor(image) + pixel_values, image_grid_thws = ( + process_result["pixel_values"], + process_result["image_grid_thw"][0], + ) + if isinstance(pixel_values, np.ndarray): + pixel_values = pixel_values.astype(np.float16) + + return pixel_values, image_hash, image.size, image_grid_thws + except Exception: + logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) + + async def _process_single_image(self, image_data: Union[bytes, str]): + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + Qwen2VLImageProcessor._process_single_image_task, + image_data, + ) + else: + return self._process_single_image_task(image_data) + + async def process_images_async( + self, image_data: List[Union[str, bytes]], request_obj + ): + if not image_data: + return None + + if isinstance(image_data, list) and len(image_data) > 0: + # Multiple images + if len(image_data) > 1: + pixel_values, image_hashes, image_sizes, image_grid_thws = ( + [], + [], + [], + [], + ) + res = [] + for img_data in image_data: + res.append(self._process_single_image(img_data)) + res = await asyncio.gather(*res) + for pixel_v, image_h, image_s, image_thw in res: + pixel_values.append(pixel_v) + image_hashes.append(image_h) + image_sizes.append(image_s) + image_grid_thws.append(image_thw) + + if isinstance(pixel_values[0], np.ndarray): + pixel_values = np.concatenate(pixel_values, axis=0) + else: + # A single image + pixel_values, image_hash, image_size, image_grid_thw = ( + await self._process_single_image(image_data[0]) + ) + image_hashes = [image_hash] + image_sizes = [image_size] + image_grid_thws = [image_grid_thw] + elif isinstance(image_data, str): + # A single image + pixel_values, image_hash, image_size, image_grid_thw = ( + await self._process_single_image(image_data) + ) + image_hashes = [image_hash] + image_sizes = [image_size] + image_grid_thws = [image_grid_thw] + else: + raise ValueError(f"Invalid image data: {image_data}") + + return { + "pixel_values": pixel_values, + "image_hashes": image_hashes, + "image_sizes": image_sizes, + "modalities": request_obj.modalities, + "image_grid_thws": image_grid_thws, + } + + def get_image_processor( hf_config, server_args: ServerArgs, _image_processor ) -> BaseImageProcessor: - return LlavaImageProcessor(hf_config, server_args, _image_processor) + if "Qwen2VLForConditionalGeneration" in hf_config.architectures: + return Qwen2VLImageProcessor(hf_config, server_args, _image_processor) + else: + return LlavaImageProcessor(hf_config, server_args, _image_processor) def get_dummy_image_processor(): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0eeb3359e8..742ac39768 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -127,6 +127,8 @@ class ImageInputs: image_embeds: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None + # QWen2-VL related + image_grid_thws: List[Tuple[int, int, int]] = None @staticmethod def from_dict(obj, vocab_size): @@ -134,6 +136,7 @@ def from_dict(obj, vocab_size): ret = ImageInputs( pixel_values=obj["pixel_values"], image_hash=hash(tuple(obj["image_hashes"])), + image_grid_thws=obj.get("image_grid_thws"), ) image_hash = ret.image_hash ret.pad_values = [ @@ -235,6 +238,9 @@ def __init__( self.regex_fsm_state: int = 0 self.jump_forward_map: JumpForwardMap = None + # For Qwen2-VL + self.mrope_position_delta = [] # use mutable object + # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None @@ -848,6 +854,8 @@ def get_model_worker_batch(self): global bid bid += 1 + mrope_positions_delta = [req.mrope_position_delta for req in self.reqs] + return ModelWorkerBatch( bid=bid, forward_mode=self.forward_mode, @@ -863,6 +871,7 @@ def get_model_worker_batch(self): image_inputs=image_inputs, lora_paths=lora_paths, sampling_info=self.sampling_info, + mrope_positions_delta=mrope_positions_delta, ) def copy(self): @@ -917,6 +926,9 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo + # For Qwen2-VL + mrope_positions_delta: List[List[int]] + def copy(self): return ModelWorkerBatch( bid=self.bid, @@ -933,4 +945,5 @@ def copy(self): image_inputs=self.image_inputs, lora_paths=self.lora_paths, sampling_info=self.sampling_info.copy(), + mrope_positions_delta=self.mrope_positions_delta, ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index be6a72afd2..b0a1f5fba8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -36,6 +36,8 @@ import numpy as np import torch +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding + if TYPE_CHECKING: from sglang.srt.layers.attention import AttentionBackend from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch @@ -112,14 +114,88 @@ class ForwardBatch: token_to_kv_pool: BaseTokenToKVPool = None attn_backend: AttentionBackend = None + # For Qwen2-VL + mrope_positions: torch.Tensor = None + + def compute_mrope_positions( + self, model_runner: ModelRunner, batch: ModelWorkerBatch + ): + device = model_runner.device + hf_config = model_runner.model_config.hf_config + mrope_positions_list = [None] * self.seq_lens.shape[0] + if self.forward_mode.is_decode(): + for i, _ in enumerate(mrope_positions_list): + mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( + batch.mrope_positions_delta[i][0], + int(self.seq_lens[i]) - 1, + int(self.seq_lens[i]), + ) + elif self.forward_mode.is_extend(): + for i, image_inputs in enumerate(batch.image_inputs): + if image_inputs is None: + # text only + mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3 + mrope_position_delta = 0 + else: + extend_start_loc, extend_seq_len, extend_prefix_len = ( + self.extend_start_loc[i], + self.extend_seq_lens[i], + self.extend_prefix_lens[i], + ) + mrope_positions, mrope_position_delta = ( + MRotaryEmbedding.get_input_positions( + input_tokens=self.input_ids[ + extend_start_loc : extend_start_loc + extend_seq_len + ].tolist(), + image_grid_thw=image_inputs.image_grid_thws, + video_grid_thw=None, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + context_len=0, + extend_prefix_len=extend_prefix_len.item(), + ) + ) + mrope_positions_list[i] = mrope_positions + batch.mrope_positions_delta[i].append(mrope_position_delta) + + self.mrope_positions = torch.tensor( + np.concatenate( + [np.array(pos) for pos in mrope_positions_list], + axis=1, + ), + device=device, + ) + self.mrope_positions = self.mrope_positions.to(torch.int64) + + def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch): + device = model_runner.device + if self.forward_mode.is_decode(): + self.positions = (self.seq_lens - 1).to(torch.int64) + else: + self.positions = torch.tensor( + np.concatenate( + [ + np.arange(prefix_len, prefix_len + extend_len) + for prefix_len, extend_len in zip( + batch.extend_prefix_lens, batch.extend_seq_lens + ) + ], + axis=0, + ), + device=device, + ).to(torch.int64) + @classmethod def init_new( cls, batch: ModelWorkerBatch, model_runner: ModelRunner, ): - device = model_runner.device + device = model_runner.device ret = cls( forward_mode=batch.forward_mode, batch_size=len(batch.seq_lens), @@ -133,23 +209,7 @@ def init_new( sampling_info=batch.sampling_info, ) - # Init position information - if ret.forward_mode.is_decode(): - ret.positions = (ret.seq_lens - 1).to(torch.int64) - else: - ret.positions = torch.tensor( - np.concatenate( - [ - np.arange(prefix_len, prefix_len + extend_len) - for prefix_len, extend_len in zip( - batch.extend_prefix_lens, batch.extend_seq_lens - ) - ], - axis=0, - ), - device=device, - ).to(torch.int64) - + if not batch.forward_mode.is_decode(): ret.image_inputs = batch.image_inputs ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device) ret.extend_prefix_lens = torch.tensor( @@ -160,6 +220,13 @@ def init_new( ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens + # Init position information + is_mrope = model_runner.model_is_mrope + if is_mrope: + ret.compute_mrope_positions(model_runner, batch) + else: + ret.compute_positions(model_runner, batch) + # Init attention information ret.req_to_token_pool = model_runner.req_to_token_pool ret.token_to_kv_pool = model_runner.token_to_kv_pool diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d583dcd34f..30c1ffc2c8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -122,6 +122,11 @@ def __init__( ) server_args.chunked_prefill_size = None server_args.mem_fraction_static *= 0.95 + # TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically + if self.model_config.hf_config.architectures == [ + "Qwen2VLForConditionalGeneration" + ]: + server_args.disable_cuda_graph = True # Global vars if server_args.show_time_cost: @@ -613,6 +618,15 @@ def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchIn return logits + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + @lru_cache() def import_model_classes(): diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py new file mode 100644 index 0000000000..ae2d4f58c6 --- /dev/null +++ b/python/sglang/srt/models/qwen2_vl.py @@ -0,0 +1,720 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +from functools import lru_cache, partial +from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal + +from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.attention.triton_ops.prefill_attention import ( + context_attention_fwd, +) +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.qwen2 import Qwen2Model + +logger = init_logger(__name__) + +# === Vision Inputs === # + + +class Qwen2VLImageInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Qwen2VLVideoInputs(TypedDict): + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +# === Vision Encoder === # + + +class Qwen2VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int = None, + act_layer: Type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.fc1 = ColumnParallelLinear( + in_features, hidden_features, quant_config=quant_config + ) + self.act = act_layer() + self.fc2 = RowParallelLinear( + hidden_features, in_features, quant_config=quant_config + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) + return output + + +class Qwen2VisionAttention(nn.Module): + + def __init__( + self, + embed_dim: Optional[int] = None, + num_heads: Optional[int] = None, + projection_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size + ) + + self.qkv = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + ) + self.proj = RowParallelLinear( + input_size=projection_size, output_size=embed_dim, quant_config=quant_config + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + x = x.view(*new_x_shape) + + # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) + batch_size = q.shape[1] + + q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)] + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = (seq_lens).max().item() + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + + output = torch.empty_like(q) + context_attention_fwd( + q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False + ) + + context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Qwen2VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Type[nn.Module] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Qwen2VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + ) + self.mlp = Qwen2VisionMLP( + dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config + ) + + def forward( + self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_chans: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.embed_dim) + return x + + +class Qwen2VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Type[nn.Module] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, d_model, bias=True, quant_config=quant_config + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + patch_size: int = vision_config.patch_size + temporal_patch_size: int = vision_config.temporal_patch_size + spatial_merge_size: int = vision_config.spatial_merge_size + in_chans: int = vision_config.in_chans + hidden_size: int = vision_config.hidden_size + embed_dim: int = vision_config.embed_dim + depth: int = vision_config.depth + num_heads: int = vision_config.num_heads + mlp_ratio: float = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + + self.patch_embed = Qwen2VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) + for _ in range(depth) + ] + ) + self.merger = Qwen2VisionPatchMerger( + d_model=hidden_size, + context_dim=embed_dim, + norm_layer=norm_layer, + quant_config=quant_config, + ) + + @property + def dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + @property + def device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + x = x.unsqueeze(1) + for blk in self.blocks: + x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + # adapter + x = self.merger(x) + return x + + +cached_get_processor = lru_cache(get_processor) + + +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): + def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]): + processor = cached_get_processor(self.config._name_or_path) + grid_t, grid_h, grid_w = image_grid_thw + num_image_tokens = ( + grid_t + * grid_h + * grid_w + // processor.image_processor.merge_size + // processor.image_processor.merge_size + ) + return num_image_tokens + + # Use grid_t * grid_w * grid_h to pad tokens for each image + # and replaced padding by unique image hash + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + image_grid_thws = image_inputs.image_grid_thws + pad_values = image_inputs.pad_values + + image_indices = [ + idx + for idx, token in enumerate(input_ids) + if token == self.config.image_token_id + ] + image_inputs.image_offsets = [] + + input_ids_with_image = [] + for image_cnt, _ in enumerate(image_grid_thws): + num_image_tokens = self.calculate_num_image_tokens( + image_grid_thws[image_cnt] + ) + if image_cnt == 0: + non_image_tokens = input_ids[: image_indices[image_cnt]] + else: + non_image_tokens = input_ids[ + image_indices[image_cnt - 1] + 1 : image_indices[image_cnt] + ] + input_ids_with_image.extend(non_image_tokens) + image_inputs.image_offsets.append(len(input_ids_with_image)) + pad_ids = pad_values * ( + (num_image_tokens + len(pad_values)) // len(pad_values) + ) + input_ids_with_image.extend(pad_ids[:num_image_tokens]) + input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :]) + + return input_ids_with_image + + def __init__( + self, + config: Qwen2VLConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + # NOTE: Qwen2-VL vision encoder does not support any + # quantization method now. + quant_config=None, + ) + + self.model = Qwen2Model(config, quant_config) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + + self.logits_processor = LogitsProcessor(config) + + def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) + return image_embeds + + def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: + pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) + video_embeds = self.visual( + pixel_values_videos, grid_thw=video_input["video_grid_thw"] + ) + return video_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ): + """Run forward pass for Qwen2-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + (Use input_metadata.mrope_positions to replace it) + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + """ + image_inputs = None + if forward_batch.image_inputs is not None: + image_inputs = [ + img for img in forward_batch.image_inputs if img is not None + ] + + positions = forward_batch.mrope_positions + if image_inputs is None or len(image_inputs) == 0: + inputs_embeds = self.model.embed_tokens(input_ids) + else: + if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}" + ) + + inputs_embeds = self.model.embed_tokens(input_ids) + extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() + prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() + for i, image in enumerate(forward_batch.image_inputs): + if image == None: + continue + start_idx = extend_start_loc_cpu[i] + prefix_len = prefix_lens_cpu[i] + + pixel_values = torch.tensor(image.pixel_values, device="cuda") + image_grid_thws = torch.tensor( + np.array(image.image_grid_thws), device="cuda" + ) + image_offsets = image.image_offsets + image_input = Qwen2VLImageInputs( + pixel_values=pixel_values, image_grid_thw=image_grid_thws + ) + image_embeds = self._process_image_input(image_input) + + image_embeds_offset = 0 + for idx, image_offset in enumerate(image_offsets): + if image_offset < prefix_len: + continue + num_image_tokens = self.calculate_num_image_tokens( + image_grid_thws[idx] + ) + left_idx = start_idx + (image_offset - prefix_len) + right_idx = ( + start_idx + (image_offset - prefix_len) + num_image_tokens + ) + inputs_embeds[left_idx:right_idx] = image_embeds[ + image_embeds_offset : image_embeds_offset + num_image_tokens + ] + image_embeds_offset += num_image_tokens + + input_ids = None + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=inputs_embeds, + ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "visual" in name and "qkv.weight" in name: + visual_num_heads = self.config.vision_config.num_heads + visual_embed_dim = self.config.vision_config.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view( + 3, visual_num_heads, head_size, visual_embed_dim + ) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) + elif "visual" in name and "qkv.bias" in name: + visual_num_heads = self.config.vision_config.num_heads + visual_embed_dim = self.config.vision_config.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = Qwen2VLForConditionalGeneration diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 11ce25940d..d72e375a0b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures): or "LlavaQwenForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures + or "Qwen2VLForConditionalGeneration" in model_architectures ): return True else: diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 727f5774ca..3b142ff6a2 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -289,5 +289,24 @@ def test_regex(self): assert isinstance(js_obj["number_of_cars"], int) +class TestQWen2VLServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen2-VL-7B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--chat-template", + "qwen2-vl", + ], + ) + cls.base_url += "/v1" + + if __name__ == "__main__": unittest.main()