Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 9, 2024
1 parent 5fd889e commit c1b0e0a
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def execute_model(
self._base_model_runner.model.sampler.include_gpu_probs_tensor = True
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = True
# TODO(will) Will need to benchmark and look at torch profiler for
# the exact location we should do this. If the CPU is very ahead, it
# does not matter if we call this before executable or after, as the
# CPU will block anyways.
for model_output in model_input.outputs:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
Expand All @@ -288,12 +292,22 @@ def execute_model(
# don't clobber any GPU tensors still in use
current_stream = torch.cuda.current_stream()
if model_input.is_first_multi_step:
# TODO(will) Need to double check that this is not possible due to
# changing batch sizes, will remove afterwards and potentially leave
# comment for future optimization
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.reuse_sampling_tensors = False
else:
# This is not needed for flashattn backend, but for other attn
# backends such as flashinfer that performs we may need to
# synchronize any CPU operations that might clobber enqueued
# forwards. (prevents CPU from running too far ahead if needed)
model_input.wait_previous_step()
model_input = self._advance_step(
model_input, model_input.outputs[-1].sampler_output)
# TODO(will) Need to double check that this is not possible due to
# changing batch sizes, will remove afterwards and potentially leave
# comment for future optimization
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.reuse_sampling_tensors = False

Expand Down

0 comments on commit c1b0e0a

Please sign in to comment.