Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NCCL instead of ray for control-plane communication to remove serialization overhead #2221

Merged
merged 35 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7265829
small test
zhuohan123 Dec 18, 2023
20274cc
test ray_pg
zhuohan123 Dec 19, 2023
1b73dd7
update ray test
zhuohan123 Dec 19, 2023
0d89354
implement driver worker
zhuohan123 Dec 20, 2023
e0c4c4e
broadcast swap info
zhuohan123 Dec 20, 2023
1baf87b
Broadcast inputmetadata as well
zhuohan123 Dec 20, 2023
c947fa0
fix bugs
zhuohan123 Dec 20, 2023
761584b
fix comments
zhuohan123 Dec 25, 2023
19110fb
remove unused files
zhuohan123 Dec 25, 2023
7b05ec6
fix async llm engine
zhuohan123 Dec 26, 2023
5f90351
fix format
zhuohan123 Dec 26, 2023
6f7ea32
Merge branch 'main' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
966e366
[BUGFIX] Fix API server test
zhuohan123 Dec 26, 2023
fe2c29a
fix and remove print
zhuohan123 Dec 26, 2023
5557cdb
fix test_cache
zhuohan123 Dec 26, 2023
d92b38d
Merge branch 'fix-test-api-server' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
c7f6c21
fix api test
zhuohan123 Dec 26, 2023
332d370
[BUGFIX] Fix the path of test prompts
zhuohan123 Dec 26, 2023
9a8c16f
Merge branch 'fix-test-prompt-path' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
6ea2a42
fix test_model_runner
zhuohan123 Dec 26, 2023
0434a76
Merge branch 'main' into remove-serialization-overhead
zhuohan123 Dec 27, 2023
95bb1d3
Fix async llm engine
zhuohan123 Dec 27, 2023
de4c8d2
[BUGFIX] Fix communication test
zhuohan123 Dec 27, 2023
89d7cfd
Merge branch 'fix-comm-test-2' into remove-serialization-overhead
zhuohan123 Dec 27, 2023
2b4863a
style
zhuohan123 Dec 27, 2023
3096c56
Fix smaller review comments
zhuohan123 Dec 28, 2023
dc4a4c2
fix
zhuohan123 Dec 28, 2023
f2b8e88
remove unused files
zhuohan123 Dec 28, 2023
83c2735
fix review comments
zhuohan123 Dec 28, 2023
3d3a547
allgather -> gather
zhuohan123 Jan 3, 2024
680c8d9
fix
zhuohan123 Jan 3, 2024
5280a61
fix and revert unnecessary changes
zhuohan123 Jan 3, 2024
03b2734
fix
zhuohan123 Jan 3, 2024
0ca5e07
fix
zhuohan123 Jan 3, 2024
ddb0795
fix review comments
zhuohan123 Jan 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers);
const std::map<int64_t, int64_t>& block_mapping);

void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers);
const std::map<int64_t, std::vector<int64_t>>& block_mapping);

