Skip to content

Commit

Permalink
Move generators dict to model runner
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Jul 23, 2024
1 parent d654d55 commit 5e89191
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 52 deletions.
5 changes: 4 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

generators: Dict[str, torch.Generator] = {}

def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
pin_memory=is_pin_memory_available(),
generators=generators)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

Expand Down
20 changes: 17 additions & 3 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,15 @@ def prepare(
query_lens: Optional[List[int]],
device: str,
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
) -> "SamplingMetadata":
(
seq_groups,
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device)
device, generators)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
Expand Down Expand Up @@ -159,6 +160,7 @@ def _prepare_seq_groups(
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling.
Expand All @@ -169,8 +171,10 @@ def _prepare_seq_groups(
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
device: A device to use for random number generators,
`SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
Returns:
seq_groups: A list of sequence group to sample.
Expand Down Expand Up @@ -206,6 +210,7 @@ def _prepare_seq_groups(
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
Expand All @@ -214,6 +219,12 @@ def _prepare_seq_groups(
do_sample = seq_group_metadata.do_sample

if seq_group_metadata.is_prompt:
if sampling_params.seed is not None:
generator = torch.Generator(device=device).manual_seed(
sampling_params.seed)
if generators is not None:
generators[seq_group_metadata.request_id] = generator

num_prompts += 1
num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1
Expand All @@ -229,6 +240,9 @@ def _prepare_seq_groups(
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0

if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)

# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
Expand Down Expand Up @@ -280,7 +294,7 @@ def sample(logits):
seq_data=seq_group_metadata.seq_data,
seq_len=seq_len,
query_len=query_len,
generator=seq_group_metadata.generator,
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices)))
Expand Down
3 changes: 0 additions & 3 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
generator: Optional torch.Generator to use for random sampling.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
Expand All @@ -656,7 +655,6 @@ def __init__(
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
generator: Optional[torch.Generator] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
Expand All @@ -672,7 +670,6 @@ def __init__(
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.generator = generator
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size
Expand Down
7 changes: 0 additions & 7 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,6 @@ def _create_single_target_seq_group_metadata(
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)

generator = seq_group_metadata.generator
if generator is not None:
orig_generator = generator
generator = torch.Generator(device=orig_generator.device)
generator.set_state(orig_generator.get_state())

return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
Expand All @@ -308,7 +302,6 @@ def _create_single_target_seq_group_metadata(
},
lora_request=None,
token_chunk_size=1,
generator=generator,
)

def _split_scoring_output(
Expand Down
6 changes: 2 additions & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,11 @@ def _verify_tokens(
proposal_token_ids = proposals.proposal_token_ids[spec_indices]

# Sampler arguments
sampler_extra_kwargs = {}
sampler_extra_kwargs: Dict[str, Any] = {}
if isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
sampler_extra_kwargs["generators"] = [
sgm.generator for sgm in seq_group_metadata_list
]
sampler_extra_kwargs["generators"] = [] #TODO

accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def prepare_model_input(
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
pin_memory=False,
generators=self._get_generators(finished_requests_ids))
return CPUModelInput(
input_tokens=input_tokens,
input_positions=input_positions,
Expand Down
14 changes: 9 additions & 5 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,11 +1224,15 @@ def prepare_model_input(
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory)
if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group
generators = self._get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
Expand Down
16 changes: 15 additions & 1 deletion vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass.
"""

device: torch.device
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}

@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
Expand Down Expand Up @@ -178,3 +179,16 @@ def execute_model(
Execute the model on the given input.
"""
raise NotImplementedError

def _get_generators(self,
finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
"""

# Clean up generators from completed requests
if finished_request_ids:
for request_id in finished_request_ids:
self.generators.pop(request_id, None)

return self.generators
3 changes: 2 additions & 1 deletion vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def prepare_model_input(
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory)
self.pin_memory,
generators=self._get_generators(finished_requests_ids))

return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
Expand Down
25 changes: 0 additions & 25 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
"""
is_driver_worker: bool
model_runner: ModelRunnerBase
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}

@property
@abstractmethod
Expand Down Expand Up @@ -232,9 +230,6 @@ def execute_model(
broadcast_tensor_dict({}, src=0)
return None

self._update_sequence_group_metadata_with_generators(
execute_model_req)

worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
Expand Down Expand Up @@ -314,26 +309,6 @@ def _execute_model_spmd(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None)

def _update_sequence_group_metadata_with_generators(
self, execute_model_req: ExecuteModelRequest):
# This only needs to be done in the last PP rank, where the sampling
# is done
if get_pp_group().is_last_rank:
# Clean up generators from completed requests
for request_id in execute_model_req.finished_requests_ids:
self.generators.pop(request_id, None)

# Set generator in SequenceGroupMetadata from worker state
for sgm in execute_model_req.seq_group_metadata_list:
if sgm.sampling_params.seed is not None:
if sgm.is_prompt:
sgm.generator = torch.Generator(
device=self.model_runner.device).manual_seed(
sgm.sampling_params.seed)
self.generators[sgm.request_id] = sgm.generator
else:
sgm.generator = self.generators.get(sgm.request_id)


class WorkerWrapperBase:
"""
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def prepare_model_input(
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
pin_memory=False,
generators=self._get_generators(finished_requests_ids))
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
Expand Down

0 comments on commit 5e89191

Please sign in to comment.