Skip to content

Commit

Permalink
[Hardware][intel GPU] add async output process for xpu (vllm-project#…
Browse files Browse the repository at this point in the history
…8897)

Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
jikunshang authored and garg-amit committed Oct 28, 2024
1 parent afa0a09 commit efae0ab
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ def verify_async_output_proc(self, parallel_config, speculative_config,

# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type not in ("cuda", "tpu"):
if device_config.device_type not in ("cuda", "tpu", "xpu"):
logger.warning(
"Async output processing is only supported for CUDA or TPU. "
"Async output processing is only supported for CUDA, TPU, XPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return
Expand Down
8 changes: 6 additions & 2 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import time
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -57,6 +57,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand Down Expand Up @@ -598,6 +599,9 @@ def execute_model(
if not self.is_driver_worker:
return []

if model_input.async_callback is not None:
model_input.async_callback()

# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
Expand Down

0 comments on commit efae0ab

Please sign in to comment.