Skip to content

Commit

Permalink
[core] [2/N] refactor worker_base input preparation for multi-step (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker authored and fialhocoelho committed Aug 22, 2024
1 parent 4eaccac commit 305c70d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 31 deletions.
2 changes: 2 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_steps = execute_model_req.num_steps
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
Expand All @@ -286,6 +287,7 @@ def prepare_worker_input(
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
)

@torch.inference_mode()
Expand Down
92 changes: 61 additions & 31 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class WorkerInput:
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
num_steps: int = 1

@classmethod
def from_broadcasted_tensor_dict(
Expand All @@ -145,6 +146,7 @@ def from_broadcasted_tensor_dict(
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
)

def as_broadcastable_tensor_dict(
Expand All @@ -158,6 +160,7 @@ def as_broadcastable_tensor_dict(
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
}

return tensor_dict
Expand Down Expand Up @@ -216,13 +219,50 @@ def execute_worker(self, worker_input: WorkerInput) -> None:
"""
raise NotImplementedError

def execute_model(
def _get_worker_input_from_broadcast(
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None

worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))

return model_input, worker_input

def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker

worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))

if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)

return model_input, worker_input

def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()
) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
Expand All @@ -233,34 +273,24 @@ def execute_model(
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None

worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
num_steps = execute_model_req.num_steps

if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_data["num_steps"] = num_steps
broadcast_tensor_dict(broadcast_data, src=0)
return self._get_driver_input_and_broadcast(execute_model_req)
else:
assert self.do_metadata_broadcast
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
return self._get_worker_input_from_broadcast()

def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()

inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None

num_steps = broadcast_data.pop("num_steps")
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
model_input, worker_input = inputs
num_steps = worker_input.num_steps

self.execute_worker(worker_input)

Expand Down

0 comments on commit 305c70d

Please sign in to comment.