Skip to content

Commit

Permalink
[BugFix] Fix use of per-request seed with pipeline parallel (#6698)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Jul 30, 2024
1 parent f058403 commit 5cf9254
Show file tree
Hide file tree
Showing 21 changed files with 222 additions and 137 deletions.
23 changes: 8 additions & 15 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
generators = [None] * batch_size

rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators)
draft_token_ids)


@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
Expand Down Expand Up @@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,

results = []
for _ in range(n_rep):
generators = [
torch.Generator(
device=device).manual_seed(i) if seeded_mask[i] else None
for i in range(batch_size)
]
seeded_seqs = {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators))
draft_token_ids, seeded_seqs))

for i in range(batch_size):
if seeded_mask[i]:
Expand Down Expand Up @@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise AssertionError()

oob_token_ids[0][0] = rogue_token_id
generators = [None] * batch_size

with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators)
draft_token_ids)


@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
Expand Down Expand Up @@ -417,15 +414,11 @@ def _estimate_rejection_sampling_pdf(
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)

# unseeded
generators = [None]

# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"),
generators)
draft_token_ids.to("cuda"))

# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()
Expand Down
5 changes: 4 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

generators: Dict[str, torch.Generator] = {}

def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
pin_memory=is_pin_memory_available(),
generators=generators)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

Expand Down
54 changes: 53 additions & 1 deletion tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import pytest

from .conftest import run_greedy_equality_correctness_test
from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)

# main model
MAIN_MODEL = "JackFram/llama-160m"
Expand Down Expand Up @@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [None])
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
temperature: float):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)

# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"output_len",
[
# Use smaller output len for fast test.
10,
20,
])
@pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
Expand Down
1 change: 1 addition & 0 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id,
target_seq_id,
token_ids,
input_seq_group_metadata.sampling_params,
)

assert output.request_id == input_seq_group_metadata.request_id
Expand Down
31 changes: 31 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage": completion.usage,
})

# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)

results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})

# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)

results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})

# test simple list
batch = client.completions.create(
model=model,
Expand Down
1 change: 0 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
Expand Down
95 changes: 39 additions & 56 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.jit
Expand Down Expand Up @@ -36,7 +36,7 @@ def forward(
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
Expand Down Expand Up @@ -66,6 +66,9 @@ def forward(
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
Expand All @@ -83,7 +86,7 @@ def forward(
target_probs,
draft_probs,
draft_token_ids,
generators,
seeded_seqs,
))

output_token_ids = self._create_output(
Expand All @@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Expand All @@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling(

# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, generators)
draft_token_ids, seeded_seqs)

recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

return accepted, recovered_token_ids
Expand All @@ -143,7 +140,7 @@ def _get_accepted(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
Expand Down Expand Up @@ -178,24 +175,26 @@ def _get_accepted(
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)

if len(seed_indices) == 0:
if not seeded_seqs:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)

for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])

if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(
1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
Expand Down Expand Up @@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float:
"""
return torch.finfo(self.probs_dtype).tiny

# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:

if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))

return seed_indices, non_seed_indices


# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
Expand All @@ -304,9 +282,7 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
seeded_seqs: Dict[int, torch.Generator],
) -> torch.Tensor:

if num_samples > 1:
Expand All @@ -315,13 +291,20 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])

q = torch.empty_like(probs)
if len(seed_indices) == 0:
if not seeded_seqs:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)

return probs.div_(q).argmax(dim=1).view(-1, num_samples)
Loading

0 comments on commit 5cf9254

Please sign in to comment.