Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 9, 2024
1 parent 76055b9 commit 5fd889e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 14 deletions.
1 change: 0 additions & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def execute_model(

# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
# torch.cuda.synchronize()
output_ready_event = torch.cuda.Event()
output_ready_event.record(current_stream)
if self.parallel_config.pipeline_parallel_size > 1:
Expand Down
17 changes: 4 additions & 13 deletions vllm/worker/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,28 +135,19 @@ def prepare_input(
pass
# cache the worker input and model input for the next steps
# TODO(will) see below

# self.multi_step_states[virtual_engine] = MultiStepState(
# worker_input=worker_input, model_input=model_input)
else:
# TODO(will) possible to also use the cached worker input and model input
# this can be done if we want to optimize the broadcast to only send
# the last sampled token ids for non-first multi steps
# TODO(will) possible to also use the cached worker input and
# model input this can be done if we want to optimize the
# broadcast to only send the last sampled token ids for
# non-first multi steps

# multi_step_state = self.multi_step_states[virtual_engine]
# cached_model_input = multi_step_state.model_input
# cached_worker_input = multi_step_state.worker_input
assert isinstance(
model_input, MutableModelInputForGPUWithMultiStepMetadata)
# we need to update the last sampled token ids in the model input
# for the workers so that they can run inplace advance_step
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
# self.multi_step_states[virtual_engine] = MultiStepState(
# worker_input=worker_input, model_input=model_input)
# model_input = cached_model_input
# worker_input = cached_worker_input

assert model_input is not None
assert worker_input is not None
Expand Down

0 comments on commit 5fd889e

Please sign in to comment.