Skip to content

Commit

Permalink
[Hardware][TPU] Support parallel sampling & Swapping (vllm-project#5855)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and prashantgupta24 committed Jun 28, 2024
1 parent 51c6b12 commit 27f91da
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 56 deletions.
30 changes: 22 additions & 8 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,35 @@ def get_kv_cache_shape(
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)

@torch.compile(backend="openxla")
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
) -> None:
raise NotImplementedError("swap_blocks is not implemented.")
src_k_cache, src_v_cache = src_kv_cache
dst_k_cache, dst_v_cache = dst_kv_cache
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)

device = dst_k_cache.device
src_indices, dst_indices = src_to_dst
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)

@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
# TODO(woosuk): Implement this.
raise NotImplementedError("copy_blocks is not implemented.")
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]


@dataclass
Expand Down
76 changes: 50 additions & 26 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
_PAD_SLOT_ID = 0 # FIXME(woosuk)
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128


class TPUModelRunner:
Expand Down Expand Up @@ -143,8 +146,9 @@ def _dummy_run(
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)

# Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata,
input_lens, t, p)
input_lens, t, p, num_samples)

def warmup_model(
self,
Expand Down Expand Up @@ -268,14 +272,11 @@ def _prepare_decode(
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
num_seq_groups = len(seq_group_metadata_list)
batch_size = _get_padded_batch_size(num_seq_groups)

for i, seq_group_metadata in enumerate(seq_group_metadata_list):
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt

seq_ids = list(seq_group_metadata.seq_data.keys())

for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
Expand All @@ -288,14 +289,16 @@ def _prepare_decode(

assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[i, :len(block_table)] = block_table
self.block_tables[batch_idx, :len(block_table)] = block_table
batch_idx += 1

block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])

num_paddings = batch_size - num_seq_groups
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
Expand Down Expand Up @@ -333,14 +336,13 @@ def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
padded_batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
t = []
p = []
best_of = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.sampling_params is not None
sampling_params = seq_group_metadata.sampling_params

# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
Expand All @@ -354,10 +356,11 @@ def _prepare_sample(
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.best_of > 1:
if sampling_params.best_of > _MAX_NUM_SAMPLES:
raise NotImplementedError(
"best_of > 1 is not currently supported by the TPU "
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
best_of.append(sampling_params.best_of)
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
Expand All @@ -369,13 +372,19 @@ def _prepare_sample(
"prompt_logprobs is not currently supported by the TPU "
"backend.")

num_paddings = padded_batch_size - len(seq_group_metadata_list)
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1)
best_of += [best_of[-1]] * (num_seqs - 1)

num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings
p += [1.0] * num_paddings

t = torch.tensor(t, dtype=torch.float32, device=self.device)
p = torch.tensor(p, dtype=torch.float32, device=self.device)
return t, p
return t, p, best_of

def _execute_model(
self,
Expand All @@ -392,28 +401,41 @@ def _execute_model(
else:
inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0]
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1

# Execute the model.
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:], t, p)
*inputs[2:], t, p, num_samples)
# Retrieve the outputs to CPU.
next_token_ids = next_token_ids.cpu().tolist()

# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
i = 0
zero_logprob = Logprob(0.0)
batch_idx = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
seq_outputs = []
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
next_token_id = next_token_ids[i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: Logprob(0.0)}))
i += 1
if is_prompt:
assert len(seq_ids) == 1
seq_id = seq_ids[0]
for i in range(best_of[batch_idx]):
next_token_id = next_token_ids[batch_idx][i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
else:
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return sampler_outputs
Expand Down Expand Up @@ -458,6 +480,7 @@ def forward(
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Expand Down Expand Up @@ -520,8 +543,9 @@ def forward(
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
next_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
return next_token_ids


Expand Down
97 changes: 75 additions & 22 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -117,19 +117,26 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Synchronize before measuring the memory usage.
xm.wait_device_ops()

dtype_btyes = get_dtype_size(self.cache_dtype)
block_size = self.cache_config.block_size
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
head_size * num_kv_heads)

# Calculate the TPU KV cache size based on profiling.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
profiled = m["bytes_used"] # Weights + intermediate activations.
kv_cache_bytes = max(usable_memory_size - profiled, 0)
dtype_btyes = get_dtype_size(self.cache_dtype)
block_size = self.cache_config.block_size
num_tpu_blocks = (kv_cache_bytes //
(dtype_btyes * block_size * num_layers * 2 *
head_size * num_kv_heads))
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, 0

# Calculate the CPU KV cache size based on the config.
num_cpu_blocks = (self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, num_cpu_blocks

def initialize_cache(
self,
Expand All @@ -145,15 +152,19 @@ def initialize_cache(
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()

self.cpu_cache = []
self.tpu_cache = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
key_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
value_cache = torch.zeros_like(key_cache)
self.tpu_cache.append((key_cache, value_cache))
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
self._warmup_model()

def _warmup_model(self) -> None:
Expand Down Expand Up @@ -187,26 +198,68 @@ def execute_model(
if not self.is_driver_worker:
self._execute_model_non_driver()
return []

assert execute_model_req is not None
# Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying.
assert len(execute_model_req.blocks_to_swap_in) == 0, (
"Swapping is not supported for the TPU backend.")
assert len(execute_model_req.blocks_to_swap_out) == 0, (
"Swapping is not supported for the TPU backend.")
assert len(execute_model_req.blocks_to_copy) == 0

# Issue cache operations.
self.cache_swap(
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy,
)
# Run the model.
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
assert len(seq_group_metadata_list) > 0
output = self.model_runner.execute_model(seq_group_metadata_list,
self.tpu_cache)
return [output]

def cache_swap(
self,
blocks_to_swap_in: List[Tuple[int, int]],
blocks_to_swap_out: List[Tuple[int, int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)

if blocks_to_swap_in:
# Swap from CPU to TPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
self.device)
for i in range(num_layers):
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
src_to_dst)
if blocks_to_swap_out:
# Swap from TPU to CPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
"cpu")
for i in range(num_layers):
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
src_to_dst)
if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
self.device)
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)

def start_worker_execution_loop(self) -> None:
while self._execute_model_non_driver():
pass

def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True


def _make_src_to_dst(
mapping: List[Tuple[int, int]],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
src_indices = [i for i, _ in mapping]
dst_indices = [i for _, i in mapping]
src_indices = torch.tensor(src_indices,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_indices,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices

0 comments on commit 27f91da

Please sign in to comment.