void reshape_and_cache(
torch::Tensor& key,
Expand Down
27 changes: 13 additions & 14 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers) {
assert(src_block_numbers.size() == dst_block_numbers.size());
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
Expand All @@ -37,9 +35,9 @@ void swap_blocks(
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (int64_t i = 0; i < src_block_numbers.size(); ++i) {
int64_t src_block_number = src_block_numbers[i];
int64_t dst_block_number = dst_block_numbers[i];
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
Expand Down Expand Up @@ -87,8 +85,7 @@ __global__ void copy_blocks_kernel(
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers) {
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
Expand All @@ -107,10 +104,12 @@ void copy_blocks(
}
// Create block mapping array.
std::vector<int64_t> block_mapping_vec;
assert(src_block_numbers.size() == dst_block_numbers.size());
for (int i = 0; i < src_block_numbers.size(); ++i) {
block_mapping_vec.push_back(src_block_numbers[i]);
block_mapping_vec.push_back(dst_block_numbers[i]);
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
Expand Down Expand Up @@ -253,12 +252,12 @@ __global__ void gather_cached_kv_kernel(
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
const int tgt_key_idx = token_idx * key_stride + i;
const int tgt_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
const int x_offset = head_offset % x;

const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
Expand Down
29 changes: 13 additions & 16 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,38 +155,35 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self.driver_dummy_worker: RayWorkerVllm = None
self.workers: List[RayWorkerVllm] = []

driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue

scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)

if (bundle.get("node:__internal_head__", 0) > 0
and self.driver_dummy_worker is None):
self.driver_dummy_worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote()
continue

worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
self.workers.append(worker)

worker_ip = ray.get(worker.get_node_ip.remote())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a minor optimization you can make it another loop so that the Workers can be initialized in a non-blocking fashion but considering that there's nothing really happening in the __init__ I think it's ok to leave it in (though it is an anti-pattern).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah also this only happens once so I think this should not relate to the performance.

if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
else:
self.workers.append(worker)

if self.driver_dummy_worker is None:
raise ValueError(
"Placement group must have a bundle with host resources for "
"the driver process.")
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")

driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
Expand All @@ -210,7 +207,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])

distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"

# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Expand Down
19 changes: 5 additions & 14 deletions vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip, set_cuda_visible_devices
from vllm.utils import is_hip, set_cuda_visible_devices, get_ip

logger = init_logger(__name__)

Expand All @@ -29,6 +29,9 @@ def execute_method(self, method, *args, **kwargs):
executor = getattr(self, method)
return executor(*args, **kwargs)

def get_node_ip(self) -> str:
return get_ip()

def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
Expand Down Expand Up @@ -92,37 +95,25 @@ def initialize_cluster(
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
gpu_bundles = 0
have_host_bundle = False
for bundle in bundles:
bundle_gpus = bundle.get("GPU", 0)
if bundle_gpus > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 GPU.")
if bundle_gpus:
gpu_bundles += 1
if bundle.get("node:__internal_head__", 0) > 0:
have_host_bundle = True
if parallel_config.world_size > gpu_bundles:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group.")
if not have_host_bundle:
raise ValueError(
"Placement group must have a bundle with host resources for "
"the driver process.")
else:
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.")
# Create a new placement group
placement_group_specs = ([{
"GPU": 1,
"node:__internal_head__": 0.01
}] + [{
"GPU": 1
}] * (parallel_config.world_size - 1))
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/parallel_utils/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
Expand Down
28 changes: 10 additions & 18 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""CacheEngine class for managing the KV cache."""
from typing import List, Tuple
from typing import Dict, List, Tuple

import torch

Expand Down Expand Up @@ -113,39 +113,31 @@ def _swap(
self,
src: List[KVCache],
dst: List[KVCache],
src_block_numbers: List[int],
dst_block_numbers: List[int],
src_to_dst: Dict[int, int],
) -> None:
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(src_key_cache, dst_key_cache,
src_block_numbers, dst_block_numbers)
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_block_numbers, dst_block_numbers)
src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)

def swap_in(self, src_block_numbers: List[int],
dst_block_numbers: List[int]) -> None:
self._swap(self.cpu_cache, self.gpu_cache, src_block_numbers,
dst_block_numbers)
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)

def swap_out(self, src_block_numbers: List[int],
dst_block_numbers: List[int]) -> None:
self._swap(self.gpu_cache, self.cpu_cache, src_block_numbers,
dst_block_numbers)
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)

def copy(self, src_block_numbers: List[int],
dst_block_numbers: List[int]) -> None:
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops.copy_blocks(key_caches, value_caches, src_block_numbers,
dst_block_numbers)
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)

@staticmethod
def get_cache_block_size(
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def get_size_or_none(x: Optional[torch.Tensor]):
sampling_metadata.selected_token_indices.size(),
}
broadcast_object_list([py_data], src=0)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast(input_tokens, src=0)
broadcast(input_positions, src=0)
if input_metadata.slot_mapping is not None:
Expand Down
51 changes: 20 additions & 31 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,22 @@ def warm_up_model(self) -> None:
# the model initialization and profiling.
set_random_seed(self.model_config.seed)

def cache_swap(self, swap_in_src: List[int], swap_in_dst: List[int],
swap_out_src: List[int], swap_out_dst: List[int],
copy_src: List[int], copy_dst: List[int]) -> None:
def cache_swap(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
if len(swap_in_src) > 0:
assert len(swap_in_src) == len(swap_in_dst)
self.cache_engine.swap_in(swap_in_src, swap_in_dst)
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if len(swap_out_src) > 0:
assert len(swap_out_src) == len(swap_out_dst)
self.cache_engine.swap_out(swap_out_src, swap_out_dst)
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if len(copy_src) > 0:
assert len(copy_src) == len(copy_dst)
self.cache_engine.copy(copy_src, copy_dst)
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True

cache_events = self.cache_events if issued_cache_op else None
Expand All @@ -167,31 +167,20 @@ def execute_model(
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
# Turn the dictionaries into lists for communication.
swap_in_src = list(blocks_to_swap_in.keys())
swap_in_dst = list(blocks_to_swap_in.values())
swap_out_src = list(blocks_to_swap_out.keys())
swap_out_dst = list(blocks_to_swap_out.values())
copy_src = []
copy_dst = []
for src, dst_list in blocks_to_copy.items():
copy_src.extend([src] * len(dst_list))
copy_dst.extend(dst_list)
swapping_block_numbers = [
swap_in_src, swap_in_dst, swap_out_src, swap_out_dst, copy_src,
copy_dst
block_swapping_info = [
blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy
]
broadcast_object_list([num_seq_groups] + swapping_block_numbers,
broadcast_object_list([num_seq_groups] + block_swapping_info,
src=0)
else:
# num_seq_groups, swap_in_src, swap_in_dst, swap_out_src,
# swap_out_dst, copy_src, copy_dst (7 elements)
recv_data = [None] * 7
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data = [None] * 4
broadcast_object_list(recv_data, src=0)
num_seq_groups = recv_data[0]
swapping_block_numbers = recv_data[1:]
block_swapping_info = recv_data[1:]

self.cache_swap(*swapping_block_numbers)
self.cache_swap(*block_swapping_info)

# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
Expand Down
Loading