Skip to content

Commit

Permalink
Fix typical acceptance sampler with correct recovered token ids (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng authored Sep 23, 2024
1 parent 3dd5053 commit 49ecfe9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 28 deletions.
17 changes: 8 additions & 9 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1
# draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
Expand All @@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
assert torch.all(output_token_ids[:, -3:] == -1)


Expand Down Expand Up @@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, device: str):
def test_get_recovered_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_replacement_token_ids` method of the
This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as replacements based on the target probability distribution.
as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
Expand All @@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
(batch_size, k), dtype=torch.long)
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
dim=1)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
typical_acceptance_sampler._replacement_token_ids(target_probs))
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
28 changes: 9 additions & 19 deletions vllm/model_executor/layers/typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs)
recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
Expand Down Expand Up @@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
accepted_mask = candidates_prob > threshold
return accepted_mask

def _replacement_token_ids(self, target_probs):
def _get_recovered_token_ids(self, target_probs):
"""
Generate one replacement token ID for each sequence based on target
probabilities. The replacement token is used as the fallback option
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
The recovered token ids will fill the first unmatched token
by the target token.
Parameters
----------
Expand All @@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs):
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
token IDs. Only the first column is set, and the rest of the
columns are filled with -1.
A tensor of shape (batch_size, k) with the recovered token
ids which are selected from target probs.
"""
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype,
device=target_probs.device)
output[:, 0] = max_indices
return output
max_indices = torch.argmax(target_probs, dim=-1)

return max_indices

0 comments on commit 49ecfe9

Please sign in to comment.