Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into openvino-2024.3.0-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jun 28, 2024
2 parents 6bad9bf + 74d55c0 commit 2a633f2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 32 deletions.
16 changes: 4 additions & 12 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,13 @@ def get_kv_cache_shape(
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)

@torch.compile(backend="openxla")
@staticmethod
def swap_blocks(
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_k_cache, src_v_cache = src_kv_cache
dst_k_cache, dst_v_cache = dst_kv_cache
src_indices, dst_indices = src_to_dst
device = dst_k_cache.device
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
raise RuntimeError("swap_blocks is not used for the TPU backend.")

@torch.compile(backend="openxla")
@staticmethod
Expand Down
20 changes: 10 additions & 10 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GraphCaptureContext:


def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
Expand Down Expand Up @@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any:

def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
Expand Down Expand Up @@ -558,9 +558,9 @@ def broadcast_tensor_dict(

def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
Expand Down Expand Up @@ -599,7 +599,7 @@ def send_tensor_dict(
def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
Expand All @@ -615,15 +615,15 @@ def recv_tensor_dict(
assert src < self.world_size, f"Invalid src rank ({src})"

recv_metadata_list = self.recv_object(src=src)
tensor_dict = {}
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
Expand All @@ -633,9 +633,9 @@ def recv_tensor_dict(
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
else:
tensor_dict[key] = value
_update_nested_dict(tensor_dict, key, value)
return tensor_dict

def barrier(self):
Expand Down
42 changes: 32 additions & 10 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
import torch_xla.runtime as xr

import vllm.envs as envs
Expand Down Expand Up @@ -152,8 +153,8 @@ def initialize_cache(
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()

self.cpu_cache = []
self.tpu_cache = []
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
Expand Down Expand Up @@ -227,18 +228,25 @@ def cache_swap(

if blocks_to_swap_in:
# Swap from CPU to TPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
self.device)
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_in, "cpu", self.device)
for i in range(num_layers):
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
src_to_dst)
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)

if blocks_to_swap_out:
# Swap from TPU to CPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
"cpu")
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_out, self.device, "cpu")
for i in range(num_layers):
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
src_to_dst)
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()

if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
self.device)
Expand Down Expand Up @@ -267,3 +275,17 @@ def _make_src_to_dst(
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices


@torch.compile(backend="openxla")
def _insert_kv(
k: torch.Tensor,
v: torch.Tensor,
indices: torch.Tensor,
tpu_k_cache: torch.Tensor,
tpu_v_cache: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
tpu_k_cache[:, indices] = k
tpu_v_cache[:, indices] = v

0 comments on commit 2a633f2

Please sign in to comment.