Skip to content

Commit

Permalink
Support qwen2 vl model (#1546)
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 authored Oct 18, 2024
1 parent dd3809f commit 92376c0
Show file tree
Hide file tree
Showing 15 changed files with 1,311 additions and 25 deletions.
3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"outlines>=0.0.44", "modelscope"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt = ["sglang[runtime_common]", "torch", "vllm==0.5.5"]

srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
srt_xpu = ["sglang[runtime_common]"]

openai = ["openai>=1.0", "tiktoken"]
Expand Down
16 changes: 16 additions & 0 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ def get_chat_template_by_model_path(model_path):
)
)

# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_chat_template(
ChatTemplate(
name="qwen2-vl",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>"),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)


register_chat_template(
ChatTemplate(
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig

__all__ = [
"ExaoneConfig",
"Qwen2VLConfig",
"Qwen2VLVisionConfig",
]
133 changes: 133 additions & 0 deletions python/sglang/srt/configs/qwen2vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""

import os
from typing import Union

from transformers import PretrainedConfig


class Qwen2VLVisionConfig(PretrainedConfig):
model_type = "qwen2_vl"

def __init__(
self,
depth=32,
embed_dim=1280,
hidden_size=3584,
hidden_act="quick_gelu",
mlp_ratio=4,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
**kwargs,
):
super().__init__(**kwargs)

self.depth = depth
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)

if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]

return cls.from_dict(config_dict, **kwargs)


class Qwen2VLConfig(PretrainedConfig):
model_type = "qwen2_vl"

def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = Qwen2VLVisionConfig(**vision_config)
elif vision_config is None:
self.vision_config = Qwen2VLVisionConfig()

self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling

# NOTE: the following section from original transformers config
# for Qwen2-VL is commented out to address rope config loading issue
#
# if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
14 changes: 14 additions & 0 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,17 @@ def generate_chat_conv(
stop_str=["<|im_end|>", "<|action_end|>"],
)
)

# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template(
Conversation(
name="qwen2-vl",
system_message="You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"],
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
3 changes: 2 additions & 1 deletion python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig

from sglang.srt.configs import ExaoneConfig
from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig

_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig,
Qwen2VLConfig.model_type: Qwen2VLConfig,
}
except ImportError:
# We want this file to run without vllm dependency
Expand Down
30 changes: 26 additions & 4 deletions python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
Expand Down Expand Up @@ -78,7 +79,9 @@ def _fwd_kernel(
mask_d = offs_d < Lk

q = tl.load(
Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0
Q + off_q,
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
)

k_ptrs = K + off_k
Expand All @@ -91,7 +94,12 @@ def _fwd_kernel(

block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)

for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
end_n = (
cur_batch_seq_len
if not IS_CAUSAL
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
)
for start_n in range(0, block_mask * end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
Expand All @@ -104,7 +112,18 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))

if IS_CAUSAL:
qk += tl.where(
(start_n + offs_n[None, :] < cur_batch_seq_len)
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
0,
float("-inf"),
)
else:
qk += tl.where(
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
)

# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
Expand Down Expand Up @@ -146,7 +165,9 @@ def _fwd_kernel(
)


def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
BLOCK = 128
else:
Expand Down Expand Up @@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK_M=BLOCK,
BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK,
IS_CAUSAL=is_causal,
num_warps=num_warps,
num_stages=1,
Lk=Lk,
Expand Down
Loading

0 comments on commit 92376c0

Please sign in to comment.