Skip to content

Commit

Permalink
[Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (vll…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and dtrifiro committed Sep 27, 2024
1 parent b19b0ea commit 3fbd98e
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 9 deletions.
16 changes: 16 additions & 0 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu

from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)
Expand Down Expand Up @@ -281,6 +282,21 @@ def forward(
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif is_cpu():
seq_length = q.size(1)
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
Expand Down
92 changes: 83 additions & 9 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
Expand Down Expand Up @@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU:
query_lens=seq_lens,
)

def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
computed_len: int):
mm_kwargs = self.multi_modal_input_mapper(mm_data)

# special processing for mrope position deltas.
mrope_positions = None
if self.runner.model_is_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")

hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()

mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=computed_len,
)
seq_data.mrope_position_delta = mrope_position_delta
return mm_kwargs, mrope_positions

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand All @@ -153,6 +187,8 @@ def _prepare_prompt(
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]

slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
Expand All @@ -171,14 +207,20 @@ def _prepare_prompt(
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids

mrope_positions = None
if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
seq_data, mm_data, computed_len)
multi_modal_inputs_list.append(mm_kwargs)

# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))

if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if mrope_positions:
for idx in range(3):
input_mrope_positions[idx].extend(mrope_positions[idx])
else:
input_positions.extend(list(range(computed_len, seq_len)))

# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
Expand All @@ -202,12 +244,18 @@ def _prepare_prompt(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore

num_prompt_tokens = len(input_tokens)

input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
Expand Down Expand Up @@ -238,6 +286,7 @@ def _prepare_decode(
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
Expand All @@ -255,7 +304,17 @@ def _prepare_decode(

seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
if seq_data.mrope_position_delta is not None:
context_len = seq_data.get_num_computed_tokens()
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
input_mrope_positions[idx].extend(next_pos[idx])
else:
input_positions.append(position)

seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
Expand All @@ -273,12 +332,18 @@ def _prepare_decode(
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)

if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore

max_decode_seq_len = max(seq_lens)

input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
Expand Down Expand Up @@ -373,6 +438,15 @@ def __init__(
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])

@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
Expand Down

0 comments on commit 3fbd98e

Please sign in to comment.