Skip to content

Commit

Permalink
[Hardware][CPU] Refactor CPU model runner (vllm-project#8729)
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and MengqingCao committed Sep 30, 2024
1 parent 0bbf5a6 commit 32564cc
Showing 1 changed file with 193 additions and 109 deletions.
302 changes: 193 additions & 109 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import dataclasses
import weakref
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union

Expand All @@ -17,7 +19,7 @@
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
Expand All @@ -32,16 +34,17 @@


@dataclass(frozen=True)
class CPUModelInput(ModelRunnerInputBase):
class ModelInputForCPU(ModelRunnerInputBase):
"""
Used by the CPUModelRunner.
Base class contains metadata needed for the base model forward pass on CPU
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
Expand All @@ -51,88 +54,96 @@ def as_broadcastable_tensor_dict(
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)

return tensor_dict

@classmethod
def from_broadcasted_tensor_dict(
cls: Type["CPUModelInput"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> "CPUModelInput":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
cls: Type["ModelInputForCPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> "ModelInputForCPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)


class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict

self.device = self.device_config.device
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForCPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)

self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)

# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):

# Lazy initialization.
self.model: nn.Module # Set after init_Model
def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper

if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
def build(self) -> ModelInputForCPU:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = self.seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) = self._prepare_prompt(
self.seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = []

return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens=seq_lens,
query_lens=seq_lens,
)

def _prepare_prompt(
self,
Expand Down Expand Up @@ -165,8 +176,7 @@ def _prepare_prompt(
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))

mm_data = seq_group_metadata.multi_modal_data
if mm_data:
if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)

Expand Down Expand Up @@ -302,56 +312,130 @@ def _prepare_decode(
attn_metadata,
)


class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker

self.device = self.device_config.device

self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)

# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)

# Lazy initialization.
self.model: nn.Module # Set after init_Model

if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)

def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> CPUModelInput:
return CPUModelInput.from_broadcasted_tensor_dict(
) -> ModelInputForCPU:
return ModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)

def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)

return builder.build() # type: ignore

def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> CPUModelInput:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False,
generators=self.get_generators(finished_requests_ids))
return CPUModelInput(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs,
)
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)

return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine)

@torch.no_grad()
def execute_model(
self,
model_input: CPUModelInput,
model_input: ModelInputForCPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
Expand Down

0 comments on commit 32564cc

Please sign in to comment.