From c1b0e0a932d877b37c5d7e086a63a9bd9d379688 Mon Sep 17 00:00:00 2001 From: Will Lin Date: Thu, 8 Aug 2024 22:33:27 -0700 Subject: [PATCH] comments --- vllm/worker/multi_step_model_runner.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 97696a4bda34c..8da847987a147 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -274,6 +274,10 @@ def execute_model( self._base_model_runner.model.sampler.include_gpu_probs_tensor = True if frozen_model_input.sampling_metadata: frozen_model_input.sampling_metadata.skip_sampler_cpu_output = True + # TODO(will) Will need to benchmark and look at torch profiler for + # the exact location we should do this. If the CPU is very ahead, it + # does not matter if we call this before executable or after, as the + # CPU will block anyways. for model_output in model_input.outputs: model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids) @@ -288,12 +292,22 @@ def execute_model( # don't clobber any GPU tensors still in use current_stream = torch.cuda.current_stream() if model_input.is_first_multi_step: + # TODO(will) Need to double check that this is not possible due to + # changing batch sizes, will remove afterwards and potentially leave + # comment for future optimization if frozen_model_input.sampling_metadata: frozen_model_input.sampling_metadata.reuse_sampling_tensors = False else: + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs we may need to + # synchronize any CPU operations that might clobber enqueued + # forwards. (prevents CPU from running too far ahead if needed) model_input.wait_previous_step() model_input = self._advance_step( model_input, model_input.outputs[-1].sampler_output) + # TODO(will) Need to double check that this is not possible due to + # changing batch sizes, will remove afterwards and potentially leave + # comment for future optimization if frozen_model_input.sampling_metadata: frozen_model_input.sampling_metadata.reuse_sampling_tensors = False