From 5cf9254a9cfd1611e5c2fcd1b5011b4bdb56947f Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 30 Jul 2024 10:40:08 -0700 Subject: [PATCH] [BugFix] Fix use of per-request seed with pipeline parallel (#6698) --- tests/samplers/test_rejection_sampler.py | 23 ++--- tests/samplers/test_sampler.py | 5 +- tests/spec_decode/e2e/test_mlp_correctness.py | 54 ++++++++++- tests/spec_decode/e2e/test_seed.py | 2 +- tests/spec_decode/test_batch_expansion.py | 1 + tests/utils.py | 31 ++++++ vllm/core/scheduler.py | 1 - .../layers/rejection_sampler.py | 95 ++++++++----------- .../layers/spec_decode_base_sampler.py | 4 +- vllm/model_executor/sampling_metadata.py | 20 ++-- vllm/sequence.py | 12 --- vllm/spec_decode/batch_expansion.py | 37 +++++--- vllm/spec_decode/medusa_worker.py | 4 +- vllm/spec_decode/mlp_speculator_worker.py | 4 +- vllm/spec_decode/ngram_worker.py | 3 +- vllm/spec_decode/spec_decode_worker.py | 25 +++-- vllm/worker/cpu_model_runner.py | 3 +- vllm/worker/model_runner.py | 14 ++- vllm/worker/model_runner_base.py | 15 +++ vllm/worker/neuron_model_runner.py | 3 +- vllm/worker/xpu_model_runner.py | 3 +- 21 files changed, 222 insertions(+), 137 deletions(-) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index b6330a5e5f7c..8f6c292620c2 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, high=vocab_size, size=(batch_size, k), dtype=torch.int64) - generators = [None] * batch_size rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids, generators) + draft_token_ids) @pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0]) @@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, results = [] for _ in range(n_rep): - generators = [ - torch.Generator( - device=device).manual_seed(i) if seeded_mask[i] else None - for i in range(batch_size) - ] + seeded_seqs = { + i: torch.Generator(device=device).manual_seed(i) + for i in range(batch_size) if seeded_mask[i] + } results.append( rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids, generators)) + draft_token_ids, seeded_seqs)) for i in range(batch_size): if seeded_mask[i]: @@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, raise AssertionError() oob_token_ids[0][0] = rogue_token_id - generators = [None] * batch_size with pytest.raises(AssertionError): rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids, generators) + draft_token_ids) @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) @@ -417,15 +414,11 @@ def _estimate_rejection_sampling_pdf( dtype=torch.int64, device="cuda").repeat(num_samples, 1) - # unseeded - generators = [None] - # Get output tokens via rejection sampling. output_token_ids = self.rejection_sampler(target_probs.to("cuda"), bonus_token_ids.to("cuda"), draft_probs.to("cuda"), - draft_token_ids.to("cuda"), - generators) + draft_token_ids.to("cuda")) # Remove bonus tokens output_token_ids = output_token_ids[:, :-1].flatten() diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 9572588ce6e5..bf062e4a5c09 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str): )) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + generators: Dict[str, torch.Generator] = {} + def test_sampling(): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=device, - pin_memory=is_pin_memory_available()) + pin_memory=is_pin_memory_available(), + generators=generators) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index e310941afacf..20f50888dab5 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -21,7 +21,8 @@ import pytest -from .conftest import run_greedy_equality_correctness_test +from .conftest import (run_equality_correctness_test, + run_greedy_equality_correctness_test) # main model MAIN_MODEL = "JackFram/llama-160m" @@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + + # Speculative model + "speculative_model": SPEC_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) +@pytest.mark.parametrize("output_len", [64]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("temperature", [0.1, 1.0]) +@pytest.mark.parametrize("seed", [None]) +def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int, + temperature: float): + """Verify seeded runs produce the same output.""" + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + temperature=temperature, + seeded=True, + force_output_len=True) + + # Ensure this same test does fail if we _don't_ include per-request seeds + with pytest.raises(AssertionError): + run_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + temperature=temperature, + seeded=False, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py index 394a53f03ed4..f84c346c1d31 100644 --- a/tests/spec_decode/e2e/test_seed.py +++ b/tests/spec_decode/e2e/test_seed.py @@ -29,7 +29,7 @@ "output_len", [ # Use smaller output len for fast test. - 10, + 20, ]) @pytest.mark.parametrize("seed", [None]) def test_seeded_consistency(baseline_llm_generator, test_llm_generator, diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py index c350a2c55396..0d6aaa449d85 100644 --- a/tests/spec_decode/test_batch_expansion.py +++ b/tests/spec_decode/test_batch_expansion.py @@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int): input_seq_id, target_seq_id, token_ids, + input_seq_group_metadata.sampling_params, ) assert output.request_id == input_seq_group_metadata.request_id diff --git a/tests/utils.py b/tests/utils.py index bf36d96108d8..1086591464d4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): "usage": completion.usage, }) + # test seeded random sampling + completion = client.completions.create(model=model, + prompt=prompt, + max_tokens=5, + seed=33, + temperature=1.0) + + results.append({ + "test": "seeded_sampling", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + }) + + # test seeded random sampling with multiple prompts + completion = client.completions.create(model=model, + prompt=[prompt, prompt], + max_tokens=5, + seed=33, + temperature=1.0) + + results.append({ + "test": + "seeded_sampling", + "text": [choice.text for choice in completion.choices], + "finish_reason": + [choice.finish_reason for choice in completion.choices], + "usage": + completion.usage, + }) + # test simple list batch = client.completions.create( model=model, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6e59c5e0f74f..5b7b569c3e08 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1029,7 +1029,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, - state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index b4994083c797..533b43634441 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.jit @@ -36,7 +36,7 @@ def forward( bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, - generators: List[Optional[torch.Generator]], + seeded_seqs: Optional[Dict[int, torch.Generator]] = None, ) -> torch.Tensor: """Sample token ids using rejection sampling. This accepts or rejects tokens proposed by the draft model using the probability of each token @@ -66,6 +66,9 @@ def forward( probabilities. shape = [batch_size, num_speculative_tokens] + seeded_seqs: Dict of batch row index to torch generator, for + sequences using seeded generation. + Returns: output_token_ids: The token ids sampled via rejection sampling, or -1 if unable to sample a token because the previous token @@ -83,7 +86,7 @@ def forward( target_probs, draft_probs, draft_token_ids, - generators, + seeded_seqs, )) output_token_ids = self._create_output( @@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling( target_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_token_ids: torch.Tensor, # [batch_size, k] - generators: List[Optional[torch.Generator]], + seeded_seqs: Optional[Dict[int, torch.Generator]], ) -> Tuple[torch.Tensor, torch.Tensor]: """Perform modified rejection sampling on each sequence. @@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling( # shape [batch_size, k] accepted = self._get_accepted(target_probs, draft_probs, - draft_token_ids, generators) + draft_token_ids, seeded_seqs) recovered_probs = self._get_recovered_probs( target_probs, draft_probs).reshape(batch_size * k, vocab_size) - seed_indices, non_seed_indices = self._split_batch_by_seeded( - generators, k=k) - # NOTE: the recovered_probs are overwritten by this method. recovered_token_ids = _multinomial( recovered_probs, num_samples=1, k=k, - generators=generators, - seed_indices=seed_indices, - # this arg is unused when None but torch.jit requires a list - non_seed_indices=non_seed_indices or [], + seeded_seqs=seeded_seqs or {}, ).reshape(batch_size, k) return accepted, recovered_token_ids @@ -143,7 +140,7 @@ def _get_accepted( target_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_token_ids: torch.Tensor, # [batch_size, k] - generators: List[Optional[torch.Generator]], + seeded_seqs: Optional[Dict[int, torch.Generator]], ) -> torch.Tensor: r"""Create bool matrix over the proposed draft tokens. If True, then a token can be accepted, else it should be @@ -178,24 +175,26 @@ def _get_accepted( selected_target_probs = target_probs[batch_indices, probs_indicies, draft_token_ids] - seed_indices, non_seed_indices = self._split_batch_by_seeded( - generators) - - if len(seed_indices) == 0: + if not seeded_seqs: uniform_rand = torch.rand_like(selected_target_probs) else: uniform_rand = torch.empty_like(selected_target_probs) - for idx in seed_indices: - uniform_rand[idx, :] = torch.rand(1, - k, - dtype=self.probs_dtype, - device=target_probs.device, - generator=generators[idx]) - - if non_seed_indices: - uniform_rand[non_seed_indices, :] = torch.rand( - len(non_seed_indices), + non_seeded_indices = [] + for idx in range(batch_size): + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.append(idx) + else: + uniform_rand[idx, :] = torch.rand( + 1, + k, + dtype=self.probs_dtype, + device=target_probs.device, + generator=generator) + if non_seeded_indices: + uniform_rand[non_seeded_indices, :] = torch.rand( + len(non_seeded_indices), k, dtype=self.probs_dtype, device=target_probs.device) @@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny - # partition batch into indices for which a generator is provided - # and indicies for which no generator is provided - @staticmethod - def _split_batch_by_seeded( - generators: List[Optional[torch.Generator]], - k: int = 1, - ) -> Tuple[List[int], Optional[List[int]]]: - - if all(generator is None for generator in generators): - seed_indices: List[int] = [] - non_seed_indices: Optional[List[int]] = None - else: - seed_indices, non_seed_indices = [], [] - for i, generator in enumerate(generators): - if generator is None: - non_seed_indices.extend(range(k * i, k * (i + 1))) - else: - seed_indices.extend(range(k * i, k * (i + 1))) - - return seed_indices, non_seed_indices - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. @@ -304,9 +282,7 @@ def _multinomial( probs: torch.Tensor, num_samples: int, k: int, - generators: List[Optional[torch.Generator]], - seed_indices: List[int], - non_seed_indices: List[int], + seeded_seqs: Dict[int, torch.Generator], ) -> torch.Tensor: if num_samples > 1: @@ -315,13 +291,20 @@ def _multinomial( probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs.shape[1]).contiguous().view( -1, probs.shape[1]) - q = torch.empty_like(probs) - if len(seed_indices) == 0: + if not seeded_seqs: q.exponential_(1.0) else: - q[non_seed_indices].exponential_(1.0) - for idx in seed_indices: - q[idx].exponential_(1.0, generator=generators[idx // k]) + non_seeded_indices: List[int] = [] + start = 0 + for idx in range(len(q) // k): + end = start + k + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.extend(list(range(start, end))) + else: + q[start:end].exponential_(1.0, generator=generator) + start = end + q[non_seeded_indices].exponential_(1.0) return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 08191da49d52..3091e639727b 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Optional +from typing import Dict, Optional import torch import torch.jit @@ -237,6 +237,6 @@ def forward( bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, - generators: List[Optional[torch.Generator]], + seeded_seqs: Optional[Dict[int, torch.Generator]] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 1caf9aa01d8c..59cfec9ec893 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -118,6 +118,7 @@ def prepare( query_lens: Optional[List[int]], device: str, pin_memory: bool, + generators: Optional[Dict[str, torch.Generator]] = None, ) -> "SamplingMetadata": ( seq_groups, @@ -125,7 +126,7 @@ def prepare( categorized_sample_indices, num_prompts, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, - device) + device, generators) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=device, @@ -160,6 +161,7 @@ def _prepare_seq_groups( seq_lens: List[int], query_lens: Optional[List[int]], device: str, + generators: Optional[Dict[str, torch.Generator]] = None, ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ SamplingType, List[Tuple[int, int]]], int]: """Prepare sequence groups and indices for sampling. @@ -170,8 +172,10 @@ def _prepare_seq_groups( Index of prompt len should match with seq_group_metadata_list. query_lens: A list of query lengths. Prompt lens include the length of entire prompt tokens, and it could be shorter. - device: A device to use for random number generator, + device: A device to use for random number generators, `SequenceGroupToSample.generator`. + generators: A store of per-request random number generators used + for seeded requests. Returns: seq_groups: A list of sequence group to sample. @@ -217,8 +221,10 @@ def _prepare_seq_groups( if seq_group_metadata.is_prompt: if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=device).manual_seed(sampling_params.seed) + generator = torch.Generator(device=device).manual_seed( + sampling_params.seed) + if generators is not None: + generators[seq_group_metadata.request_id] = generator num_prompts += 1 num_prefill_sample = len(seq_ids) @@ -235,6 +241,9 @@ def _prepare_seq_groups( prompt_logprob_len = 0 sample_len = len(seq_ids) if do_sample else 0 + if sampling_params.seed is not None and generators is not None: + generator = generators.get(seq_group_metadata.request_id) + # Update indices to select from the model output. """ This blocks computes selected_token_indices which is used in the @@ -279,9 +288,6 @@ def sample(logits): logit_idx += sample_len sample_idx += sample_len - if sampling_params.seed is not None: - generator = seq_group_metadata.state.generator - seq_groups.append( SequenceGroupToSample( seq_ids=seq_ids, diff --git a/vllm/sequence.py b/vllm/sequence.py index 72821ecea0f4..ab50cfdfd29a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -411,14 +411,6 @@ def __repr__(self) -> str: f"num_blocks={self.n_blocks}, ") -@dataclass -class SequenceGroupState: - """Mutable state tied to a specific sequence group""" - - # torch.Generator used in seeded sampling - generator: Optional = None # type: ignore - - class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -461,7 +453,6 @@ def __init__( time_in_queue=None) self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None - self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params self.prompt_adapter_request = prompt_adapter_request @@ -648,7 +639,6 @@ class SequenceGroupMetadata: lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. - state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None @@ -674,7 +664,6 @@ def __init__( token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, - state: Optional[SequenceGroupState] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, @@ -690,7 +679,6 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data - self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 41f0aebf3c01..45eaeb51c5c0 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -3,9 +3,9 @@ import torch +from vllm import SamplingParams from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, - SequenceGroupMetadata, SequenceGroupState, - get_all_seq_ids) + SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, @@ -16,6 +16,8 @@ TargetSeqId = int TokenId = int +DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams() + class BatchExpansionTop1Scorer(SpeculativeScorer): """Implements a speculative scorer that uses batch expansion to get @@ -247,24 +249,39 @@ def _create_target_seq_group_metadata( token_ids_to_score = self._get_token_ids_to_score( proposal_token_ids[batch_index]) + # Use simpler sampling parameters apart from for final token + # (in particular don't do seeded sampling) since those sampled tokens + # aren't used. + # We don't replace the sampling_params in the greedy case because + # this also controls whether the probs get modified in the sampler + # (see use of _modify_greedy_probs_inplace there). + sampling_params = input_seq_group_metadata.sampling_params + non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \ + if sampling_params.temperature else sampling_params + target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for token_ids in token_ids_to_score: + last_index = len(token_ids_to_score) - 1 + for i, token_ids in enumerate(token_ids_to_score): + target_sampling_params = sampling_params if i == last_index \ + else non_bonus_sampling_params target_seq_group_metadata_list.append( self._create_single_target_seq_group_metadata( input_seq_group_metadata, input_seq_id, next(target_seq_ids_iter), token_ids, + sampling_params=target_sampling_params, )) return target_seq_group_metadata_list + @staticmethod def _create_single_target_seq_group_metadata( - self, seq_group_metadata: SequenceGroupMetadata, seq_id: SeqId, target_seq_id: TargetSeqId, token_ids: List[TokenId], + sampling_params: SamplingParams, ) -> SequenceGroupMetadata: """Create a single target SequenceGroupMetadata. @@ -293,26 +310,16 @@ def _create_single_target_seq_group_metadata( for data in new_seq_data_dict.values(): data.update_num_computed_tokens(data.get_len() - 1) - if (seq_group_metadata.state is not None - and seq_group_metadata.state.generator is not None): - generator = torch.Generator( - device=seq_group_metadata.state.generator.device) - generator.set_state(seq_group_metadata.state.generator.get_state()) - state = SequenceGroupState(generator=generator) - else: - state = None - return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, seq_data=new_seq_data_dict, - sampling_params=seq_group_metadata.sampling_params, + sampling_params=sampling_params, block_tables={ target_seq_id: seq_group_metadata.block_tables[seq_id], }, lora_request=None, token_chunk_size=1, - state=state, ) def _split_scoring_output( diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 041ce41e91d0..4b82f7bf92ba 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -57,9 +57,11 @@ def sampler_output( seq_lens, query_lens = self._prepare_input_tensors( seq_group_metadata_list) + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, - self.model_runner.pin_memory) + self.model_runner.pin_memory, generators) model_outputs = self.model_runner.model.generate_proposals( previous_hidden_states=execute_model_req.previous_hidden_states. diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 308573348d44..76e444387816 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -38,9 +38,11 @@ def sampler_output( (input_tokens, seq_lens, query_lens) = self._prepare_input_tensors(seq_group_metadata_list) + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, - self.model_runner.pin_memory) + self.model_runner.pin_memory, generators) model_outputs = self.model_runner.model.generate_proposals( input_ids=input_tokens, diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index a21222fec269..806480b5c892 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -7,10 +7,9 @@ from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): +class NGramWorker(NonLLMProposerWorkerBase): """NGramWorker provides a light drafter without need for model. Current NGramWorker only implements prompt lookup decoding, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 98960b88f719..ad8c0cee0b5b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -213,6 +213,9 @@ def __init__( """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker + scorer_runner = getattr(self.scorer_worker, "model_runner", None) + self.generators = scorer_runner.get_generators( + ) if scorer_runner else None self.disable_by_batch_size = disable_by_batch_size or float("inf") self.spec_decode_sampler = spec_decode_sampler self._allow_zero_draft_token_step = allow_zero_draft_token_step @@ -591,20 +594,14 @@ def _verify_tokens( proposal_token_ids = proposals.proposal_token_ids[spec_indices] # Sampler arguments - sampler_extra_kwargs = {} - if isinstance(self.spec_decode_sampler, - SpecDecodeStochasticBaseSampler): - - # Get sequence group state - generators = [] - for seq_group_metadata in seq_group_metadata_list: - if (seq_group_metadata.state is not None - and seq_group_metadata.state.generator is not None): - generators.append(seq_group_metadata.state.generator) - else: - generators.append(None) - - sampler_extra_kwargs["generators"] = generators + sampler_extra_kwargs: Dict[str, Any] = {} + if self.generators and isinstance(self.spec_decode_sampler, + SpecDecodeStochasticBaseSampler): + sampler_extra_kwargs["seeded_seqs"] = { + idx: self.generators[sgm.request_id] + for idx, sgm in enumerate(seq_group_metadata_list) + if sgm.sampling_params.seed is not None + } accepted_token_ids = self.spec_decode_sampler( target_probs=proposal_verifier_probs, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 71763c08ec45..c1dee444da51 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -337,7 +337,8 @@ def prepare_model_input( # just use seq_lens instead. seq_lens, self.device, - pin_memory=False) + pin_memory=False, + generators=self.get_generators(finished_requests_ids)) return CPUModelInput( input_tokens=input_tokens, input_positions=input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c3..4010c45e1026 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1264,11 +1264,15 @@ def prepare_model_input( """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - self.pin_memory) + if get_pp_group().is_last_rank: + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, model_input.seq_lens, + model_input.query_lens, self.device, self.pin_memory, + generators) + else: + sampling_metadata = None is_prompt = (seq_group_metadata_list[0].is_prompt if seq_group_metadata_list else None) return dataclasses.replace(model_input, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 5fb97025af5c..46ac16b504bf 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]): ModelRunnerInputBase subclass. """ + # Map of request_id -> generator used for seeded random sampling + generators: Dict[str, torch.Generator] = {} + @abstractmethod def make_model_input_from_broadcasted_tensor_dict( self, @@ -176,3 +179,15 @@ def execute_model( Execute the model on the given input. """ raise NotImplementedError + + def get_generators(self, finished_request_ids: Optional[List[str]] = None): + """ + Return dict of per-request generators used for random sampling. + """ + + # Clean up generators from completed requests + if finished_request_ids: + for request_id in finished_request_ids: + self.generators.pop(request_id, None) + + return self.generators diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 651319ab1454..243e2ece56fe 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -219,7 +219,8 @@ def prepare_model_input( # just use seq_lens instead. seq_lens, self.device, - self.pin_memory) + self.pin_memory, + generators=self.get_generators(finished_requests_ids)) return ModelInputForNeuron(input_tokens=input_tokens, input_positions=input_positions, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 2f0ca42316e1..98462f0f7f38 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -246,7 +246,8 @@ def prepare_model_input( # just use seq_lens instead. seq_lens, self.device, - pin_memory=False) + pin_memory=False, + generators=self.get_generators(finished_requests_ids)) # Broadcast the metadata. metadata_dict = { "input_tokens": input_tokens,