Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix use of per-request seed with pipeline parallel #6698

Merged
merged 9 commits into from
Jul 30, 2024
23 changes: 8 additions & 15 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
generators = [None] * batch_size

rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators)
draft_token_ids)


@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
Expand Down Expand Up @@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,

results = []
for _ in range(n_rep):
generators = [
torch.Generator(
device=device).manual_seed(i) if seeded_mask[i] else None
for i in range(batch_size)
]
seeded_seqs = {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators))
draft_token_ids, seeded_seqs))

for i in range(batch_size):
if seeded_mask[i]:
Expand Down Expand Up @@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise AssertionError()

oob_token_ids[0][0] = rogue_token_id
generators = [None] * batch_size

with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators)
draft_token_ids)


@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
Expand Down Expand Up @@ -417,15 +414,11 @@ def _estimate_rejection_sampling_pdf(
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)

# unseeded
generators = [None]

# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"),
generators)
draft_token_ids.to("cuda"))

# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()
Expand Down
5 changes: 4 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

generators: Dict[str, torch.Generator] = {}

def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
pin_memory=is_pin_memory_available(),
generators=generators)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

Expand Down
54 changes: 53 additions & 1 deletion tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

import pytest

from .conftest import run_greedy_equality_correctness_test
from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)

# main model
MAIN_MODEL = "JackFram/llama-160m"
Expand Down Expand Up @@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model": MAIN_MODEL,

# Speculative model
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [None])
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
temperature: float):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)

# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"output_len",
[
# Use smaller output len for fast test.
10,
20,
])
@pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
Expand Down
1 change: 1 addition & 0 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id,
target_seq_id,
token_ids,
input_seq_group_metadata.sampling_params,
)

assert output.request_id == input_seq_group_metadata.request_id
Expand Down
31 changes: 31 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage": completion.usage,
})

# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)

results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})

# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)

results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})

# test simple list
batch = client.completions.create(
model=model,
Expand Down
1 change: 0 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
Expand Down
95 changes: 39 additions & 56 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.jit
Expand Down Expand Up @@ -36,7 +36,7 @@ def forward(
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
Expand Down Expand Up @@ -66,6 +66,9 @@ def forward(
probabilities.
shape = [batch_size, num_speculative_tokens]

seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.

Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
Expand All @@ -83,7 +86,7 @@ def forward(
target_probs,
draft_probs,
draft_token_ids,
generators,
seeded_seqs,
))

output_token_ids = self._create_output(
Expand All @@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.

Expand All @@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling(

# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, generators)
draft_token_ids, seeded_seqs)

recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

return accepted, recovered_token_ids
Expand All @@ -143,7 +140,7 @@ def _get_accepted(
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
Expand Down Expand Up @@ -178,24 +175,26 @@ def _get_accepted(
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]

seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)

if len(seed_indices) == 0:
if not seeded_seqs:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)

for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])

if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(
1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
Expand Down Expand Up @@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float:
"""
return torch.finfo(self.probs_dtype).tiny

# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:

if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))

return seed_indices, non_seed_indices


# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
Expand All @@ -304,9 +282,7 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
seeded_seqs: Dict[int, torch.Generator],
) -> torch.Tensor:

if num_samples > 1:
Expand All @@ -315,13 +291,20 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])

q = torch.empty_like(probs)
if len(seed_indices) == 0:
if not seeded_seqs:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)

return probs.div_(q).argmax(dim=1).view(-1, num_samples)
Loading
Loading