diff --git a/tests/models/utils.py b/tests/models/utils.py index ff29a0ae81d6e..93ec03995094b 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import SampleLogprobs +from vllm.sequence import Logprob, SampleLogprobs TokensText = Tuple[List[int], str] @@ -38,34 +38,39 @@ def check_outputs_equal( float]], SampleLogprobs]]] +# Allow for tokens to be represented as str's rather than IDs +TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], + List[Dict[str, + Logprob]]]]] + def check_logprobs_close( *, - outputs_0_lst: Sequence[TokensTextLogprobs], - outputs_1_lst: Sequence[TokensTextLogprobs], + outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], + outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, warn_on_mismatch: bool = True, -): - """ - Compare the logprobs of two sequences generated by different models, + always_check_logprobs: bool = False, +) -> None: + """Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. - Arguments: - - * outputs_0_lst: First sequence to compare - * outputs_0_lst: Second sequence to compare - * name_0: sequence #0 name - * name_1: sequence #1 name - * num_outputs_0_skip_tokens: If > 0, specifies the number of initial + Args: + outputs_0_lst: First sequence to compare + outputs_0_lst: Second sequence to compare + name_0: sequence #0 name + name_1: sequence #1 name + num_outputs_0_skip_tokens: If > 0, specifies the number of initial sequence #0 tokens & logprobs to discard before comparison, i.e. all of sequence #1 will be compared to sequence #0 beginning at index num_outputs_0_skip_tokens - * warn_on_mismatch: Issue a warning if there is token-wise or text-wise + warn_on_mismatch: Issue a warning if there is token-wise or text-wise mismatch between the two sequences + always_check_logprobs: If true, check logprobs even when tokens match """ assert len(outputs_0_lst) == len(outputs_1_lst) @@ -94,8 +99,12 @@ def check_logprobs_close( for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - # If generated tokens don't match, then - if output_id_0 != output_id_1: + is_tok_mismatch = output_id_0 != output_id_1 + + # If generated tokens don't match + # or it is desired to always check logprobs, + # then + if is_tok_mismatch or always_check_logprobs: logprobs_elem_0 = logprobs_0[idx] logprobs_elem_1 = logprobs_1[idx] @@ -111,7 +120,7 @@ def check_logprobs_close( assert output_id_0 in logprobs_elem_1, fail_msg assert output_id_1 in logprobs_elem_0, fail_msg - if warn_on_mismatch: + if warn_on_mismatch and is_tok_mismatch: with warnings.catch_warnings(): # This ensures that repeated warnings are shown # in the output, not just the first occurrence diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index ac04be3d9a689..d054ca341694a 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -1,10 +1,12 @@ # Test the AsyncLLMEngine with multi-step-decoding -from typing import List +from typing import List, Optional import pytest -from ..utils import RemoteOpenAIServer +from ..models.utils import check_logprobs_close +from ..utils import (completions_with_server_args, get_client_text_generations, + get_client_text_logprob_generations) MODELS = [ "JackFram/llama-160m", @@ -23,22 +25,6 @@ ] -async def completions_with_server_args(prompts: List[str], model_name: str, - server_cli_args: List[str]): - - outputs = None - with RemoteOpenAIServer(model_name, server_cli_args) as server: - async with server.get_async_client() as client: - outputs = await client.completions.create(model=model_name, - prompt=prompts, - temperature=0, - stream=False, - max_tokens=5) - assert outputs is not None - - return outputs - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize(("tp_size, pp_size"), [ (1, 1), @@ -47,12 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs", [None, 5]) @pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio -async def test_multi_step(example_prompts, model: str, tp_size: int, - pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int, - is_async: bool): +async def test_multi_step( + example_prompts, + model: str, + tp_size: int, + pp_size: int, + eager_mode: int, + num_scheduler_steps: int, + num_prompts: int, + is_async: bool, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step scheduling in an OpenAI-protocol + client/server environment. + + Set up an engine with single-step scheduling as a ground-truth reference. + + Send a completions API request to both engines with the same prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + tp_size: degree of tensor-parallelism + pp_size: degree of pipeline-parallelism + eager_mode + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + """ prompts = example_prompts if len(prompts) < num_prompts: @@ -77,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, str(pp_size), ] + # Spin up client/server & issue completion API requests. + # Default `max_wait_seconds` is 240 but was empirically + # was raised 3x to 720 *just for this test* due to + # observed timeouts in GHA CI ref_completions = await completions_with_server_args( - prompts, model, server_args + distributed_args) + prompts, + model, + server_args + distributed_args, + num_logprobs, + max_wait_seconds=3 * 240) test_completions = await completions_with_server_args( - prompts, model, ms_server_args + distributed_args) - - def get_text_generations(completions): - return [x.text for x in completions.choices] - - ref_generations = get_text_generations(ref_completions) - test_generations = get_text_generations(test_completions) + prompts, + model, + ms_server_args + distributed_args, + num_logprobs, + max_wait_seconds=3 * 240) + + # Assert multi-step scheduling produces identical tokens + # to single-step scheduling. + ref_generations = get_client_text_generations(ref_completions) + test_generations = get_client_text_generations(test_completions) assert ref_generations == test_generations + + # Assert multi-step scheduling produces nearly-identical logprobs + # to single-step scheduling. + ref_text_logprobs = get_client_text_logprob_generations(ref_completions) + test_text_logprobs = get_client_text_logprob_generations(test_completions) + check_logprobs_close( + outputs_0_lst=ref_text_logprobs, + outputs_1_lst=test_text_logprobs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 36f610ba74f05..50c85df932e25 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -1,8 +1,10 @@ # Test the LLMEngine with multi-step-decoding +from typing import Optional + import pytest -from ..models.utils import check_outputs_equal +from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ "JackFram/llama-160m", @@ -18,10 +20,45 @@ @pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, tp_size: int, max_tokens: int, - enforce_eager: int, num_scheduler_steps: int, - num_prompts: int) -> None: +@pytest.mark.parametrize("num_logprobs", [None, 5]) +def test_multi_step_llm( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step scheduling via sync LLM Engine. + + Set up a HuggingFace (HF) transformers model as a ground-truth reference. + + Prompt them with the same example prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + """ prompts = example_prompts if len(prompts) < num_prompts: @@ -29,21 +66,37 @@ def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, prompts = prompts[:num_prompts] assert len(prompts) == num_prompts - with vllm_runner(model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - use_v2_block_manager=True, - num_scheduler_steps=num_scheduler_steps) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs)) with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + hf_outputs = (hf_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + hf_model.generate_greedy_logprobs_limit( + prompts, max_tokens, num_logprobs)) + + if num_logprobs is None: + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + else: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index ada6c37d9af8d..e7a0af4377630 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,9 +5,10 @@ import pytest import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, - SamplerOutput, get_all_seq_ids) + get_all_seq_ids) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 9ae1b4bc40f0f..cbaffee2f41e2 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -7,8 +7,9 @@ import pytest import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput +from vllm.sequence import ExecuteModelRequest, SequenceOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 60b36a33d9077..9075a433eb66e 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -8,12 +8,12 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceData, SequenceGroupMetadata, - SequenceOutput) + SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 1ae349e808e0d..348ba7dd41d99 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,9 +2,10 @@ import pytest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, SamplerOutput, - SequenceData, SequenceOutput) + CompletionSequenceGroupOutput, SequenceData, + SequenceOutput) from .core.utils import create_dummy_prompt diff --git a/tests/utils.py b/tests/utils.py index de887bc8cf6fb..cd8d7b1f25905 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,9 +11,11 @@ import openai import requests +from openai.types.completion import Completion from transformers import AutoTokenizer from typing_extensions import ParamSpec +from tests.models.utils import TextTextLogprobs from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs @@ -432,3 +434,61 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: f" args {args} and kwargs {kwargs}") return wrapper + + +async def completions_with_server_args( + prompts: List[str], + model_name: str, + server_cli_args: List[str], + num_logprobs: Optional[int], + max_wait_seconds: int = 240, +) -> Completion: + '''Construct a remote OpenAI server, obtain an async client to the + server & invoke the completions API to obtain completions. + + Args: + prompts: test prompts + model_name: model to spin up on the vLLM server + server_cli_args: CLI args for starting the server + num_logprobs: Number of logprobs to report (or `None`) + max_wait_seconds: timeout interval for bringing up server. + Default: 240sec + + Returns: + OpenAI Completion instance + ''' + + outputs = None + with RemoteOpenAIServer(model_name, + server_cli_args, + max_wait_seconds=max_wait_seconds) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5, + logprobs=num_logprobs) + assert outputs is not None + + return outputs + + +def get_client_text_generations(completions: Completion) -> List[str]: + '''Extract generated tokens from the output of a + request made to an Open-AI-protocol completions endpoint. + ''' + return [x.text for x in completions.choices] + + +def get_client_text_logprob_generations( + completions: Completion) -> List[TextTextLogprobs]: + '''Operates on the output of a request made to an Open-AI-protocol + completions endpoint; obtains top-rank logprobs for each token in + each :class:`SequenceGroup` + ''' + text_generations = get_client_text_generations(completions) + text = ''.join(text_generations) + return [(text_generations, text, + (None if x.logprobs is None else x.logprobs.top_logprobs)) + for x in completions.choices] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3058214c50a5f..159281dabde4a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -22,11 +22,12 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 59baf1ef40dfc..aa33933c668ed 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -33,6 +33,7 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) @@ -40,8 +41,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + Sequence, SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 49a33ded5fcaa..0209b0adc9831 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -4,6 +4,8 @@ from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) +from vllm.engine.output_processor.single_step import ( + single_step_process_prompt_logprob) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams @@ -46,9 +48,16 @@ def __init__( def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: - # TODO(sang): Prompt logprob currently not implemented in multi step - # workers. - self._log_prompt_logprob_unsupported_warning_once() + """Process prompt logprobs associated with each step of a multi-step- + scheduled computation. + + Args: + seq_group: the outputs are associated with this :class:`SequenceGroup` + outputs: the :class:`SequenceGroupOutput`s for all scheduler steps + """ + for output in outputs: + # Concatenate single-step prompt logprob processing results. + single_step_process_prompt_logprob(self, seq_group, output) @staticmethod @functools.lru_cache() diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4b0c3f37a5e21..422e6d30522f5 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -15,6 +15,44 @@ logger = init_logger(__name__) +def single_step_process_prompt_logprob( + sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, + output: SequenceGroupOutput) -> None: + """Process prompt logprobs associated with the :class:`SequenceGroupOutput` + for a given step. + + Do nothing if the output has no prompt logprobs. + + Account for the fact that transformers do not compute first-token logprobs. + + Args: + sg_output_proc: :class:`SequenceGroupOutputProcessor` instance + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ + prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. + if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + + assert hasattr(sg_output_proc, 'detokenizer') + if (seq_group.sampling_params.detokenize + and sg_output_proc.detokenizer): + sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + + seq_group.prompt_logprobs.extend(prompt_logprobs) + + class SingleStepOutputProcessor(SequenceGroupOutputProcessor): """SequenceGroupOutputProcessor which handles "output processing" logic, which happens after the model returns generated token ids and before @@ -60,27 +98,16 @@ def process_outputs(self, sequence_group: SequenceGroup, def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with one step of a single-step- + scheduled computation. + + Args: + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] - prompt_logprobs = output.prompt_logprobs - - # If this is the first (or only) "chunk" of the prefill, we need - # to prepend None to the list of prompt logprobs. The reason for this - # is that for N prompt tokens, the Sampler will generate N-1 total - # prompt logprobs during prefill since the token at idx 0 will not - # have a logprob associated with it. - if prompt_logprobs is not None: - if not seq_group.prompt_logprobs: - prompt_logprobs = [None] + prompt_logprobs - seq_group.prompt_logprobs = [] - - if seq_group.sampling_params.detokenize and self.detokenizer: - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, - prompt_logprobs, - position_offset=len(seq_group.prompt_logprobs)) - - seq_group.prompt_logprobs.extend(prompt_logprobs) + single_step_process_prompt_logprob(self, seq_group, output) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput, diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 57cc33d911183..76782888031e3 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -2,7 +2,8 @@ from typing import Sequence as GenericSequence from typing import Union -from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import PoolerOutput, SequenceGroupOutput def create_output_by_sequence_group( diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 1deb75167bc72..34ae79f5fa8df 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -5,11 +5,11 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput from vllm.transformers_utils.tokenizer import AnyTokenizer diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 37d12725bd1e4..21ad43f641685 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -11,8 +11,9 @@ ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async) from vllm.worker.worker_base import WorkerWrapperBase diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 1a35a7c3b8f75..ad84422ee2129 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -6,7 +6,8 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest logger = init_logger(__name__) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 422bef107f352..c96cb0f2c2981 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -6,8 +6,9 @@ PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest class ExecutorBase(ABC): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 795692195f84d..947776e5d6ef4 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,8 +3,9 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 02b2499be4656..9c6d4051eb3f8 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -14,7 +14,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, get_distributed_init_method, get_open_port, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 02627de3e0be7..f2fcfa58b26e1 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -3,7 +3,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 867859d8d3d79..78606e223aa7b 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -9,7 +9,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 760c06cb6c06f..ab8844bcdafec 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -12,7 +12,8 @@ from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 7048d47980723..2a1fd35b65797 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -10,7 +10,8 @@ from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.tpu_executor import TPUExecutor from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 253c8abdc1ada..0af8ba41e24d5 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -5,7 +5,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 774204dd4612a..bada56068507a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -9,7 +9,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async from vllm.worker.worker_base import WorkerBase diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 7344d59e988f0..c00da106734ae 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,13 +1,16 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools import warnings +from dataclasses import dataclass from importlib.util import find_spec from math import inf -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union +import msgspec import torch import torch.nn as nn +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.triton_utils import HAS_TRITON if HAS_TRITON: @@ -19,8 +22,7 @@ SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceOutput) + PromptLogprobs, SampleLogprobs, SequenceOutput) if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -35,6 +37,116 @@ # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] +# Types of temporary data structures used for +# computing sample_result +SampleMetadataType = Dict[SamplingType, Tuple[List[int], + List[SequenceGroupToSample]]] +MultinomialSamplesType = Dict[SamplingType, torch.Tensor] +SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] + + +# Encapsulates temporary data structures for computing +# sample_result. +# +# * For multi-step scheduling: must be returned +# by `Sampler.forward()` and used later to compute the pythonized +# sample_result +# +# * For single-step scheduling: consumed immediately +# inside `Sampler.forward()` to compute pythonized sample_result. +@dataclass +class SampleResultArgsType: + sample_metadata: SampleMetadataType + multinomial_samples: MultinomialSamplesType + sample_results_dict: SampleResultsDictType + sampling_metadata: SamplingMetadata + greedy_samples: Optional[torch.Tensor] + beam_search_logprobs: Optional[torch.Tensor] + + +# Union of non-deferred (single-step scheduling) +# vs deferred (multi-step scheduling) +# sample result types +MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] + +# Abbreviation of the _sample() return type +SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] + + +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """For each sequence group, we generate a list of SequenceOutput object, + each of which contains one possible candidate for the next token. + + This data structure implements methods, so it can be used like a list, but + also has optional fields for device tensors. + """ + + outputs: List[CompletionSequenceGroupOutput] + + # On-device tensor containing probabilities of each token. + sampled_token_probs: Optional[torch.Tensor] = None + + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + + # Holds either (1) the pythonized sampler result (single-step scheduling) + # or (2) what will be arguments for later deferred pythonization of the + # sampler result (muliti-step scheduling) + deferred_sample_results_args: Optional[SampleResultArgsType] = None + + # On-device tensor containing the sampled token ids. + sampled_token_ids: Optional[torch.Tensor] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None + + # Spec decode metrics populated by workers. + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None + + # Optional last hidden states from the model. + hidden_states: Optional[torch.Tensor] = None + + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + + # Time taken in the model execute function. This will include model forward, + # block/sync across workers, cpu-gpu sync time and sampling time. + model_execute_time: Optional[float] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -98,6 +210,19 @@ def forward( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ + Single-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Pythonize sampling result & logprobs tensor + + Multi-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Defer Pythonization of sampling result & logprobs + tensor + * Encapsulate arguments required for deferred Pythonization + in the :class:`SamplerOutput` structure + Args: logits: (num_tokens, vocab_size). sampling_metadata: Metadata for sampling. @@ -150,7 +275,7 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results, maybe_sampled_tokens_tensor = _sample( + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( probs, logprobs, sampling_metadata, @@ -160,20 +285,28 @@ def forward( ) if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs assert maybe_sampled_tokens_tensor is not None on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. on_device_tensors = None # Get the logprobs query results. prompt_logprobs = None sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) return _build_sampler_output( - sample_results, + maybe_deferred_sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, @@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer( return batch_next_token_ids.view(-1, num_samples) +def get_pythonized_sample_results( + sample_result_args: SampleResultArgsType) -> SampleResultType: + '''This function consumes GPU-side sampler results and computes + Pythonized CPU-side sampler results (GPU -> CPU sync.) + + Single-step scheduling: this function is invoked at sampling-time + for immediate Pythonization. + + Multi-step scheduling: Pythonization is deferred until after multiple + GPU-side steps have been completed. + + Args: + sample_result_args: GPU-side inputs to the Pythonization process + + Returns: + Pythonized sampler results + ''' + + ( + sample_metadata, + sampling_metadata, + greedy_samples, + multinomial_samples, + beam_search_logprobs, + sample_results_dict, + ) = ( + sample_result_args.sample_metadata, + sample_result_args.sampling_metadata, + sample_result_args.greedy_samples, + sample_result_args.multinomial_samples, + sample_result_args.beam_search_logprobs, + sample_result_args.sample_results_dict, + ) + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, + multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + return [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + + def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, @@ -550,7 +737,19 @@ def _sample_with_torch( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[SampleResultType, Optional[torch.Tensor]]: +) -> SampleReturnType: + '''Torch-oriented _sample() implementation. + + Single-step scheduling: + * Perform GPU-side sampling computation + * Immediately Pythonize sampling result + + Multi-step scheduling: + * Perform GPU-side sampling computation + * Defer Pythonization & preserve GPU-side + tensors required for Pythonization + ''' + categorized_seq_group_ids: Dict[SamplingType, List[int]] = {t: [] for t in SamplingType} @@ -560,10 +759,11 @@ def _sample_with_torch( sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata: Dict[SamplingType, - Tuple[List[int], List[SequenceGroupToSample]]] = {} - multinomial_samples: Dict[SamplingType, torch.Tensor] = {} + sample_results_dict: SampleResultsDictType = {} + sample_metadata: SampleMetadataType = {} + multinomial_samples: MultinomialSamplesType = {} + greedy_samples: Optional[torch.Tensor] = None + beam_search_logprobs: Optional[torch.Tensor] = None # Create output tensor for sampled token ids. if include_gpu_probs_tensor: @@ -638,32 +838,29 @@ def _sample_with_torch( else: raise ValueError(f"Unsupported sampling type: {sampling_type}") - # GPU<->CPU sync happens in the loop below. - # This also converts the sample output to Python objects. + # Encapsulate arguments for computing Pythonized sampler + # results, whether deferred or otherwise. + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=sample_metadata, + multinomial_samples=multinomial_samples, + greedy_samples=greedy_samples, + beam_search_logprobs=beam_search_logprobs, + sample_results_dict=sample_results_dict) + if not sampling_metadata.skip_sampler_cpu_output: - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, - SamplingType.RANDOM_SEED): - sample_results = _random_sample( - seq_groups, multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] + # GPU<->CPU sync happens here. + # This also converts the sampler output to a Python object. + # Return Pythonized sampler result & sampled token ids + return get_pythonized_sample_results( + maybe_deferred_args), sampled_token_ids_tensor else: - sample_results = [] - - return sample_results, sampled_token_ids_tensor + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) def _sample_with_triton_kernel( @@ -755,7 +952,7 @@ def _sample( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[SampleResultType, Optional[torch.Tensor]]: +) -> SampleReturnType: """ Args: probs: (num_query_tokens_in_batch, num_vocab) @@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: return result.sum(1).add_(1) -def _get_logprobs( +def get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: SampleResultType, @@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - sample_results: SampleResultType, + maybe_deferred_sample_results: MaybeDeferredSampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], sample_logprobs: Optional[List[SampleLogprobs]], @@ -1143,14 +1340,21 @@ def _build_sampler_output( speculative decoding rejection sampling. """ sampler_output: List[CompletionSequenceGroupOutput] = [] - if not skip_sampler_cpu_output: + + if skip_sampler_cpu_output: + assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) + deferred_sample_results_args = maybe_deferred_sample_results + else: assert prompt_logprobs is not None assert sample_logprobs is not None + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + deferred_sample_results_args = None for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): + maybe_deferred_sample_results, + prompt_logprobs, sample_logprobs): seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs: List[SequenceOutput] = [] @@ -1176,7 +1380,7 @@ def _build_sampler_output( sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, logprobs=logprobs_tensor, - ) + deferred_sample_results_args=deferred_sample_results_args) def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 24fa13d7e5fe5..7396ac833e782 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -10,9 +10,8 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput TORCH_DTYPE_TO_NEURON_AMP = { "auto": "f32", diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index 5c522a61732a4..3c1f6fa769894 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -15,9 +15,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import (LogitsProcessor, _prune_hidden_states) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput logger = init_logger(__name__) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 28f69cfbc46bd..efa044d0b5e92 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig logger = init_logger(__name__) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 73711d8eb5185..bdd76b11384c2 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index f78400b0df7b3..9b4c4be7fcb09 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors logger = logging.get_logger(__name__) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8be786fd3f6f5..0ed46f39cacd9 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -13,13 +13,13 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 07ee0e3c531d0..831b3f20457a9 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index b25f5d521a9bf..47e020e8ecb73 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -33,7 +33,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 4949d0232fabb..35f1ed5ef5d33 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -20,12 +20,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index f63cf246e510a..be7f19d15b623 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -38,14 +38,14 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors @torch.compile diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index dca959798e8b2..6160197dc19de 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -17,13 +17,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 7a27e1388e987..61cc917ab6207 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -43,12 +43,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class DeepseekMLP(nn.Module): diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c7f3af0ccb266..8cbd9435ec7ca 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -43,12 +43,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 99c825ff63572..ad1ab0231d861 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -5,12 +5,13 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.eagle import EAGLEConfig diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 7b97b3d255dfa..b474d35baf89d 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -39,12 +39,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig FalconConfig = Union[HF_FalconConfig, RWConfig] diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 6cdf331fed8b7..beeae14229575 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -31,6 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -39,7 +40,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index e1041edf81b0a..36fd389831282 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 5e0f8b70d4b80..90449ec51ef0b 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index bfc231282952a..fb5a297661ddc 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b93fb8d69b2d7..fe5ec10827608 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 4d52b448049b4..664d775c8ba40 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class GPTJAttention(nn.Module): diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 2adecf7fa9ef8..5f6f1e3880547 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class GPTNeoXAttention(nn.Module): diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 499cdb43fc8b2..9b7cada187ce1 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -17,12 +17,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class InternLM2MLP(nn.Module): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index ca4d773190e0f..5ca8d0b6a2922 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -18,13 +18,14 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index a550f7e6c97a1..b0fbb7e9829e0 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ac3b59f95f7e0..73be7ffed0f89 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -27,14 +27,14 @@ selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0c67a9b8e198b..e55c01316087c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -42,13 +42,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_hip from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 490c93294d50f..43c485bdf3668 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -11,10 +11,11 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 048ca16974e3c..5a179e9603710 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -15,10 +15,11 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 55d42952cd0cc..619a5cd00d6b6 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -4,11 +4,11 @@ import torch.nn as nn from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.medusa import MedusaConfig diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index ff42bdefe0269..a135118bc748e 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -44,13 +44,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 6a3d5422e0ce4..dd10729b9ffb5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -57,7 +57,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .idefics2_vision_model import Idefics2VisionTransformer diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 413783ba4b259..e744e36ac08bf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,13 +39,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 8bdd52b343175..68471f6ac77d1 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -42,12 +42,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class MixtralMLP(nn.Module): diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 9b96ecb78a3c9..42ccd01298169 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -6,11 +6,10 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import MLPSpeculatorConfig SQRT2 = 2**0.5 diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 1a8e514a7ae83..0fcbf06e1a060 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -17,12 +17,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.mpt import MPTConfig diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 7d92a1ffe55df..e9ff12de2094e 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -37,13 +37,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 8de124cd034dc..97749725dd132 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OlmoAttention(nn.Module): diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c0d2d537e731f..88d2bcb9f0c9d 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OPTLearnedPositionalEmbedding(nn.Embedding): diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index fab35f0b882a7..b01ce87adfa46 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -21,12 +21,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OrionMLP(nn.Module): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 46ee4c3208b7a..104b89e06fa5f 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -11,13 +11,13 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3300939c7b102..f8fc1cd8ef1f0 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -37,12 +37,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class PersimmonMLP(nn.Module): diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index f31b5162aac96..15c21cfa2d8a8 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -52,12 +52,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index df01bfa3d8e6e..afc6fe9844ad6 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -16,12 +16,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors def load_column_parallel_weight(param: torch.nn.Parameter, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bec1d35388506..2fad3ec3e5651 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -31,7 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel @@ -39,7 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b7d017d5f3ea6..8298e3bac4465 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -22,12 +22,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b95987c16ebca..a64e08c422bc3 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -40,13 +40,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6f838947fbf27..56129515ca8d1 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -45,12 +45,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index decbf89d27c7c..6236426dcd4e1 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -36,12 +36,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class StablelmMLP(nn.Module): diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index d1b1d210b727c..d3a3a83c8437f 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -35,12 +35,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class Starcoder2Attention(nn.Module): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 03d6223225511..827a9493a70d2 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -27,6 +27,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (filter_weights, @@ -37,7 +38,7 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig _AUDIO_PLACEHOLDER_TOKEN = 128002 diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index c0bafa9367e43..24cc3728f85e4 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/sequence.py b/vllm/sequence.py index e7cde87f605a7..87b3d21fa7ae3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1060,76 +1060,6 @@ def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: List[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # Spec decode metrics populated by workers. - spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int): - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return ( - f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr}, " - f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") - - class PoolerOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 8a691d65aaa06..b2204e8b27afd 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -5,8 +5,9 @@ import torch from vllm import SamplingParams +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SamplerOutput, SequenceData, SequenceGroupMetadata, + SequenceData, SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index aedf0a83da07d..6e35e40294381 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -3,6 +3,7 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.sampler import SamplerOutput try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -16,8 +17,7 @@ PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalInputs -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput) +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index d1809e49c2a8f..0d233f393cb8c 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -4,8 +4,8 @@ import torch from vllm.model_executor import SamplingMetadata -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 76e444387816f..fc41bb82ea340 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -3,8 +3,8 @@ import torch from vllm.model_executor import SamplingMetadata -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 2dfbacfb7b759..4b53fbe056c47 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -4,8 +4,9 @@ import torch -from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, - SequenceData, SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 806480b5c892f..36e5e1774aa0d 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -3,7 +3,8 @@ import torch -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index efb8ee25ba2f9..28a537593f26d 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposer from vllm.worker.worker_base import LoraNotSupportedWorkerBase diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index 215ede52fb812..8896b7dbc6b8a 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -6,7 +6,8 @@ init_model_parallel_group, patch_tensor_parallel_group) from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9b1f21fcb4920..78beb2ce44773 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -8,12 +8,13 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, - HiddenStates, SamplerOutput, SequenceGroupMetadata, + HiddenStates, SequenceGroupMetadata, get_all_seq_ids, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index aa993e539b6d3..f6a52a516075d 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -2,8 +2,8 @@ import torch -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 5d5f8767e5b6d..54e718bc49017 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -4,9 +4,9 @@ import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + SequenceGroupMetadata, SequenceOutput) SeqId = int diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f69afa4c43149..7205b1a7beb8d 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -10,11 +10,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5c700229660c0..d6189d82d51d9 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,9 +16,10 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e022f7481ee51..8a3c99a45b149 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -29,6 +29,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import (supports_lora, @@ -41,8 +42,7 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, supports_dynamo) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 90c39407d7266..f8fd9d801d289 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -5,9 +5,9 @@ import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata if TYPE_CHECKING: from vllm.attention import AttentionMetadata diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 0abca9d9f4558..be0c75bc00dbd 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,7 +1,8 @@ import dataclasses import functools from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -15,9 +16,12 @@ from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, + SamplerOutput, + SamplingMetadata, get_logprobs, + get_pythonized_sample_results) from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + Logprob, SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -53,6 +57,8 @@ class ModelOutput: sampler_output_ready_event: torch.cuda.Event sampled_token_ids: Optional[torch.Tensor] = None pythonized: bool = False + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None def pythonize(self, input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, @@ -78,7 +84,9 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", blocking: bool) -> bool: """ If blocking is set, will block until the forward pass for the output is - ready and pythonize the output. + ready and pythonize the output. Upon completing Pythonization, erases + self.logprobs (note that a non-blocking call that is performed when + the sampler output is not yet ready, will not erase self.logprobs.) """ assert self.sampled_token_ids is not None if not blocking and not self.sampler_output_ready_event.query(): @@ -89,7 +97,15 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", with torch.cuda.stream(copy_stream): _pythonize_sampler_output(input_metadata, self.sampler_output, pinned_sampled_token_buffer, - self.sampled_token_ids) + self.sampled_token_ids, self.logprobs) + + # Erase the logprobs GPU-side tensor. + # Note that although _pythonize_sampler_output() runs in its + # own CUDA stream, nonetheless _pythonize_sampler_output() + # cannot return until Pythonization is complete; therefore + # we know that by the time the CPU reaches this point, + # `self.logprobs` is no longer needed. + self.logprobs = None return True @@ -350,11 +366,16 @@ def execute_model( 0].sampled_token_ids.cpu() model_input.cached_outputs.append( ModelOutput(output[0], output_ready_event, - output[0].sampled_token_ids, False)) - # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids, False, + output[0].logprobs)) + + # These GPU tensors are not required by multi-step; + # erase them to ensure they are not pythonized or + # transferred to CPU output[0].sampled_token_ids = None output[0].sampled_token_probs = None output[0].logprobs = None + # Pythonize the output if CPU is ahead and the previous step is # ready. if not frozen_model_input.use_async_and_multi_step: @@ -464,12 +485,75 @@ def vocab_size(self) -> int: return self._base_model_runner.vocab_size -def _pythonize_sampler_output(model_input: StatefulModelInput, - output: SamplerOutput, - pinned_sampled_token_buffer: torch.Tensor, - sampled_token_ids: torch.Tensor) -> None: +DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], + Optional[List[SampleLogprobs]]] + + +def deferred_pythonize_logprobs( + output: SamplerOutput, + sampling_metadata: SamplingMetadata, + logprobs_tensor: Optional[torch.Tensor], +) -> DeferredLogprobsReturnType: + """Perform deferred logprob Pythonization. + + 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result. + 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists, + utilizing the Pythonized sampler result computed in step 1. + + These deferred computations are not required for single-step scheduling + or the `profile_run()` phase of multi-step scheduling. + + Args: + output: sampler output (under deferred Pythonization) + sampling_metadata + + Returns: + prompt_logprobs (CPU), sample_logprobs (CPU) + """ + + # - Deferred pythonization of sample result + sampler_result = get_pythonized_sample_results( + output.deferred_sample_results_args) + + # - Erase the GPU-side deferred sample_result + # computation args to ensure it is never + # pythonized or transferred to CPU + output.deferred_sample_results_args = None + + # - Deferred pythonization of logprobs + ( + prompt_logprobs, + sample_logprobs, + ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) + assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) + assert len(sample_logprobs) == len(sampling_metadata.seq_groups) + + return prompt_logprobs, sample_logprobs + + +def _pythonize_sampler_output( + model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor, + logprobs_tensor: Optional[torch.Tensor], +) -> None: """ This function is only called when the output tensors are ready. - See ModelOutput + See :class:`ModelOutput`. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + adding a Pythonized output data structure + (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. + + Args: + model_input + output: sampler output + pinned_sampled_token_token_buffer: CPU-side pinned memory + (receives copy of + GPU-side token buffer.) + sampled_token_ids: GPU-side token buffer + logprobs_tensor: GPU-side tensor containing + logprobs computed during sampling """ assert model_input.frozen_model_input is not None @@ -489,8 +573,51 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, sampling_metadata = frozen_model_input.sampling_metadata - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - samples_list): + skip_sampler_cpu_output = ( + frozen_model_input.sampling_metadata.skip_sampler_cpu_output) + + # We are guaranteed output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However this computation may be skipped entirely + # if no pythonization was deferred. + seq_groups = sampling_metadata.seq_groups + logprobs_are_requested = any([ + sg.sampling_params.logprobs is not None + or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups + ]) + do_pythonize_logprobs = (skip_sampler_cpu_output + and logprobs_are_requested) + ( + prompt_logprobs, + sample_logprobs, + ) = (deferred_pythonize_logprobs(output, sampling_metadata, + logprobs_tensor) + if do_pythonize_logprobs else (None, None)) + + for sgdx, (seq_group, + sample_result) in enumerate(zip(seq_groups, samples_list)): + + if do_pythonize_logprobs: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( # Utilize deferred pythonization results + prompt_logprobs[sgdx], + sample_logprobs[sgdx], + ) + elif logprobs_are_requested: + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( + # profile_run: use already-computed logprobs + output.outputs[sgdx].prompt_logprobs, + [sample.logprobs for sample in output.outputs[sgdx].samples]) + seq_ids = seq_group.seq_ids next_token_ids = sample_result parent_ids = [0] @@ -498,11 +625,19 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, if seq_group.sampling_params.logits_processors: assert len(seq_group.sampling_params.logits_processors) == 0, ( "Logits Processors are not supported in multi-step decoding") - for parent_id, next_token_id in zip(parent_ids, next_token_ids): - # TODO(will): support logprobs - # Hard coded logprob + for tdx, (parent_id, + next_token_id) in enumerate(zip(parent_ids, next_token_ids)): seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, - {next_token_id: Logprob(logprob=-1)})) - output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + (group_sample_logprobs[tdx] + if logprobs_are_requested else { + next_token_id: + Logprob(logprob=float('inf'), + rank=None, + decoded_token=None) + }))) + output.outputs.append( + CompletionSequenceGroupOutput( + seq_outputs, + (group_prompt_logprobs if logprobs_are_requested else None))) assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index e0e421942f409..517b0ab78c460 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -5,7 +5,8 @@ import torch from vllm.distributed import broadcast_tensor_dict, get_pp_group -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.model_runner_base import BroadcastableModelInput from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, StatefulModelInput) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 4f3fed2dbd723..f3defffdfa520 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -8,11 +8,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index a1d09a2f9e53e..f335e4e32efd4 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -11,10 +11,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import SequenceGroupMetadata logger = init_logger(__name__) diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index c47f9acc4423d..36339e175d7bb 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -14,7 +14,8 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a7ceb84effe91..ebb4b89cb4727 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -14,11 +14,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + Logprob, SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7ed609c3b447c..0ff559a9af53e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -17,12 +17,12 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput, SequenceGroupMetadata, - SequenceGroupMetadataDelta) + SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 012043673b094..6ba4f272315ce 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -11,9 +11,9 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput) +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 3894658a095f3..f9037625d4af9 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -15,12 +15,12 @@ from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import (