From 00d74c0a84c990b4e7d6159b40faee2a72993f96 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 8 May 2024 19:40:10 +0000 Subject: [PATCH 01/37] Test commit --- examples/llm_engine_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index a81c4b3e399c3..a92777d92da39 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -1,4 +1,5 @@ import argparse + from typing import List, Tuple from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams From 81eb96646954afe23cf6913f72bde6f014c38d0e Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 30 May 2024 07:30:18 +0000 Subject: [PATCH 02/37] Refactoring some of the logic in rejection_sampling to a base class and adding new class for acceptance sampling --- .../test_typical_acceptance_sampler.py | 220 ++++++++++++++++++ .../layers/rejection_sampler.py | 169 +------------- .../layers/spec_decode_base_sampler.py | 211 +++++++++++++++++ .../layers/typical_acceptance_sampler.py | 128 ++++++++++ 4 files changed, 567 insertions(+), 161 deletions(-) create mode 100644 tests/samplers/test_typical_acceptance_sampler.py create mode 100644 vllm/model_executor/layers/spec_decode_base_sampler.py create mode 100644 vllm/model_executor/layers/typical_acceptance_sampler.py diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py new file mode 100644 index 0000000000000..f6ce11f500451 --- /dev/null +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -0,0 +1,220 @@ +"""Tests for rejection sampling.""" +from typing import List, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) +from vllm.model_executor.utils import set_random_seed + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1) +] + +def get_zero_temperature_prob_dist(batch_size, k, vocab_size): + # Simulate temperature 0 probability distribution for target probabilities + # and create target probabilities such that only 1 token id has + # probability 1.0 + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + probs = torch.rand(batch_size, k, vocab_size) + _, max_indices = torch.max(probs, dim=-1) + # set the probability of the tokens with ids in max_indices to 1 and + # the rest to 0. + target_probs = torch.zeros_like(probs).scatter_( + -1, max_indices.unsqueeze(-1), 1.0) + return target_probs, max_indices + +def get_draft_token_ids( + batch_size: int, k: int, vocab_size: int, token_ids_to_exclude: torch.Tensor): + # Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0 + draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) + for i in range(batch_size): + for j in range(k): + # Generate a random token ID excluding max_indices[i, j] + while True: + token_id = torch.randint(0, vocab_size, (1,)).item() + if token_id != token_ids_to_exclude[i, j]: + draft_token_ids[i, j] = token_id + break + return draft_token_ids + +@pytest.mark.parametrize("k", list(range(1, 6))) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", list(range(1, 32))) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, + device: str): + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler() + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + + typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + + +@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) +@pytest.mark.parametrize("which_token_ids", + ["bonus_token_ids", "draft_token_ids"]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_raises_when_vocab_oob(above_or_below_vocab_range: str, + which_token_ids: str, device: str): + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + oob_token_ids = None + if which_token_ids == "bonus_token_ids": + oob_token_ids = bonus_token_ids + elif which_token_ids == "draft_token_ids": + oob_token_ids = draft_token_ids + else: + raise AssertionError() + + if above_or_below_vocab_range == "above": + rogue_token_id = vocab_size + 1 + elif above_or_below_vocab_range == "below": + rogue_token_id = -1 + else: + raise AssertionError() + + oob_token_ids[0][0] = rogue_token_id + + with pytest.raises(AssertionError): + typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_uniform_target_distribution_accepts_all_tokens( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler( + target_probs, bonus_token_ids, draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze()) + + assert torch.all(output_token_ids[:, : k] == draft_token_ids) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_temperature_zero_target_distribution( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # Simulate temperature 0 probability distribution for target probabilities + # and create target probabilities such that only 1 token id has + # probability 1.0 + target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( + batch_size, k, vocab_size) + # Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0 + draft_token_ids = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler( + target_probs, bonus_token_ids, draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, -1] == -1) + assert torch.all( + output_token_ids[:, 0] == zero_temperature_token_ids[:, 0]) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_mixed_target_distribution( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 3 + batch_size = 4 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # For batches 0 and 2 set the distribution to an uniform distribution. For + # batches 1 and 3 set it to a temperature 0 distribution. + target_probs, zero_temperature_token_ids = ( + get_zero_temperature_prob_dist(batch_size, k, vocab_size)) + draft_token_ids = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + # Create target_probs such that only one token_id has probability 1.0 + uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) + target_probs[[1, 3]] = uniform_probs + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler( + target_probs, bonus_token_ids, draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[[0, 2], 1:] == -1) + assert ( + torch.all(output_token_ids[[0, 2], 0] == zero_temperature_token_ids[[0, 2], 0])) + assert torch.all(output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) + if disable_bonus_tokens: + assert torch.all(output_token_ids[[1, 3], -1] == -1) + else: + assert torch.all(output_token_ids[[1, 3], -1] != -1) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 1f2ab7e2870ca..301159ee0f20d 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -4,9 +4,10 @@ import torch import torch.jit import torch.nn as nn +from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler -class RejectionSampler(nn.Module): +class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" https://arxiv.org/pdf/2302.01318.pdf. @@ -25,37 +26,9 @@ def __init__(self, during sampling. This catches correctness issues but adds nontrivial latency. """ - super().__init__() - self._disable_bonus_tokens = disable_bonus_tokens - self._strict_mode = strict_mode - - # NOTE: A "bonus token" is accepted iff all proposal tokens are - # accepted. There is always only one possible bonus token. We store this - # value in a variable for readability. - self._num_bonus_tokens = 1 - - self.num_accepted_tokens: Optional[torch.Tensor] = None - self.num_emitted_tokens: Optional[torch.Tensor] = None - self.num_draft_tokens: int = 0 - - def init_gpu_tensors(self, rank: int) -> None: - assert self.num_accepted_tokens is None - device = f"cuda:{rank}" - self.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - self.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - - @property - def probs_dtype(self): - return torch.float32 - - @property - def token_id_dtype(self): - return torch.int64 - + SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) + nn.Module.__init__(self) + def forward( self, target_probs: torch.Tensor, @@ -100,15 +73,10 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_shape(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, + self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - self._raise_if_inconsistent_device(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], - bonus_token_ids, - draft_token_ids) + + print('target_probs ' + str(target_probs)) accepted, recovered_token_ids = self._batch_modified_rejection_sampling( target_probs, @@ -272,127 +240,6 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny - def _create_output( - self, - accepted: torch.Tensor, # [batch_size, k] - recovered_token_ids: torch.Tensor, # [batch_size, k] - draft_token_ids: torch.Tensor, # [batch_size, k] - bonus_token_ids: torch.Tensor, # [batch_size] - ) -> torch.Tensor: - """Format output. Returns a matrix of token ids. When - a token is rejected via rejection sampling, all subsequent - token ids are set to -1 for the sequence. - - shape = [batch_size, k + num_bonus_tokens] - """ - bonus_token_ids = bonus_token_ids.squeeze() - batch_size, k = recovered_token_ids.shape - - # Determine the index of the first False value for each row. - limits = (accepted == 0).max(1).indices - limits[~(accepted == 0).any(1)] = k - - # Create masks using the indices. - indices = torch.arange(k, device=accepted.device).unsqueeze(0) - accepted_mask = indices < limits.unsqueeze(1) - after_false_mask = indices == limits.unsqueeze(1) - - # Create an extended output tensor - output_with_bonus_tokens = -torch.ones( - (batch_size, k + self._num_bonus_tokens), - dtype=self.token_id_dtype, - device=accepted.device) - output = output_with_bonus_tokens[:, :k] - - # Fill in the first k columns of the output tensor using masks and data - # tensors. - output[:, :k] = torch.where(accepted_mask, draft_token_ids, - -torch.ones_like(draft_token_ids)) - - # Fill the last column. - # We check output directly as accepted may have True values inconsistent - # with causal acceptance. - output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, - bonus_token_ids, -1) - - # We disable bonus tokens because it causes corrupt KV cache for - # proposal methods that require KV cache. We can fix it by "prefilling" - # the bonus token in the proposer. The following issue tracks the fix. - # https://github.com/vllm-project/vllm/issues/4212 - if self._disable_bonus_tokens: - output_with_bonus_tokens[:, -1] = -1 - - # Fill the recovered token ids. - output.mul_(~after_false_mask).add_( - recovered_token_ids.mul(after_false_mask)) - - self.num_accepted_tokens += accepted.sum() - self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() - self.num_draft_tokens += batch_size * k - - return output_with_bonus_tokens - - def _raise_if_incorrect_shape( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - (target_batch_size, num_target_probs, - target_vocab_size) = target_probs.shape - bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape - draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape - draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape - - assert draft_batch_size == target_batch_size - assert num_draft_probs == num_target_probs - assert (draft_vocab_size == target_vocab_size - ), f"{draft_vocab_size=} {target_vocab_size=}" - - assert draft_token_ids_batch_size == draft_batch_size - assert num_draft_token_ids == num_draft_probs - - assert bonus_batch_size == target_batch_size - assert num_bonus_tokens == self._num_bonus_tokens - - def _raise_if_incorrect_dtype( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert all(probs.dtype == self.probs_dtype - for probs in [target_probs, draft_probs]) - assert all(token_ids.dtype == self.token_id_dtype - for token_ids in [bonus_token_ids, draft_token_ids]) - - def _raise_if_inconsistent_device( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] - ] - assert all([devices[0] == device for device in devices]) - - def _raise_if_out_of_bounds_vocab( - self, - vocab_size: int, - bonus_token_ids: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert torch.all(bonus_token_ids < vocab_size) - assert torch.all(bonus_token_ids >= 0) - assert torch.all(draft_token_ids < vocab_size) - assert torch.all(draft_token_ids >= 0) - - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py new file mode 100644 index 0000000000000..64a7f6474644a --- /dev/null +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -0,0 +1,211 @@ +from abc import ABC, abstractmethod +from functools import cached_property +from typing import Optional, Tuple + +import torch +import torch.jit +import torch.nn as nn + + +class SpecDecodeBaseSampler(ABC): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): + """Create a rejection sampler. + + Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + recovered_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via rejection sampling, all subsequent + token ids are set to -1 for the sequence. + + shape = [batch_size, k + num_bonus_tokens] + """ + batch_size, k = recovered_token_ids.shape + bonus_token_ids = bonus_token_ids.squeeze() + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # We disable bonus tokens because it causes corrupt KV cache for + # proposal methods that require KV cache. We can fix it by "prefilling" + # the bonus token in the proposer. The following issue tracks the fix. + # https://github.com/vllm-project/vllm/issues/4212 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + recovered_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_input( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + self._raise_if_incorrect_shape( + target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_incorrect_dtype( + target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_inconsistent_device( + target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + draft_token_ids, + bonus_token_ids) + + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + + # validate the shape of draft token ids. + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + assert draft_token_ids_batch_size == target_batch_size + assert num_draft_token_ids == num_target_probs + + # validate the shape of bonus token ids + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + # validate the shape of draft probs if it is set + if draft_probs is not None: + draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + assert target_probs.dtype == self.probs_dtype + assert draft_token_ids.dtype == self.token_id_dtype + assert bonus_token_ids.dtype == self.token_id_dtype + if draft_probs is not None: + assert draft_probs.dtype == self.probs_dtype + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] if t is not None + ] + assert all([devices[0] == device for device in devices]) + + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + +def get_rejection_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): + from vllm.model_executor.layers.rejection_sampler import RejectionSampler + return RejectionSampler(disable_bonus_tokens, strict_mode) + + +def get_typical_acceptance_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): + from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) + return TypicalAcceptanceSampler(disable_bonus_tokens, strict_mode) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py new file mode 100644 index 0000000000000..b4aac132e10a7 --- /dev/null +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -0,0 +1,128 @@ +from functools import cached_property +from typing import Optional, Tuple + +import torch +import torch.jit +import torch.nn as nn + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) + + +class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, disable_bonus_tokens: bool = False, strict_mode: bool = False): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + SpecDecodeBaseSampler.__init__( + self, disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) + nn.Module.__init__(self) + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input( + target_probs, draft_token_ids, bonus_token_ids) + accepted = self._evaluate_posterior(target_probs, draft_token_ids) + recovered_token_ids = self._replacement_token_ids(target_probs) + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, bonus_token_ids + ) + print('----test input----') + print('target_probs ' + str(target_probs)) + print('draft_token_ids ' + str(draft_token_ids)) + print('recovered_token_ids ' + str(recovered_token_ids)) + print(output_token_ids) + return output_token_ids + + def _evaluate_posterior( + self, target_probs, draft_token_ids, + posterior_threshold=0.3, posterior_alpha = 0.09): + + """Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. + Depending on the temperature value, the function either uses greedy decoding or evaluates posterior + probabilities to select the best candidate. + + Args: + - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). + - candidates (torch.Tensor): Candidate token sequences. + - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. + - posterior_threshold (float): Threshold for posterior probability. + - posterior_alpha (float): Scaling factor for the threshold. + - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8. + - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'. + - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False. + Returns: + - best_candidate (torch.Tensor): Index of the chosen best candidate. + - accept_length (int): Length of the accepted candidate sequence. + """ + candidates_prob = torch.gather( + target_probs, dim=-1, index=draft_token_ids.unsqueeze(-1) + ).squeeze(-1) + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + 1e-5), dim=-1 + ) # torch.sum(torch.log(*)) is faster than torch.prod + threshold = torch.minimum( + torch.ones_like(posterior_entropy) * posterior_threshold, + torch.exp(-posterior_entropy) * posterior_alpha, + ) + posterior_mask = candidates_prob > threshold + return posterior_mask + + def _replacement_token_ids(self, target_probs): + max_indices = torch.argmax(target_probs[:, 0, :], dim=1) + output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), dtype=self.token_id_dtype) + output[:, 0] = max_indices + return output From b09a20f492dfdf7c33db54a43df75ecfa59580ac Mon Sep 17 00:00:00 2001 From: sroy745 Date: Fri, 31 May 2024 06:04:13 +0000 Subject: [PATCH 03/37] Adding comments and some new tests. --- tests/samplers/test_rejection_sampler.py | 2 +- .../test_typical_acceptance_sampler.py | 128 +++++++++++++----- .../layers/rejection_sampler.py | 9 +- .../layers/spec_decode_base_sampler.py | 67 +++++---- .../layers/typical_acceptance_sampler.py | 118 +++++++++------- 5 files changed, 202 insertions(+), 122 deletions(-) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 00a2379502e6d..0cc4f73aaaf37 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -9,7 +9,7 @@ from vllm.model_executor.utils import set_random_seed CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) ] diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index f6ce11f500451..db1f6265bff1e 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -9,9 +9,8 @@ TypicalAcceptanceSampler) from vllm.model_executor.utils import set_random_seed -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1)] + def get_zero_temperature_prob_dist(batch_size, k, vocab_size): # Simulate temperature 0 probability distribution for target probabilities @@ -19,15 +18,16 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): # probability 1.0 target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) probs = torch.rand(batch_size, k, vocab_size) - _, max_indices = torch.max(probs, dim=-1) + _, zero_temperature_token_ids = torch.max(probs, dim=-1) # set the probability of the tokens with ids in max_indices to 1 and # the rest to 0. target_probs = torch.zeros_like(probs).scatter_( - -1, max_indices.unsqueeze(-1), 1.0) - return target_probs, max_indices + -1, zero_temperature_token_ids.unsqueeze(-1), 1.0) + return target_probs, zero_temperature_token_ids + -def get_draft_token_ids( - batch_size: int, k: int, vocab_size: int, token_ids_to_exclude: torch.Tensor): +def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, + token_ids_to_exclude: torch.Tensor): # Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0 draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) @@ -35,12 +35,13 @@ def get_draft_token_ids( for j in range(k): # Generate a random token ID excluding max_indices[i, j] while True: - token_id = torch.randint(0, vocab_size, (1,)).item() + token_id = torch.randint(0, vocab_size, (1, )).item() if token_id != token_ids_to_exclude[i, j]: draft_token_ids[i, j] = token_id break return draft_token_ids + @pytest.mark.parametrize("k", list(range(1, 6))) @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", list(range(1, 32))) @@ -104,7 +105,8 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, oob_token_ids[0][0] = rogue_token_id with pytest.raises(AssertionError): - typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + typical_acceptance_sampler(target_probs, bonus_token_ids, + draft_token_ids) @pytest.mark.parametrize("seed", list(range(10))) @@ -112,7 +114,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_uniform_target_distribution_accepts_all_tokens( - seed: int, disable_bonus_tokens: bool, device: str): + seed: int, disable_bonus_tokens: bool, device: str): set_random_seed(seed) k = 3 batch_size = 5 @@ -130,24 +132,26 @@ def test_uniform_target_distribution_accepts_all_tokens( high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_probs, bonus_token_ids, draft_token_ids) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) if disable_bonus_tokens: assert torch.all(output_token_ids[:, -1] == -1) else: assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze()) - - assert torch.all(output_token_ids[:, : k] == draft_token_ids) + + assert torch.all(output_token_ids[:, :k] == draft_token_ids) @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_temperature_zero_target_distribution( - seed: int, disable_bonus_tokens: bool, device: str): +def test_temperature_zero_target_distribution(seed: int, + disable_bonus_tokens: bool, + device: str): set_random_seed(seed) k = 3 batch_size = 5 @@ -164,27 +168,28 @@ def test_temperature_zero_target_distribution( batch_size, k, vocab_size) # Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0 - draft_token_ids = get_draft_token_ids( - batch_size, k, vocab_size, zero_temperature_token_ids) + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_probs, bonus_token_ids, draft_token_ids) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, -1] == -1) - assert torch.all( - output_token_ids[:, 0] == zero_temperature_token_ids[:, 0]) + assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:, + 0]) @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_mixed_target_distribution( - seed: int, disable_bonus_tokens: bool, device: str): +def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, + device: str): set_random_seed(seed) k = 3 batch_size = 4 @@ -193,12 +198,12 @@ def test_mixed_target_distribution( typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # For batches 0 and 2 set the distribution to an uniform distribution. For + # For batches 0 and 2 set the distribution to an uniform distribution. For # batches 1 and 3 set it to a temperature 0 distribution. - target_probs, zero_temperature_token_ids = ( - get_zero_temperature_prob_dist(batch_size, k, vocab_size)) - draft_token_ids = get_draft_token_ids( - batch_size, k, vocab_size, zero_temperature_token_ids) + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) # Create target_probs such that only one token_id has probability 1.0 uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) target_probs[[1, 3]] = uniform_probs @@ -206,15 +211,68 @@ def test_mixed_target_distribution( high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_probs, bonus_token_ids, draft_token_ids) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[[0, 2], 1:] == -1) - assert ( - torch.all(output_token_ids[[0, 2], 0] == zero_temperature_token_ids[[0, 2], 0])) - assert torch.all(output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) + assert (torch.all(output_token_ids[[0, 2], + 0] == zero_temperature_token_ids[[0, 2], + 0])) + assert torch.all( + output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) if disable_bonus_tokens: assert torch.all(output_token_ids[[1, 3], -1] == -1) else: assert torch.all(output_token_ids[[1, 3], -1] != -1) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, + device: str): + set_random_seed(seed) + k = 5 + batch_size = 1 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # For batches 0 and 2 set the distribution to an uniform distribution. For + # batches 1 and 3 set it to a temperature 0 distribution. + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + draft_token_ids = zero_temperature_token_ids + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids) + # Next only keep the first 2 draft tokens same as the zero temperature + # tokens. For the remaining 3 choose some other tokens. In the + # response we will expect the first 2 tokens to be the same as the + # draft tokens and the rest as -1 + draft_token_ids_to_replace = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + draft_token_ids = torch.cat( + (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) + assert torch.all(output_token_ids[:, -3:] == -1) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 301159ee0f20d..a4d51e948f5b0 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -23,12 +23,12 @@ def __init__(self, Require when bonus tokens will cause corrupt KV cache for proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. + during sampling. This catches correctness issues but adds + nontrivial latency. """ SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) nn.Module.__init__(self) - + def forward( self, target_probs: torch.Tensor, @@ -76,8 +76,6 @@ def forward( self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - print('target_probs ' + str(target_probs)) - accepted, recovered_token_ids = self._batch_modified_rejection_sampling( target_probs, draft_probs, @@ -240,6 +238,7 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny + # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 64a7f6474644a..8043358119496 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -8,15 +8,14 @@ class SpecDecodeBaseSampler(ABC): - """Apply modified rejection sampling as described in "Accelerating Large - Language Model Decoding with Speculative Sampling" - https://arxiv.org/pdf/2302.01318.pdf. + """Base class for samplers used for Speculative Decoding verification + step. """ def __init__(self, disable_bonus_tokens: bool = True, strict_mode: bool = False): - """Create a rejection sampler. + """Base class constructor. Args: disable_bonus_tokens: Whether or not to disable the bonus token. @@ -60,17 +59,29 @@ def token_id_dtype(self): def _create_output( self, accepted: torch.Tensor, # [batch_size, k] - recovered_token_ids: torch.Tensor, # [batch_size, k] + substitute_token_ids: torch.Tensor, # [batch_size, k] draft_token_ids: torch.Tensor, # [batch_size, k] bonus_token_ids: torch.Tensor, # [batch_size] ) -> torch.Tensor: """Format output. Returns a matrix of token ids. When - a token is rejected via rejection sampling, all subsequent - token ids are set to -1 for the sequence. + a token is rejected via sampling, all subsequent token ids are + set to -1 for the sequence. - shape = [batch_size, k + num_bonus_tokens] + Args: + accepted: A boolean tensor indicating if the corresponding + draft token in draft_token_ids should be accepted or not. + substitute_token_ids: A tensor of token_ids that can be used + as substitutes for the draft token ids if the proposed token + is rejected. + draft_token_ids: A tensor of token ids speculated by the + draft model. + bonus_token_ids: Token ids to use as the bonus token if + all the draft tokens are accepted. + Returns: + A tensor containing the accepted token ids. The shape of the + tensor is [batch_size, k + num_bonus_tokens] """ - batch_size, k = recovered_token_ids.shape + batch_size, k = substitute_token_ids.shape bonus_token_ids = bonus_token_ids.squeeze() # Determine the index of the first False value for each row. limits = (accepted == 0).max(1).indices @@ -108,7 +119,7 @@ def _create_output( # Fill the recovered token ids. output.mul_(~after_false_mask).add_( - recovered_token_ids.mul(after_false_mask)) + substitute_token_ids.mul(after_false_mask)) self.num_accepted_tokens += accepted.sum() self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() @@ -123,16 +134,14 @@ def _raise_if_incorrect_input( bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - self._raise_if_incorrect_shape( - target_probs, draft_token_ids, bonus_token_ids, draft_probs) - self._raise_if_incorrect_dtype( - target_probs, draft_token_ids, bonus_token_ids, draft_probs) - self._raise_if_inconsistent_device( - target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_incorrect_shape(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_incorrect_dtype(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_inconsistent_device(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], - draft_token_ids, - bonus_token_ids) - + draft_token_ids, bonus_token_ids) def _raise_if_incorrect_shape( self, @@ -143,12 +152,12 @@ def _raise_if_incorrect_shape( ) -> None: (target_batch_size, num_target_probs, target_vocab_size) = target_probs.shape - + # validate the shape of draft token ids. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape assert draft_token_ids_batch_size == target_batch_size assert num_draft_token_ids == num_target_probs - + # validate the shape of bonus token ids bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape assert bonus_batch_size == target_batch_size @@ -183,12 +192,12 @@ def _raise_if_inconsistent_device( draft_probs: Optional[torch.Tensor] = None, ) -> None: devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] if t is not None + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + if t is not None ] assert all([devices[0] == device for device in devices]) - def _raise_if_out_of_bounds_vocab( self, vocab_size: int, @@ -199,13 +208,3 @@ def _raise_if_out_of_bounds_vocab( assert torch.all(bonus_token_ids >= 0) assert torch.all(draft_token_ids < vocab_size) assert torch.all(draft_token_ids >= 0) - -def get_rejection_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): - from vllm.model_executor.layers.rejection_sampler import RejectionSampler - return RejectionSampler(disable_bonus_tokens, strict_mode) - - -def get_typical_acceptance_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): - from vllm.model_executor.layers.typical_acceptance_sampler import ( - TypicalAcceptanceSampler) - return TypicalAcceptanceSampler(disable_bonus_tokens, strict_mode) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index b4aac132e10a7..c914a52e2fb06 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -10,24 +10,40 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): - """Apply modified rejection sampling as described in "Accelerating Large - Language Model Decoding with Speculative Sampling" - https://arxiv.org/pdf/2302.01318.pdf. + """Apply typical acceptance sampling as described in section 3.3.1 in + "MEDUSA: Simple LLM Inference Acceleration Framework with + Multiple Decoding Heads" + https://arxiv.org/pdf/2401.10774 """ - def __init__(self, disable_bonus_tokens: bool = False, strict_mode: bool = False): - """Create a rejection sampler. + def __init__( + self, + disable_bonus_tokens: bool = False, + strict_mode: bool = False, + posterior_threshold: float = 0.3, + ): + """Create a Typical Acceptance Sampler. Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. + during sampling. This catches correctness issues but adds + nontrivial latency. + posterior_threshold : A threshold value that sets a lower bound + for the posterior probability. Default is 0.3. + + + """ super().__init__() SpecDecodeBaseSampler.__init__( - self, disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) + self, + disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) nn.Module.__init__(self) - + def forward( self, target_probs: torch.Tensor, @@ -54,10 +70,6 @@ def forward( speculative tokens in a sequence are accepted. shape = [batch_size, num_bonus_tokens] - draft_probs: The probability distribution over token ids given - context according to the draft model. - shape = [batch_size, num_speculative_tokens, vocab_size] - draft_token_ids: The token ids that were sampled from the draft probabilities. shape = [batch_size, num_speculative_tokens] @@ -71,15 +83,14 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input( - target_probs, draft_token_ids, bonus_token_ids) - accepted = self._evaluate_posterior(target_probs, draft_token_ids) + self._raise_if_incorrect_input(target_probs, draft_token_ids, + bonus_token_ids) + accepted = self._evaluate_accepted_tokens(target_probs, + draft_token_ids) recovered_token_ids = self._replacement_token_ids(target_probs) - output_token_ids = self._create_output( - accepted, - recovered_token_ids, - draft_token_ids, bonus_token_ids - ) + output_token_ids = self._create_output(accepted, recovered_token_ids, + draft_token_ids, + bonus_token_ids) print('----test input----') print('target_probs ' + str(target_probs)) print('draft_token_ids ' + str(draft_token_ids)) @@ -87,42 +98,55 @@ def forward( print(output_token_ids) return output_token_ids - def _evaluate_posterior( - self, target_probs, draft_token_ids, - posterior_threshold=0.3, posterior_alpha = 0.09): + def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): + r""" + Evaluates and returns a mask of accepted tokens based on the posterior probabilities. + + Parameters: + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) representing the probabilities + of each token in the vocabulary for each position in the proposed sequence. + This is the distribution generated by the target model. + draft_token_ids : torch.Tensor + A tensor of shape (batch_size, k) representing the proposed token ids. + + A draft token_id x_{n+k} is accepted if it satisifies the following condition - """Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. - Depending on the temperature value, the function either uses greedy decoding or evaluates posterior - probabilities to select the best candidate. + .. math:: + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta \exp \left( -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + + where :math:`p_{\text{original}}` corresponds to target_probs + + This method computes the posterior probabilities for the given draft token ids based on the + provided target probabilities. It calculates the entropy of the posterior distribution and + determines a dynamic threshold for each token position using the provided posterior threshold + and posterior alpha values. The method then returns a boolean mask indicating which tokens + have posterior probabilities exceeding the threshold. - Args: - - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). - - candidates (torch.Tensor): Candidate token sequences. - - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. - - posterior_threshold (float): Threshold for posterior probability. - - posterior_alpha (float): Scaling factor for the threshold. - - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8. - - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'. - - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False. Returns: - - best_candidate (torch.Tensor): Index of the chosen best candidate. - - accept_length (int): Length of the accepted candidate sequence. + ------- + torch.Tensor + A boolean tensor of shape (batch_size, k) where each element indicates whether the + corresponding draft token has been accepted or rejected. + """ candidates_prob = torch.gather( - target_probs, dim=-1, index=draft_token_ids.unsqueeze(-1) - ).squeeze(-1) + target_probs, dim=-1, + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) posterior_entropy = -torch.sum( - target_probs * torch.log(target_probs + 1e-5), dim=-1 - ) # torch.sum(torch.log(*)) is faster than torch.prod + target_probs * torch.log(target_probs + 1e-5), dim=-1) threshold = torch.minimum( - torch.ones_like(posterior_entropy) * posterior_threshold, - torch.exp(-posterior_entropy) * posterior_alpha, + torch.ones_like(posterior_entropy) * self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, ) - posterior_mask = candidates_prob > threshold - return posterior_mask + accepted_mask = candidates_prob > threshold + return accepted_mask def _replacement_token_ids(self, target_probs): max_indices = torch.argmax(target_probs[:, 0, :], dim=1) - output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), dtype=self.token_id_dtype) + output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), + dtype=self.token_id_dtype) output[:, 0] = max_indices return output From 9340244091018cdbfca3193eba075a20a31d4586 Mon Sep 17 00:00:00 2001 From: sroy745 Date: Fri, 31 May 2024 06:44:35 +0000 Subject: [PATCH 04/37] Formatting and comments. --- .../layers/spec_decode_base_sampler.py | 10 ++-- .../layers/typical_acceptance_sampler.py | 49 ++++++++----------- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 8043358119496..95f8e44f2f2de 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,13 +1,10 @@ -from abc import ABC, abstractmethod -from functools import cached_property -from typing import Optional, Tuple +from typing import Optional import torch import torch.jit -import torch.nn as nn -class SpecDecodeBaseSampler(ABC): +class SpecDecodeBaseSampler(): """Base class for samplers used for Speculative Decoding verification step. """ @@ -165,7 +162,8 @@ def _raise_if_incorrect_shape( # validate the shape of draft probs if it is set if draft_probs is not None: - draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + (draft_batch_size, num_draft_probs, + draft_vocab_size) = draft_probs.shape assert draft_batch_size == target_batch_size assert num_draft_probs == num_target_probs assert (draft_vocab_size == target_vocab_size diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index c914a52e2fb06..64c6f36bc3322 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -1,6 +1,3 @@ -from functools import cached_property -from typing import Optional, Tuple - import torch import torch.jit import torch.nn as nn @@ -33,9 +30,8 @@ def __init__( nontrivial latency. posterior_threshold : A threshold value that sets a lower bound for the posterior probability. Default is 0.3. - - - + posterior_alpha : A threshold value that sets a lower bound + for the posterior probability. Default is 0.3. """ super().__init__() SpecDecodeBaseSampler.__init__( @@ -91,45 +87,42 @@ def forward( output_token_ids = self._create_output(accepted, recovered_token_ids, draft_token_ids, bonus_token_ids) - print('----test input----') - print('target_probs ' + str(target_probs)) - print('draft_token_ids ' + str(draft_token_ids)) - print('recovered_token_ids ' + str(recovered_token_ids)) - print(output_token_ids) return output_token_ids def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): r""" - Evaluates and returns a mask of accepted tokens based on the posterior probabilities. + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. Parameters: ---------- target_probs : torch.Tensor - A tensor of shape (batch_size, k, vocab_size) representing the probabilities - of each token in the vocabulary for each position in the proposed sequence. - This is the distribution generated by the target model. + A tensor of shape (batch_size, k, vocab_size) representing + the probabilities of each token in the vocabulary for each + position in the proposed sequence. This is the distribution + generated by the target model. draft_token_ids : torch.Tensor - A tensor of shape (batch_size, k) representing the proposed token ids. + A tensor of shape (batch_size, k) representing the proposed + token ids. - A draft token_id x_{n+k} is accepted if it satisifies the following condition - + A draft token_id x_{n+k} is accepted if it satisifies the + following condition + .. math:: p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > - \min \left( \epsilon, \delta \exp \left( -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) - - where :math:`p_{\text{original}}` corresponds to target_probs + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) - This method computes the posterior probabilities for the given draft token ids based on the - provided target probabilities. It calculates the entropy of the posterior distribution and - determines a dynamic threshold for each token position using the provided posterior threshold - and posterior alpha values. The method then returns a boolean mask indicating which tokens - have posterior probabilities exceeding the threshold. + where :math:`p_{\text{original}}` corresponds to target_probs + and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters + specified using self._posterior_threshold_ and self._posterior_alpha_ Returns: ------- torch.Tensor - A boolean tensor of shape (batch_size, k) where each element indicates whether the - corresponding draft token has been accepted or rejected. + A boolean tensor of shape (batch_size, k) where each element + indicates whether the corresponding draft token has been accepted + or rejected. """ candidates_prob = torch.gather( From 7a7f9bd887cf4e9597308bd559c451bdad0da186 Mon Sep 17 00:00:00 2001 From: sroy745 Date: Fri, 31 May 2024 18:36:40 +0000 Subject: [PATCH 05/37] Added comments and some more tests. --- tests/samplers/test_rejection_sampler.py | 2 +- .../test_typical_acceptance_sampler.py | 78 +++++++++++++++---- .../layers/typical_acceptance_sampler.py | 51 ++++++++++-- 3 files changed, 110 insertions(+), 21 deletions(-) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 0cc4f73aaaf37..00a2379502e6d 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -9,7 +9,7 @@ from vllm.model_executor.utils import set_random_seed CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index db1f6265bff1e..07f9498421f1b 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -1,9 +1,7 @@ """Tests for rejection sampling.""" -from typing import List, Tuple import pytest import torch -import torch.nn.functional as F from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) @@ -13,14 +11,22 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): + """ + Generates a fake temperature zero probablity distribution. + Returns: + 1. A fake temperature zero probablity distribution of shape + [batch_size, k, vocab_size] + 2. Tensor of shape [batch_size, k] containing the token ids + of the probability 1.0 tokens at each position. + """ # Simulate temperature 0 probability distribution for target probabilities # and create target probabilities such that only 1 token id has # probability 1.0 target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) probs = torch.rand(batch_size, k, vocab_size) _, zero_temperature_token_ids = torch.max(probs, dim=-1) - # set the probability of the tokens with ids in max_indices to 1 and - # the rest to 0. + # set the probability of the tokens with ids in zero_temperature_token_ids + # to 1 and the rest to 0. target_probs = torch.zeros_like(probs).scatter_( -1, zero_temperature_token_ids.unsqueeze(-1), 1.0) return target_probs, zero_temperature_token_ids @@ -28,12 +34,16 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, token_ids_to_exclude: torch.Tensor): - # Populate draft_token_ids such that they exclude the token_ids - # with probability = 1.0 + """ + Returns a tensor of shape [batch_size, k] of fake draft token ids + drawn randomly from a vocab of size vocab_size. We however ensure + that token_ids from token_ids_to_exclude are excluded at the + corresponding positions. + """ draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) for i in range(batch_size): for j in range(k): - # Generate a random token ID excluding max_indices[i, j] + # Generate a random token ID excluding token_ids_to_exclude[i, j] while True: token_id = torch.randint(0, vocab_size, (1, )).item() if token_id != token_ids_to_exclude[i, j]: @@ -61,7 +71,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, high=vocab_size, size=(batch_size, k), dtype=torch.int64) - + # Verify that sampling succeeds for all cases. typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) @@ -87,6 +97,8 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, high=vocab_size, size=(batch_size, k), dtype=torch.int64) + # Verify that appropriate exceptions are thrown for out + # of bound vocabs. oob_token_ids = None if which_token_ids == "bonus_token_ids": oob_token_ids = bonus_token_ids @@ -135,6 +147,9 @@ def test_uniform_target_distribution_accepts_all_tokens( output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + # We are using a uniform target probability distribution. + # For a uniform distribution the entropy is very high and it + # should lead to all draft tokens being accepted. Verify that. assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) if disable_bonus_tokens: @@ -174,6 +189,11 @@ def test_temperature_zero_target_distribution(seed: int, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) + # The target probaility distribution is a temperature zero distribution + # with zero entroy. Since our draft token ids don't match the probability + # 1.0 tokens in the target distribution we will reject all of them and + # fallback to the greedy sampling for selecting 1 token for each sequence. + # Verify the same. output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) @@ -198,13 +218,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # For batches 0 and 2 set the distribution to an uniform distribution. For - # batches 1 and 3 set it to a temperature 0 distribution. + # For sequences 0 and 2 set the distribution to a temperature + # zero distribution. For sequences 1 and 3 set it to a uniform + # distribution. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) - # Create target_probs such that only one token_id has probability 1.0 uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) target_probs[[1, 3]] = uniform_probs bonus_token_ids = torch.randint(low=0, @@ -214,12 +234,19 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + # verify the shape of output_token_ids assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) + # For sequences 0 and 2 verify that only 1 token is accepted + # which is the token with probability 1.0 in the target distribution + # at position 0. assert torch.all(output_token_ids[[0, 2], 1:] == -1) assert (torch.all(output_token_ids[[0, 2], 0] == zero_temperature_token_ids[[0, 2], 0])) + # For sequences 1 and 3 verify that all tokens are accepted since the + # target probability distribution is uniform. In addition verify that + # if disable_bonus_tokens is false then we also accept the bonus tokens. assert torch.all( output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) if disable_bonus_tokens: @@ -242,8 +269,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # For batches 0 and 2 set the distribution to an uniform distribution. For - # batches 1 and 3 set it to a temperature 0 distribution. + # Create a temperature zero target probability distribution and ensure + # all draft token ids correspond to the tokens with 1.0 probability. + # Verify that all of them are accepted. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) draft_token_ids = zero_temperature_token_ids @@ -276,3 +304,27 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) assert torch.all(output_token_ids[:, -3:] == -1) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_replacement_token_ids( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 10 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + expected_replacement_tokens = -torch.ones( + (batch_size, k), dtype=torch.long) + expected_replacement_tokens[:, 0] = torch.argmax( + target_probs[:, 0, :], dim=1) + actual_replacement_tokens = ( + typical_acceptance_sampler._replacement_token_ids(target_probs)) + assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 64c6f36bc3322..edfd1faf9cb2f 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -17,7 +17,8 @@ def __init__( self, disable_bonus_tokens: bool = False, strict_mode: bool = False, - posterior_threshold: float = 0.3, + posterior_threshold: float = 0.09, + posterior_alpha: float = 0.3, ): """Create a Typical Acceptance Sampler. @@ -29,10 +30,14 @@ def __init__( during sampling. This catches correctness issues but adds nontrivial latency. posterior_threshold : A threshold value that sets a lower bound - for the posterior probability. Default is 0.3. - posterior_alpha : A threshold value that sets a lower bound - for the posterior probability. Default is 0.3. + on the posterior probability of a token in target model for it + to be accepted. Default is 0.09 + posterior_alpha : A scaling factor for the entropy-based + threshold in typical acceptance sampling. Typically defaults to + sqrt of posterior_threshold and is set to 0.3. """ + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha super().__init__() SpecDecodeBaseSampler.__init__( self, @@ -111,18 +116,27 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): .. math:: p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > \min \left( \epsilon, \delta * \exp \left( - -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) where :math:`p_{\text{original}}` corresponds to target_probs and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters - specified using self._posterior_threshold_ and self._posterior_alpha_ + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. Returns: ------- torch.Tensor A boolean tensor of shape (batch_size, k) where each element indicates whether the corresponding draft token has been accepted - or rejected. + or rejected. True indicates acceptance and false indicates + rejection. """ candidates_prob = torch.gather( @@ -138,6 +152,29 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): return accepted_mask def _replacement_token_ids(self, target_probs): + """ + Generate one replacement token ID for each sequence based on target + probabilities. The replacement token is used as the fallback option + if typical acceptance sampling does not accept any draft tokens for + that particular sequence. + + This method computes the token IDs to be replaced by selecting the + token with the highest probability for each sequence in the first + position. The rest of the output is filled with -1. + + Parameters + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) containing + the target probability distribution + + Returns + ------- + torch.Tensor + A tensor of shape (batch_size, k) with the replacement + token IDs. Only the first column is set, and the rest of the + columns are filled with -1. + """ max_indices = torch.argmax(target_probs[:, 0, :], dim=1) output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), dtype=self.token_id_dtype) From aecc2e82d9209fde33cde40a94dc19f1933426b4 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 31 May 2024 18:52:38 +0000 Subject: [PATCH 06/37] Dummy commit --- vllm/model_executor/layers/spec_decode_base_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 95f8e44f2f2de..4cd0b92cfcbae 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -13,7 +13,6 @@ def __init__(self, disable_bonus_tokens: bool = True, strict_mode: bool = False): """Base class constructor. - Args: disable_bonus_tokens: Whether or not to disable the bonus token. Require when bonus tokens will cause corrupt KV cache for From 513e25223bd526ee9e652595650d300238f84065 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 31 May 2024 19:09:19 +0000 Subject: [PATCH 07/37] Reverting change to llm_engine_example.py --- examples/llm_engine_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index a92777d92da39..a81c4b3e399c3 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -1,5 +1,4 @@ import argparse - from typing import List, Tuple from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams From 69e52f087442391ca3300911bea0797f3a0af9fc Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 31 May 2024 23:09:12 +0000 Subject: [PATCH 08/37] Updating a comment --- .../layers/typical_acceptance_sampler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index edfd1faf9cb2f..79c3358cea1e0 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -51,16 +51,15 @@ def forward( bonus_token_ids: torch.Tensor, draft_token_ids: torch.Tensor, ) -> torch.Tensor: - """Sample token ids using rejection sampling. This accepts or rejects - tokens proposed by the draft model using the probability of each token - according to the draft and target models. + """Sample token ids using typical acceptance sampling. This accepts + or rejects tokens proposed by the draft model using the probability + of each token according to the draft and target models. In the worst case where all draft tokens are rejected, it is guaranteed - one correct token will be emitted. + one token will be emitted. - In the case where all draft tokens are accepted, a bonus token will be - accepted as its cheap to have the target model score this speculative - sequence. + In the case where all draft tokens are accepted, the bonus token will be + accepted conditioned on self._disable_bonus_tokens being false. Args: target_probs: The probability distribution over token ids given From 312bc49c3942acd11d0c15733128e534f0a0de37 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 3 Jun 2024 17:47:56 +0000 Subject: [PATCH 09/37] Dummy commit --- vllm/model_executor/layers/typical_acceptance_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 79c3358cea1e0..76d14f63762f4 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -12,7 +12,6 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): Multiple Decoding Heads" https://arxiv.org/pdf/2401.10774 """ - def __init__( self, disable_bonus_tokens: bool = False, From ca412152184fd56f65cb97b7b9decaefeae372c6 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 3 Jun 2024 22:44:30 +0000 Subject: [PATCH 10/37] Fix device for tensors --- .../layers/typical_acceptance_sampler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 76d14f63762f4..b3ebfe1e4ee82 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -137,13 +137,16 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): rejection. """ + device = target_probs.device candidates_prob = torch.gather( target_probs, dim=-1, - index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + index=draft_token_ids.unsqueeze(-1).to(device)).squeeze(-1) posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + 1e-5), dim=-1) threshold = torch.minimum( - torch.ones_like(posterior_entropy) * self._posterior_threshold, + torch.ones_like( + posterior_entropy, + device=device) * self._posterior_threshold, torch.exp(-posterior_entropy) * self._posterior_alpha, ) accepted_mask = candidates_prob > threshold @@ -175,6 +178,7 @@ def _replacement_token_ids(self, target_probs): """ max_indices = torch.argmax(target_probs[:, 0, :], dim=1) output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype) + dtype=self.token_id_dtype).to( + target_probs.device) output[:, 0] = max_indices return output From ba4d3fb1d0339b57c6300e062ba174d6df09ba92 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 6 Jun 2024 05:59:20 +0000 Subject: [PATCH 11/37] Fixing review comments --- .../test_typical_acceptance_sampler.py | 73 ++++++++++++++++++- .../layers/spec_decode_base_sampler.py | 1 - .../layers/typical_acceptance_sampler.py | 11 ++- 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 07f9498421f1b..b1b84d00f7cf0 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -59,6 +59,10 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, @torch.inference_mode() def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, device: str): + """ + Tests that the TypicalAcceptancSampler forward succeeds for + different combinations of k, vocab_size, batch_size and num devices. + """ torch.set_default_device(device) typical_acceptance_sampler = TypicalAcceptanceSampler() typical_acceptance_sampler.init_gpu_tensors(rank=0) @@ -82,11 +86,16 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, @torch.inference_mode() def test_raises_when_vocab_oob(above_or_below_vocab_range: str, which_token_ids: str, device: str): + """ + Tests that we throw an exception of the token ids fall outside + the bound of the provided vocabulary. + """ k = 3 batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -127,6 +136,17 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, @torch.inference_mode() def test_uniform_target_distribution_accepts_all_tokens( seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler with a uniform target probability + distribution. + + This test verifies that when provided with a uniform target probability + distribution, the TypicalAcceptanceSampler accepts all draft tokens. The + entropy of the uniform target distribution being high should lead to all + draft tokens being accepted. The test also ensures that the behavior + regarding bonus tokens is consistent with the `disable_bonus_tokens` + flag. + """ set_random_seed(seed) k = 3 batch_size = 5 @@ -167,6 +187,17 @@ def test_uniform_target_distribution_accepts_all_tokens( def test_temperature_zero_target_distribution(seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler with a zero-temperature target + probability distribution. + + This test verifies that when using a zero-temperature target probability + distribution, where only one token has a probability of 1.0, the + TypicalAcceptanceSampler correctly rejects all draft tokens that do not + match this probability. Additionally, it ensures that when all draft + tokens are rejected, the sampler falls back to greedy sampling to select a + single token from the target distribution. + """ set_random_seed(seed) k = 3 batch_size = 5 @@ -210,6 +241,22 @@ def test_temperature_zero_target_distribution(seed: int, @torch.inference_mode() def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler with a mixed target probability + distribution. + + This test ensures that the TypicalAcceptanceSampler handles a mixed + target probability distribution correctly. Specifically, it uses a + zero-temperature distribution for some sequences and a uniform + distribution for others. The test verifies that: + + - For sequences with a zero-temperature distribution, only the token + with a probability of 1.0 is accepted, and all other tokens are rejected. + - For sequences with a uniform distribution, all draft tokens are + accepted. + - When `disable_bonus_tokens` is False, the bonus tokens are also accepted + for sequences with a uniform distribution. + """ set_random_seed(seed) k = 3 batch_size = 4 @@ -261,6 +308,20 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, @torch.inference_mode() def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler's behavior when only a subset of draft + tokens should be accepted. + + This test verifies that the TypicalAcceptanceSampler correctly accepts or + rejects draft tokens based on a zero-temperature target probability + distribution. Specifically, it ensures that: + + - When all draft tokens match tokens with a probability of 1.0 in the + target distribution, all draft tokens are accepted. + - When only some draft tokens match tokens with a probability of 1.0 in + the target distribution, only those matching tokens are accepted, and the + rest are rejected. + """ set_random_seed(seed) k = 5 batch_size = 1 @@ -312,6 +373,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, @torch.inference_mode() def test_replacement_token_ids( seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler's method for generating + replacement token IDs. + + This test verifies that the `_replacement_token_ids` method of the + TypicalAcceptanceSampler correctly identifies the token IDs to be used + as replacements based on the target probability distribution. + Specifically, it ensures that the method correctly identifies the + tokens with the highest probability for each sequence in the batch. + """ set_random_seed(seed) k = 10 batch_size = 5 diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 4cd0b92cfcbae..9856a7e7ddea0 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,7 +1,6 @@ from typing import Optional import torch -import torch.jit class SpecDecodeBaseSampler(): diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index b3ebfe1e4ee82..3bdb72948a360 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -35,14 +35,13 @@ def __init__( threshold in typical acceptance sampling. Typically defaults to sqrt of posterior_threshold and is set to 0.3. """ - self._posterior_threshold = posterior_threshold - self._posterior_alpha = posterior_alpha - super().__init__() SpecDecodeBaseSampler.__init__( self, disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) nn.Module.__init__(self) + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha def forward( self, @@ -140,7 +139,7 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): device = target_probs.device candidates_prob = torch.gather( target_probs, dim=-1, - index=draft_token_ids.unsqueeze(-1).to(device)).squeeze(-1) + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + 1e-5), dim=-1) threshold = torch.minimum( @@ -178,7 +177,7 @@ def _replacement_token_ids(self, target_probs): """ max_indices = torch.argmax(target_probs[:, 0, :], dim=1) output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype).to( - target_probs.device) + dtype=self.token_id_dtype, + device=target_probs.device) output[:, 0] = max_indices return output From c9c7a8ba9ffced77b935ffcc448e11f322a17709 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 06:41:57 +0000 Subject: [PATCH 12/37] Addressing comments --- vllm/model_executor/layers/typical_acceptance_sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 3bdb72948a360..3fb33d9c2d375 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -5,7 +5,6 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) - class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): """Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with @@ -140,8 +139,11 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): candidates_prob = torch.gather( target_probs, dim=-1, index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + # A small constant added to prevent computing the logarithm of zero, + # which can lead to undefined values. + epsilon = 1e-5 posterior_entropy = -torch.sum( - target_probs * torch.log(target_probs + 1e-5), dim=-1) + target_probs * torch.log(target_probs + epsilon), dim=-1) threshold = torch.minimum( torch.ones_like( posterior_entropy, From 0ad9afd8c36db329ee022299ffcd066737b6eefe Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 07:29:37 +0000 Subject: [PATCH 13/37] Add a new test for non default posteriors --- .../test_typical_acceptance_sampler.py | 63 +++++++++++++++++++ .../layers/spec_decode_base_sampler.py | 2 +- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index b1b84d00f7cf0..db5406e598afb 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -367,6 +367,69 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, assert torch.all(output_token_ids[:, -3:] == -1) +@pytest.mark.parametrize("seed", list(range(1))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_accept_tokens_set_non_default_posteriors( + seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler's behavior when only a subset of draft + tokens should be accepted. + + This test verifies that the TypicalAcceptanceSampler correctly accepts or + rejects draft tokens based on a zero-temperature target probability + distribution. Specifically, it ensures that: + + - When all draft tokens match tokens with a probability of 1.0 in the + target distribution, all draft tokens are accepted. + - When only some draft tokens match tokens with a probability of 1.0 in + the target distribution, only those matching tokens are accepted, and the + rest are rejected. + """ + set_random_seed(seed) + k = 5 + batch_size = 1 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # Create a temperature zero target probability distribution and ensure + # all draft token ids correspond to the tokens with 1.0 probability. + # Verify that all of them are accepted. + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + target_probs[target_probs == 0] = 0.00001 + draft_token_ids = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 1:-1] == -1) + + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens, + posterior_threshold=0.0, posterior_alpha=0.0) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids) + + @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 9856a7e7ddea0..06917a44cb6d4 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -115,7 +115,7 @@ def _create_output( # Fill the recovered token ids. output.mul_(~after_false_mask).add_( substitute_token_ids.mul(after_false_mask)) - + self.num_accepted_tokens += accepted.sum() self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() self.num_draft_tokens += batch_size * k From 9b572f7a62c4e7aec665a65183e6d07e90487f6e Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:06:43 +0000 Subject: [PATCH 14/37] Documentation for test --- .../test_typical_acceptance_sampler.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index db5406e598afb..757e07dcd77ce 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -374,18 +374,10 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, def test_accept_tokens_set_non_default_posteriors( seed: int, disable_bonus_tokens: bool, device: str): """ - Test the TypicalAcceptanceSampler's behavior when only a subset of draft - tokens should be accepted. - - This test verifies that the TypicalAcceptanceSampler correctly accepts or - rejects draft tokens based on a zero-temperature target probability - distribution. Specifically, it ensures that: - - - When all draft tokens match tokens with a probability of 1.0 in the - target distribution, all draft tokens are accepted. - - When only some draft tokens match tokens with a probability of 1.0 in - the target distribution, only those matching tokens are accepted, and the - rest are rejected. + Test the TypicalAcceptanceSampler with custom posterior thresholds and + alpha values. This test verifies that by modifying the posterior + thresholds and alpha values we can change the acceptance behavior of the + sampler. """ set_random_seed(seed) k = 5 @@ -395,9 +387,12 @@ def test_accept_tokens_set_non_default_posteriors( typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # Create a temperature zero target probability distribution and ensure - # all draft token ids correspond to the tokens with 1.0 probability. - # Verify that all of them are accepted. + # Simulate temperature 0 probability distribution for target + # probabilities and create target probabilities such that only 1 token + # id has probability 1.0 and others have a very low probability of + # 0.00001. Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0. Without any changes to the posterior thresholds + # none of the draft tokens are accepted. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) target_probs[target_probs == 0] = 0.00001 @@ -414,6 +409,9 @@ def test_accept_tokens_set_non_default_posteriors( assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 1:-1] == -1) + # Change the posterior threshold values to 0.0 so that we will + # now accept even draft tokens with very low probability in the + # target distribution. Simulate and verify the same. typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens, posterior_threshold=0.0, posterior_alpha=0.0) From 644cae444cb8d8979086efb420ecf57c5abe3a77 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:21:15 +0000 Subject: [PATCH 15/37] Fix ruff errors --- vllm/model_executor/layers/rejection_sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index a4d51e948f5b0..be87d9d92d7c5 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,11 +1,12 @@ from functools import cached_property -from typing import Optional, Tuple +from typing import Tuple import torch import torch.jit import torch.nn as nn -from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler - +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler +) class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large From 5ee4018eedc70931c654c8f5ca0897cc97ffe239 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:31:24 +0000 Subject: [PATCH 16/37] Fix spell corrections --- vllm/model_executor/layers/typical_acceptance_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 3fb33d9c2d375..521de7c21724e 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -106,7 +106,7 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): A tensor of shape (batch_size, k) representing the proposed token ids. - A draft token_id x_{n+k} is accepted if it satisifies the + A draft token_id x_{n+k} is accepted if it satisfies the following condition .. math:: From b414d425a735509377b4bcb3e5674da257086b0f Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:33:12 +0000 Subject: [PATCH 17/37] Fixing spell errors --- tests/samplers/test_typical_acceptance_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 757e07dcd77ce..341901d4f6c3c 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -12,9 +12,9 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): """ - Generates a fake temperature zero probablity distribution. + Generates a fake temperature zero probability distribution. Returns: - 1. A fake temperature zero probablity distribution of shape + 1. A fake temperature zero probability distribution of shape [batch_size, k, vocab_size] 2. Tensor of shape [batch_size, k] containing the token ids of the probability 1.0 tokens at each position. From 7aa132b90a7c0306c7754f624bf7909b57093dec Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:39:56 +0000 Subject: [PATCH 18/37] Ran format.sh --- .../test_typical_acceptance_sampler.py | 30 ++++++++++--------- .../layers/rejection_sampler.py | 5 ++-- .../layers/spec_decode_base_sampler.py | 2 +- .../layers/typical_acceptance_sampler.py | 7 +++-- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 341901d4f6c3c..87cf37bc926bc 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -94,8 +94,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( - strict_mode=True) + typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -371,8 +370,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_accept_tokens_set_non_default_posteriors( - seed: int, disable_bonus_tokens: bool, device: str): +def test_accept_tokens_set_non_default_posteriors(seed: int, + disable_bonus_tokens: bool, + device: str): """ Test the TypicalAcceptanceSampler with custom posterior thresholds and alpha values. This test verifies that by modifying the posterior @@ -391,13 +391,13 @@ def test_accept_tokens_set_non_default_posteriors( # probabilities and create target probabilities such that only 1 token # id has probability 1.0 and others have a very low probability of # 0.00001. Populate draft_token_ids such that they exclude the token_ids - # with probability = 1.0. Without any changes to the posterior thresholds + # with probability = 1.0. Without any changes to the posterior thresholds # none of the draft tokens are accepted. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) target_probs[target_probs == 0] = 0.00001 - draft_token_ids = get_draft_token_ids( - batch_size, k, vocab_size, zero_temperature_token_ids) + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -413,8 +413,10 @@ def test_accept_tokens_set_non_default_posteriors( # now accept even draft tokens with very low probability in the # target distribution. Simulate and verify the same. typical_acceptance_sampler = TypicalAcceptanceSampler( - strict_mode=True, disable_bonus_tokens=disable_bonus_tokens, - posterior_threshold=0.0, posterior_alpha=0.0) + strict_mode=True, + disable_bonus_tokens=disable_bonus_tokens, + posterior_threshold=0.0, + posterior_alpha=0.0) typical_acceptance_sampler.init_gpu_tensors(rank=0) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, @@ -432,8 +434,8 @@ def test_accept_tokens_set_non_default_posteriors( @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_replacement_token_ids( - seed: int, disable_bonus_tokens: bool, device: str): +def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool, + device: str): """ Test the TypicalAcceptanceSampler's method for generating replacement token IDs. @@ -455,8 +457,8 @@ def test_replacement_token_ids( target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) expected_replacement_tokens = -torch.ones( (batch_size, k), dtype=torch.long) - expected_replacement_tokens[:, 0] = torch.argmax( - target_probs[:, 0, :], dim=1) + expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :], + dim=1) actual_replacement_tokens = ( - typical_acceptance_sampler._replacement_token_ids(target_probs)) + typical_acceptance_sampler._replacement_token_ids(target_probs)) assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index be87d9d92d7c5..fe9b2fac1117e 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -4,9 +4,10 @@ import torch import torch.jit import torch.nn as nn + from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeBaseSampler -) + SpecDecodeBaseSampler) + class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 06917a44cb6d4..9856a7e7ddea0 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -115,7 +115,7 @@ def _create_output( # Fill the recovered token ids. output.mul_(~after_false_mask).add_( substitute_token_ids.mul(after_false_mask)) - + self.num_accepted_tokens += accepted.sum() self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() self.num_draft_tokens += batch_size * k diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 521de7c21724e..f12d6a03b4d16 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -5,12 +5,14 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) + class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): """Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" https://arxiv.org/pdf/2401.10774 """ + def __init__( self, disable_bonus_tokens: bool = False, @@ -145,9 +147,8 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + epsilon), dim=-1) threshold = torch.minimum( - torch.ones_like( - posterior_entropy, - device=device) * self._posterior_threshold, + torch.ones_like(posterior_entropy, device=device) * + self._posterior_threshold, torch.exp(-posterior_entropy) * self._posterior_alpha, ) accepted_mask = candidates_prob > threshold From 26694a7ed11bf3d0830f56a75f0b23fb2aaf586d Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 12 Jun 2024 23:51:08 +0000 Subject: [PATCH 19/37] Fix formatting --- vllm/model_executor/layers/rejection_sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 1a7d11d0d4bc2..fe9b2fac1117e 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -240,6 +240,7 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny + # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. From 1b8fd3e69bf503c38e9dbb5a829c3a850b6a202c Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 8 May 2024 19:40:10 +0000 Subject: [PATCH 20/37] Test commit --- examples/llm_engine_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index a81c4b3e399c3..a92777d92da39 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -1,4 +1,5 @@ import argparse + from typing import List, Tuple from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams From 5e80dd8200aa68d2915143151432131e074ee3fe Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 30 May 2024 07:30:18 +0000 Subject: [PATCH 21/37] Refactoring some of the logic in rejection_sampling to a base class and adding new class for acceptance sampling --- .../test_typical_acceptance_sampler.py | 220 ++++++++++++++++++ .../layers/rejection_sampler.py | 171 +------------- .../layers/spec_decode_base_sampler.py | 211 +++++++++++++++++ .../layers/typical_acceptance_sampler.py | 128 ++++++++++ 4 files changed, 567 insertions(+), 163 deletions(-) create mode 100644 tests/samplers/test_typical_acceptance_sampler.py create mode 100644 vllm/model_executor/layers/spec_decode_base_sampler.py create mode 100644 vllm/model_executor/layers/typical_acceptance_sampler.py diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py new file mode 100644 index 0000000000000..f6ce11f500451 --- /dev/null +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -0,0 +1,220 @@ +"""Tests for rejection sampling.""" +from typing import List, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) +from vllm.model_executor.utils import set_random_seed + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1) +] + +def get_zero_temperature_prob_dist(batch_size, k, vocab_size): + # Simulate temperature 0 probability distribution for target probabilities + # and create target probabilities such that only 1 token id has + # probability 1.0 + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + probs = torch.rand(batch_size, k, vocab_size) + _, max_indices = torch.max(probs, dim=-1) + # set the probability of the tokens with ids in max_indices to 1 and + # the rest to 0. + target_probs = torch.zeros_like(probs).scatter_( + -1, max_indices.unsqueeze(-1), 1.0) + return target_probs, max_indices + +def get_draft_token_ids( + batch_size: int, k: int, vocab_size: int, token_ids_to_exclude: torch.Tensor): + # Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0 + draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) + for i in range(batch_size): + for j in range(k): + # Generate a random token ID excluding max_indices[i, j] + while True: + token_id = torch.randint(0, vocab_size, (1,)).item() + if token_id != token_ids_to_exclude[i, j]: + draft_token_ids[i, j] = token_id + break + return draft_token_ids + +@pytest.mark.parametrize("k", list(range(1, 6))) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", list(range(1, 32))) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, + device: str): + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler() + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + + typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + + +@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) +@pytest.mark.parametrize("which_token_ids", + ["bonus_token_ids", "draft_token_ids"]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_raises_when_vocab_oob(above_or_below_vocab_range: str, + which_token_ids: str, device: str): + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + oob_token_ids = None + if which_token_ids == "bonus_token_ids": + oob_token_ids = bonus_token_ids + elif which_token_ids == "draft_token_ids": + oob_token_ids = draft_token_ids + else: + raise AssertionError() + + if above_or_below_vocab_range == "above": + rogue_token_id = vocab_size + 1 + elif above_or_below_vocab_range == "below": + rogue_token_id = -1 + else: + raise AssertionError() + + oob_token_ids[0][0] = rogue_token_id + + with pytest.raises(AssertionError): + typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_uniform_target_distribution_accepts_all_tokens( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler( + target_probs, bonus_token_ids, draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze()) + + assert torch.all(output_token_ids[:, : k] == draft_token_ids) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_temperature_zero_target_distribution( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 3 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # Simulate temperature 0 probability distribution for target probabilities + # and create target probabilities such that only 1 token id has + # probability 1.0 + target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( + batch_size, k, vocab_size) + # Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0 + draft_token_ids = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler( + target_probs, bonus_token_ids, draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, -1] == -1) + assert torch.all( + output_token_ids[:, 0] == zero_temperature_token_ids[:, 0]) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_mixed_target_distribution( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 3 + batch_size = 4 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # For batches 0 and 2 set the distribution to an uniform distribution. For + # batches 1 and 3 set it to a temperature 0 distribution. + target_probs, zero_temperature_token_ids = ( + get_zero_temperature_prob_dist(batch_size, k, vocab_size)) + draft_token_ids = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + # Create target_probs such that only one token_id has probability 1.0 + uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) + target_probs[[1, 3]] = uniform_probs + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler( + target_probs, bonus_token_ids, draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[[0, 2], 1:] == -1) + assert ( + torch.all(output_token_ids[[0, 2], 0] == zero_temperature_token_ids[[0, 2], 0])) + assert torch.all(output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) + if disable_bonus_tokens: + assert torch.all(output_token_ids[[1, 3], -1] == -1) + else: + assert torch.all(output_token_ids[[1, 3], -1] != -1) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index a80703155c0b6..301159ee0f20d 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -4,9 +4,10 @@ import torch import torch.jit import torch.nn as nn +from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler -class RejectionSampler(nn.Module): +class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" https://arxiv.org/pdf/2302.01318.pdf. @@ -25,37 +26,9 @@ def __init__(self, during sampling. This catches correctness issues but adds nontrivial latency. """ - super().__init__() - self._disable_bonus_tokens = disable_bonus_tokens - self._strict_mode = strict_mode - - # NOTE: A "bonus token" is accepted iff all proposal tokens are - # accepted. There is always only one possible bonus token. We store this - # value in a variable for readability. - self._num_bonus_tokens = 1 - - self.num_accepted_tokens: Optional[torch.Tensor] = None - self.num_emitted_tokens: Optional[torch.Tensor] = None - self.num_draft_tokens: int = 0 - - def init_gpu_tensors(self, rank: int) -> None: - assert self.num_accepted_tokens is None - device = f"cuda:{rank}" - self.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - self.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - - @property - def probs_dtype(self): - return torch.float32 - - @property - def token_id_dtype(self): - return torch.int64 - + SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) + nn.Module.__init__(self) + def forward( self, target_probs: torch.Tensor, @@ -100,15 +73,10 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_shape(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, + self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - self._raise_if_inconsistent_device(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], - bonus_token_ids, - draft_token_ids) + + print('target_probs ' + str(target_probs)) accepted, recovered_token_ids = self._batch_modified_rejection_sampling( target_probs, @@ -272,129 +240,6 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny - def _create_output( - self, - accepted: torch.Tensor, # [batch_size, k] - recovered_token_ids: torch.Tensor, # [batch_size, k] - draft_token_ids: torch.Tensor, # [batch_size, k] - bonus_token_ids: torch.Tensor, # [batch_size] - ) -> torch.Tensor: - """Format output. Returns a matrix of token ids. When - a token is rejected via rejection sampling, all subsequent - token ids are set to -1 for the sequence. - - shape = [batch_size, k + num_bonus_tokens] - """ - bonus_token_ids = bonus_token_ids.squeeze() - batch_size, k = recovered_token_ids.shape - - # Determine the index of the first False value for each row. - limits = (accepted == 0).max(1).indices - limits[~(accepted == 0).any(1)] = k - - # Create masks using the indices. - indices = torch.arange(k, device=accepted.device).unsqueeze(0) - accepted_mask = indices < limits.unsqueeze(1) - after_false_mask = indices == limits.unsqueeze(1) - - # Create an extended output tensor - output_with_bonus_tokens = -torch.ones( - (batch_size, k + self._num_bonus_tokens), - dtype=self.token_id_dtype, - device=accepted.device) - output = output_with_bonus_tokens[:, :k] - - # Fill in the first k columns of the output tensor using masks and data - # tensors. - torch.where(accepted_mask, - draft_token_ids, - -torch.ones_like(draft_token_ids), - out=output) - - # Fill the last column. - # We check output directly as accepted may have True values inconsistent - # with causal acceptance. - output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, - bonus_token_ids, -1) - - # We disable bonus tokens because it causes corrupt KV cache for - # proposal methods that require KV cache. We can fix it by "prefilling" - # the bonus token in the proposer. The following issue tracks the fix. - # https://github.com/vllm-project/vllm/issues/4212 - if self._disable_bonus_tokens: - output_with_bonus_tokens[:, -1] = -1 - - # Fill the recovered token ids. - output.mul_(~after_false_mask).add_( - recovered_token_ids.mul(after_false_mask)) - - self.num_accepted_tokens += accepted.sum() - self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() - self.num_draft_tokens += batch_size * k - - return output_with_bonus_tokens - - def _raise_if_incorrect_shape( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - (target_batch_size, num_target_probs, - target_vocab_size) = target_probs.shape - bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape - draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape - draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape - - assert draft_batch_size == target_batch_size - assert num_draft_probs == num_target_probs - assert (draft_vocab_size == target_vocab_size - ), f"{draft_vocab_size=} {target_vocab_size=}" - - assert draft_token_ids_batch_size == draft_batch_size - assert num_draft_token_ids == num_draft_probs - - assert bonus_batch_size == target_batch_size - assert num_bonus_tokens == self._num_bonus_tokens - - def _raise_if_incorrect_dtype( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert all(probs.dtype == self.probs_dtype - for probs in [target_probs, draft_probs]) - assert all(token_ids.dtype == self.token_id_dtype - for token_ids in [bonus_token_ids, draft_token_ids]) - - def _raise_if_inconsistent_device( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] - ] - assert all([devices[0] == device for device in devices]) - - def _raise_if_out_of_bounds_vocab( - self, - vocab_size: int, - bonus_token_ids: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert torch.all(bonus_token_ids < vocab_size) - assert torch.all(bonus_token_ids >= 0) - assert torch.all(draft_token_ids < vocab_size) - assert torch.all(draft_token_ids >= 0) - - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py new file mode 100644 index 0000000000000..64a7f6474644a --- /dev/null +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -0,0 +1,211 @@ +from abc import ABC, abstractmethod +from functools import cached_property +from typing import Optional, Tuple + +import torch +import torch.jit +import torch.nn as nn + + +class SpecDecodeBaseSampler(ABC): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): + """Create a rejection sampler. + + Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + recovered_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via rejection sampling, all subsequent + token ids are set to -1 for the sequence. + + shape = [batch_size, k + num_bonus_tokens] + """ + batch_size, k = recovered_token_ids.shape + bonus_token_ids = bonus_token_ids.squeeze() + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # We disable bonus tokens because it causes corrupt KV cache for + # proposal methods that require KV cache. We can fix it by "prefilling" + # the bonus token in the proposer. The following issue tracks the fix. + # https://github.com/vllm-project/vllm/issues/4212 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + recovered_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_input( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + self._raise_if_incorrect_shape( + target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_incorrect_dtype( + target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_inconsistent_device( + target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + draft_token_ids, + bonus_token_ids) + + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + + # validate the shape of draft token ids. + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + assert draft_token_ids_batch_size == target_batch_size + assert num_draft_token_ids == num_target_probs + + # validate the shape of bonus token ids + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + # validate the shape of draft probs if it is set + if draft_probs is not None: + draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + assert target_probs.dtype == self.probs_dtype + assert draft_token_ids.dtype == self.token_id_dtype + assert bonus_token_ids.dtype == self.token_id_dtype + if draft_probs is not None: + assert draft_probs.dtype == self.probs_dtype + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] if t is not None + ] + assert all([devices[0] == device for device in devices]) + + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + +def get_rejection_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): + from vllm.model_executor.layers.rejection_sampler import RejectionSampler + return RejectionSampler(disable_bonus_tokens, strict_mode) + + +def get_typical_acceptance_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): + from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) + return TypicalAcceptanceSampler(disable_bonus_tokens, strict_mode) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py new file mode 100644 index 0000000000000..b4aac132e10a7 --- /dev/null +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -0,0 +1,128 @@ +from functools import cached_property +from typing import Optional, Tuple + +import torch +import torch.jit +import torch.nn as nn + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) + + +class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, disable_bonus_tokens: bool = False, strict_mode: bool = False): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + SpecDecodeBaseSampler.__init__( + self, disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) + nn.Module.__init__(self) + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input( + target_probs, draft_token_ids, bonus_token_ids) + accepted = self._evaluate_posterior(target_probs, draft_token_ids) + recovered_token_ids = self._replacement_token_ids(target_probs) + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, bonus_token_ids + ) + print('----test input----') + print('target_probs ' + str(target_probs)) + print('draft_token_ids ' + str(draft_token_ids)) + print('recovered_token_ids ' + str(recovered_token_ids)) + print(output_token_ids) + return output_token_ids + + def _evaluate_posterior( + self, target_probs, draft_token_ids, + posterior_threshold=0.3, posterior_alpha = 0.09): + + """Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. + Depending on the temperature value, the function either uses greedy decoding or evaluates posterior + probabilities to select the best candidate. + + Args: + - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). + - candidates (torch.Tensor): Candidate token sequences. + - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. + - posterior_threshold (float): Threshold for posterior probability. + - posterior_alpha (float): Scaling factor for the threshold. + - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8. + - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'. + - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False. + Returns: + - best_candidate (torch.Tensor): Index of the chosen best candidate. + - accept_length (int): Length of the accepted candidate sequence. + """ + candidates_prob = torch.gather( + target_probs, dim=-1, index=draft_token_ids.unsqueeze(-1) + ).squeeze(-1) + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + 1e-5), dim=-1 + ) # torch.sum(torch.log(*)) is faster than torch.prod + threshold = torch.minimum( + torch.ones_like(posterior_entropy) * posterior_threshold, + torch.exp(-posterior_entropy) * posterior_alpha, + ) + posterior_mask = candidates_prob > threshold + return posterior_mask + + def _replacement_token_ids(self, target_probs): + max_indices = torch.argmax(target_probs[:, 0, :], dim=1) + output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), dtype=self.token_id_dtype) + output[:, 0] = max_indices + return output From 77dcb792a0e66eedadc5fe169e949cf07942af99 Mon Sep 17 00:00:00 2001 From: sroy745 Date: Fri, 31 May 2024 06:04:13 +0000 Subject: [PATCH 22/37] Adding comments and some new tests. --- tests/samplers/test_rejection_sampler.py | 2 +- .../test_typical_acceptance_sampler.py | 128 +++++++++++++----- .../layers/rejection_sampler.py | 9 +- .../layers/spec_decode_base_sampler.py | 67 +++++---- .../layers/typical_acceptance_sampler.py | 118 +++++++++------- 5 files changed, 202 insertions(+), 122 deletions(-) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 6dd643bbea2bb..34bd53d0d7d4c 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -9,7 +9,7 @@ from vllm.model_executor.utils import set_random_seed CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) ] diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index f6ce11f500451..db1f6265bff1e 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -9,9 +9,8 @@ TypicalAcceptanceSampler) from vllm.model_executor.utils import set_random_seed -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1)] + def get_zero_temperature_prob_dist(batch_size, k, vocab_size): # Simulate temperature 0 probability distribution for target probabilities @@ -19,15 +18,16 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): # probability 1.0 target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) probs = torch.rand(batch_size, k, vocab_size) - _, max_indices = torch.max(probs, dim=-1) + _, zero_temperature_token_ids = torch.max(probs, dim=-1) # set the probability of the tokens with ids in max_indices to 1 and # the rest to 0. target_probs = torch.zeros_like(probs).scatter_( - -1, max_indices.unsqueeze(-1), 1.0) - return target_probs, max_indices + -1, zero_temperature_token_ids.unsqueeze(-1), 1.0) + return target_probs, zero_temperature_token_ids + -def get_draft_token_ids( - batch_size: int, k: int, vocab_size: int, token_ids_to_exclude: torch.Tensor): +def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, + token_ids_to_exclude: torch.Tensor): # Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0 draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) @@ -35,12 +35,13 @@ def get_draft_token_ids( for j in range(k): # Generate a random token ID excluding max_indices[i, j] while True: - token_id = torch.randint(0, vocab_size, (1,)).item() + token_id = torch.randint(0, vocab_size, (1, )).item() if token_id != token_ids_to_exclude[i, j]: draft_token_ids[i, j] = token_id break return draft_token_ids + @pytest.mark.parametrize("k", list(range(1, 6))) @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", list(range(1, 32))) @@ -104,7 +105,8 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, oob_token_ids[0][0] = rogue_token_id with pytest.raises(AssertionError): - typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + typical_acceptance_sampler(target_probs, bonus_token_ids, + draft_token_ids) @pytest.mark.parametrize("seed", list(range(10))) @@ -112,7 +114,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_uniform_target_distribution_accepts_all_tokens( - seed: int, disable_bonus_tokens: bool, device: str): + seed: int, disable_bonus_tokens: bool, device: str): set_random_seed(seed) k = 3 batch_size = 5 @@ -130,24 +132,26 @@ def test_uniform_target_distribution_accepts_all_tokens( high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_probs, bonus_token_ids, draft_token_ids) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) if disable_bonus_tokens: assert torch.all(output_token_ids[:, -1] == -1) else: assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze()) - - assert torch.all(output_token_ids[:, : k] == draft_token_ids) + + assert torch.all(output_token_ids[:, :k] == draft_token_ids) @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_temperature_zero_target_distribution( - seed: int, disable_bonus_tokens: bool, device: str): +def test_temperature_zero_target_distribution(seed: int, + disable_bonus_tokens: bool, + device: str): set_random_seed(seed) k = 3 batch_size = 5 @@ -164,27 +168,28 @@ def test_temperature_zero_target_distribution( batch_size, k, vocab_size) # Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0 - draft_token_ids = get_draft_token_ids( - batch_size, k, vocab_size, zero_temperature_token_ids) + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_probs, bonus_token_ids, draft_token_ids) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, -1] == -1) - assert torch.all( - output_token_ids[:, 0] == zero_temperature_token_ids[:, 0]) + assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:, + 0]) @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_mixed_target_distribution( - seed: int, disable_bonus_tokens: bool, device: str): +def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, + device: str): set_random_seed(seed) k = 3 batch_size = 4 @@ -193,12 +198,12 @@ def test_mixed_target_distribution( typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # For batches 0 and 2 set the distribution to an uniform distribution. For + # For batches 0 and 2 set the distribution to an uniform distribution. For # batches 1 and 3 set it to a temperature 0 distribution. - target_probs, zero_temperature_token_ids = ( - get_zero_temperature_prob_dist(batch_size, k, vocab_size)) - draft_token_ids = get_draft_token_ids( - batch_size, k, vocab_size, zero_temperature_token_ids) + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) # Create target_probs such that only one token_id has probability 1.0 uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) target_probs[[1, 3]] = uniform_probs @@ -206,15 +211,68 @@ def test_mixed_target_distribution( high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_probs, bonus_token_ids, draft_token_ids) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[[0, 2], 1:] == -1) - assert ( - torch.all(output_token_ids[[0, 2], 0] == zero_temperature_token_ids[[0, 2], 0])) - assert torch.all(output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) + assert (torch.all(output_token_ids[[0, 2], + 0] == zero_temperature_token_ids[[0, 2], + 0])) + assert torch.all( + output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) if disable_bonus_tokens: assert torch.all(output_token_ids[[1, 3], -1] == -1) else: assert torch.all(output_token_ids[[1, 3], -1] != -1) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, + device: str): + set_random_seed(seed) + k = 5 + batch_size = 1 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # For batches 0 and 2 set the distribution to an uniform distribution. For + # batches 1 and 3 set it to a temperature 0 distribution. + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + draft_token_ids = zero_temperature_token_ids + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids) + # Next only keep the first 2 draft tokens same as the zero temperature + # tokens. For the remaining 3 choose some other tokens. In the + # response we will expect the first 2 tokens to be the same as the + # draft tokens and the rest as -1 + draft_token_ids_to_replace = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + draft_token_ids = torch.cat( + (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) + assert torch.all(output_token_ids[:, -3:] == -1) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 301159ee0f20d..a4d51e948f5b0 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -23,12 +23,12 @@ def __init__(self, Require when bonus tokens will cause corrupt KV cache for proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. + during sampling. This catches correctness issues but adds + nontrivial latency. """ SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) nn.Module.__init__(self) - + def forward( self, target_probs: torch.Tensor, @@ -76,8 +76,6 @@ def forward( self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - print('target_probs ' + str(target_probs)) - accepted, recovered_token_ids = self._batch_modified_rejection_sampling( target_probs, draft_probs, @@ -240,6 +238,7 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny + # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 64a7f6474644a..8043358119496 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -8,15 +8,14 @@ class SpecDecodeBaseSampler(ABC): - """Apply modified rejection sampling as described in "Accelerating Large - Language Model Decoding with Speculative Sampling" - https://arxiv.org/pdf/2302.01318.pdf. + """Base class for samplers used for Speculative Decoding verification + step. """ def __init__(self, disable_bonus_tokens: bool = True, strict_mode: bool = False): - """Create a rejection sampler. + """Base class constructor. Args: disable_bonus_tokens: Whether or not to disable the bonus token. @@ -60,17 +59,29 @@ def token_id_dtype(self): def _create_output( self, accepted: torch.Tensor, # [batch_size, k] - recovered_token_ids: torch.Tensor, # [batch_size, k] + substitute_token_ids: torch.Tensor, # [batch_size, k] draft_token_ids: torch.Tensor, # [batch_size, k] bonus_token_ids: torch.Tensor, # [batch_size] ) -> torch.Tensor: """Format output. Returns a matrix of token ids. When - a token is rejected via rejection sampling, all subsequent - token ids are set to -1 for the sequence. + a token is rejected via sampling, all subsequent token ids are + set to -1 for the sequence. - shape = [batch_size, k + num_bonus_tokens] + Args: + accepted: A boolean tensor indicating if the corresponding + draft token in draft_token_ids should be accepted or not. + substitute_token_ids: A tensor of token_ids that can be used + as substitutes for the draft token ids if the proposed token + is rejected. + draft_token_ids: A tensor of token ids speculated by the + draft model. + bonus_token_ids: Token ids to use as the bonus token if + all the draft tokens are accepted. + Returns: + A tensor containing the accepted token ids. The shape of the + tensor is [batch_size, k + num_bonus_tokens] """ - batch_size, k = recovered_token_ids.shape + batch_size, k = substitute_token_ids.shape bonus_token_ids = bonus_token_ids.squeeze() # Determine the index of the first False value for each row. limits = (accepted == 0).max(1).indices @@ -108,7 +119,7 @@ def _create_output( # Fill the recovered token ids. output.mul_(~after_false_mask).add_( - recovered_token_ids.mul(after_false_mask)) + substitute_token_ids.mul(after_false_mask)) self.num_accepted_tokens += accepted.sum() self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() @@ -123,16 +134,14 @@ def _raise_if_incorrect_input( bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - self._raise_if_incorrect_shape( - target_probs, draft_token_ids, bonus_token_ids, draft_probs) - self._raise_if_incorrect_dtype( - target_probs, draft_token_ids, bonus_token_ids, draft_probs) - self._raise_if_inconsistent_device( - target_probs, draft_token_ids, bonus_token_ids, draft_probs) + self._raise_if_incorrect_shape(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_incorrect_dtype(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_inconsistent_device(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], - draft_token_ids, - bonus_token_ids) - + draft_token_ids, bonus_token_ids) def _raise_if_incorrect_shape( self, @@ -143,12 +152,12 @@ def _raise_if_incorrect_shape( ) -> None: (target_batch_size, num_target_probs, target_vocab_size) = target_probs.shape - + # validate the shape of draft token ids. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape assert draft_token_ids_batch_size == target_batch_size assert num_draft_token_ids == num_target_probs - + # validate the shape of bonus token ids bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape assert bonus_batch_size == target_batch_size @@ -183,12 +192,12 @@ def _raise_if_inconsistent_device( draft_probs: Optional[torch.Tensor] = None, ) -> None: devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] if t is not None + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + if t is not None ] assert all([devices[0] == device for device in devices]) - def _raise_if_out_of_bounds_vocab( self, vocab_size: int, @@ -199,13 +208,3 @@ def _raise_if_out_of_bounds_vocab( assert torch.all(bonus_token_ids >= 0) assert torch.all(draft_token_ids < vocab_size) assert torch.all(draft_token_ids >= 0) - -def get_rejection_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): - from vllm.model_executor.layers.rejection_sampler import RejectionSampler - return RejectionSampler(disable_bonus_tokens, strict_mode) - - -def get_typical_acceptance_sampler(disable_bonus_tokens: bool = True, strict_mode: bool = False): - from vllm.model_executor.layers.typical_acceptance_sampler import ( - TypicalAcceptanceSampler) - return TypicalAcceptanceSampler(disable_bonus_tokens, strict_mode) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index b4aac132e10a7..c914a52e2fb06 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -10,24 +10,40 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): - """Apply modified rejection sampling as described in "Accelerating Large - Language Model Decoding with Speculative Sampling" - https://arxiv.org/pdf/2302.01318.pdf. + """Apply typical acceptance sampling as described in section 3.3.1 in + "MEDUSA: Simple LLM Inference Acceleration Framework with + Multiple Decoding Heads" + https://arxiv.org/pdf/2401.10774 """ - def __init__(self, disable_bonus_tokens: bool = False, strict_mode: bool = False): - """Create a rejection sampler. + def __init__( + self, + disable_bonus_tokens: bool = False, + strict_mode: bool = False, + posterior_threshold: float = 0.3, + ): + """Create a Typical Acceptance Sampler. Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. + during sampling. This catches correctness issues but adds + nontrivial latency. + posterior_threshold : A threshold value that sets a lower bound + for the posterior probability. Default is 0.3. + + + """ super().__init__() SpecDecodeBaseSampler.__init__( - self, disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) + self, + disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) nn.Module.__init__(self) - + def forward( self, target_probs: torch.Tensor, @@ -54,10 +70,6 @@ def forward( speculative tokens in a sequence are accepted. shape = [batch_size, num_bonus_tokens] - draft_probs: The probability distribution over token ids given - context according to the draft model. - shape = [batch_size, num_speculative_tokens, vocab_size] - draft_token_ids: The token ids that were sampled from the draft probabilities. shape = [batch_size, num_speculative_tokens] @@ -71,15 +83,14 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input( - target_probs, draft_token_ids, bonus_token_ids) - accepted = self._evaluate_posterior(target_probs, draft_token_ids) + self._raise_if_incorrect_input(target_probs, draft_token_ids, + bonus_token_ids) + accepted = self._evaluate_accepted_tokens(target_probs, + draft_token_ids) recovered_token_ids = self._replacement_token_ids(target_probs) - output_token_ids = self._create_output( - accepted, - recovered_token_ids, - draft_token_ids, bonus_token_ids - ) + output_token_ids = self._create_output(accepted, recovered_token_ids, + draft_token_ids, + bonus_token_ids) print('----test input----') print('target_probs ' + str(target_probs)) print('draft_token_ids ' + str(draft_token_ids)) @@ -87,42 +98,55 @@ def forward( print(output_token_ids) return output_token_ids - def _evaluate_posterior( - self, target_probs, draft_token_ids, - posterior_threshold=0.3, posterior_alpha = 0.09): + def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): + r""" + Evaluates and returns a mask of accepted tokens based on the posterior probabilities. + + Parameters: + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) representing the probabilities + of each token in the vocabulary for each position in the proposed sequence. + This is the distribution generated by the target model. + draft_token_ids : torch.Tensor + A tensor of shape (batch_size, k) representing the proposed token ids. + + A draft token_id x_{n+k} is accepted if it satisifies the following condition - """Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. - Depending on the temperature value, the function either uses greedy decoding or evaluates posterior - probabilities to select the best candidate. + .. math:: + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta \exp \left( -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + + where :math:`p_{\text{original}}` corresponds to target_probs + + This method computes the posterior probabilities for the given draft token ids based on the + provided target probabilities. It calculates the entropy of the posterior distribution and + determines a dynamic threshold for each token position using the provided posterior threshold + and posterior alpha values. The method then returns a boolean mask indicating which tokens + have posterior probabilities exceeding the threshold. - Args: - - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). - - candidates (torch.Tensor): Candidate token sequences. - - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. - - posterior_threshold (float): Threshold for posterior probability. - - posterior_alpha (float): Scaling factor for the threshold. - - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8. - - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'. - - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False. Returns: - - best_candidate (torch.Tensor): Index of the chosen best candidate. - - accept_length (int): Length of the accepted candidate sequence. + ------- + torch.Tensor + A boolean tensor of shape (batch_size, k) where each element indicates whether the + corresponding draft token has been accepted or rejected. + """ candidates_prob = torch.gather( - target_probs, dim=-1, index=draft_token_ids.unsqueeze(-1) - ).squeeze(-1) + target_probs, dim=-1, + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) posterior_entropy = -torch.sum( - target_probs * torch.log(target_probs + 1e-5), dim=-1 - ) # torch.sum(torch.log(*)) is faster than torch.prod + target_probs * torch.log(target_probs + 1e-5), dim=-1) threshold = torch.minimum( - torch.ones_like(posterior_entropy) * posterior_threshold, - torch.exp(-posterior_entropy) * posterior_alpha, + torch.ones_like(posterior_entropy) * self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, ) - posterior_mask = candidates_prob > threshold - return posterior_mask + accepted_mask = candidates_prob > threshold + return accepted_mask def _replacement_token_ids(self, target_probs): max_indices = torch.argmax(target_probs[:, 0, :], dim=1) - output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), dtype=self.token_id_dtype) + output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), + dtype=self.token_id_dtype) output[:, 0] = max_indices return output From db3e6faf40cf1709bfa4bed37370e9ca014d79bb Mon Sep 17 00:00:00 2001 From: sroy745 Date: Fri, 31 May 2024 06:44:35 +0000 Subject: [PATCH 23/37] Formatting and comments. --- .../layers/spec_decode_base_sampler.py | 10 ++-- .../layers/typical_acceptance_sampler.py | 49 ++++++++----------- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 8043358119496..95f8e44f2f2de 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,13 +1,10 @@ -from abc import ABC, abstractmethod -from functools import cached_property -from typing import Optional, Tuple +from typing import Optional import torch import torch.jit -import torch.nn as nn -class SpecDecodeBaseSampler(ABC): +class SpecDecodeBaseSampler(): """Base class for samplers used for Speculative Decoding verification step. """ @@ -165,7 +162,8 @@ def _raise_if_incorrect_shape( # validate the shape of draft probs if it is set if draft_probs is not None: - draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + (draft_batch_size, num_draft_probs, + draft_vocab_size) = draft_probs.shape assert draft_batch_size == target_batch_size assert num_draft_probs == num_target_probs assert (draft_vocab_size == target_vocab_size diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index c914a52e2fb06..64c6f36bc3322 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -1,6 +1,3 @@ -from functools import cached_property -from typing import Optional, Tuple - import torch import torch.jit import torch.nn as nn @@ -33,9 +30,8 @@ def __init__( nontrivial latency. posterior_threshold : A threshold value that sets a lower bound for the posterior probability. Default is 0.3. - - - + posterior_alpha : A threshold value that sets a lower bound + for the posterior probability. Default is 0.3. """ super().__init__() SpecDecodeBaseSampler.__init__( @@ -91,45 +87,42 @@ def forward( output_token_ids = self._create_output(accepted, recovered_token_ids, draft_token_ids, bonus_token_ids) - print('----test input----') - print('target_probs ' + str(target_probs)) - print('draft_token_ids ' + str(draft_token_ids)) - print('recovered_token_ids ' + str(recovered_token_ids)) - print(output_token_ids) return output_token_ids def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): r""" - Evaluates and returns a mask of accepted tokens based on the posterior probabilities. + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. Parameters: ---------- target_probs : torch.Tensor - A tensor of shape (batch_size, k, vocab_size) representing the probabilities - of each token in the vocabulary for each position in the proposed sequence. - This is the distribution generated by the target model. + A tensor of shape (batch_size, k, vocab_size) representing + the probabilities of each token in the vocabulary for each + position in the proposed sequence. This is the distribution + generated by the target model. draft_token_ids : torch.Tensor - A tensor of shape (batch_size, k) representing the proposed token ids. + A tensor of shape (batch_size, k) representing the proposed + token ids. - A draft token_id x_{n+k} is accepted if it satisifies the following condition - + A draft token_id x_{n+k} is accepted if it satisifies the + following condition + .. math:: p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > - \min \left( \epsilon, \delta \exp \left( -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) - - where :math:`p_{\text{original}}` corresponds to target_probs + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) - This method computes the posterior probabilities for the given draft token ids based on the - provided target probabilities. It calculates the entropy of the posterior distribution and - determines a dynamic threshold for each token position using the provided posterior threshold - and posterior alpha values. The method then returns a boolean mask indicating which tokens - have posterior probabilities exceeding the threshold. + where :math:`p_{\text{original}}` corresponds to target_probs + and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters + specified using self._posterior_threshold_ and self._posterior_alpha_ Returns: ------- torch.Tensor - A boolean tensor of shape (batch_size, k) where each element indicates whether the - corresponding draft token has been accepted or rejected. + A boolean tensor of shape (batch_size, k) where each element + indicates whether the corresponding draft token has been accepted + or rejected. """ candidates_prob = torch.gather( From 7d876d414618494fa2d25863446e8f35e362a617 Mon Sep 17 00:00:00 2001 From: sroy745 Date: Fri, 31 May 2024 18:36:40 +0000 Subject: [PATCH 24/37] Added comments and some more tests. --- tests/samplers/test_rejection_sampler.py | 2 +- .../test_typical_acceptance_sampler.py | 78 +++++++++++++++---- .../layers/typical_acceptance_sampler.py | 51 ++++++++++-- 3 files changed, 110 insertions(+), 21 deletions(-) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 34bd53d0d7d4c..6dd643bbea2bb 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -9,7 +9,7 @@ from vllm.model_executor.utils import set_random_seed CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index db1f6265bff1e..07f9498421f1b 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -1,9 +1,7 @@ """Tests for rejection sampling.""" -from typing import List, Tuple import pytest import torch -import torch.nn.functional as F from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) @@ -13,14 +11,22 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): + """ + Generates a fake temperature zero probablity distribution. + Returns: + 1. A fake temperature zero probablity distribution of shape + [batch_size, k, vocab_size] + 2. Tensor of shape [batch_size, k] containing the token ids + of the probability 1.0 tokens at each position. + """ # Simulate temperature 0 probability distribution for target probabilities # and create target probabilities such that only 1 token id has # probability 1.0 target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) probs = torch.rand(batch_size, k, vocab_size) _, zero_temperature_token_ids = torch.max(probs, dim=-1) - # set the probability of the tokens with ids in max_indices to 1 and - # the rest to 0. + # set the probability of the tokens with ids in zero_temperature_token_ids + # to 1 and the rest to 0. target_probs = torch.zeros_like(probs).scatter_( -1, zero_temperature_token_ids.unsqueeze(-1), 1.0) return target_probs, zero_temperature_token_ids @@ -28,12 +34,16 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, token_ids_to_exclude: torch.Tensor): - # Populate draft_token_ids such that they exclude the token_ids - # with probability = 1.0 + """ + Returns a tensor of shape [batch_size, k] of fake draft token ids + drawn randomly from a vocab of size vocab_size. We however ensure + that token_ids from token_ids_to_exclude are excluded at the + corresponding positions. + """ draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) for i in range(batch_size): for j in range(k): - # Generate a random token ID excluding max_indices[i, j] + # Generate a random token ID excluding token_ids_to_exclude[i, j] while True: token_id = torch.randint(0, vocab_size, (1, )).item() if token_id != token_ids_to_exclude[i, j]: @@ -61,7 +71,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, high=vocab_size, size=(batch_size, k), dtype=torch.int64) - + # Verify that sampling succeeds for all cases. typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) @@ -87,6 +97,8 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, high=vocab_size, size=(batch_size, k), dtype=torch.int64) + # Verify that appropriate exceptions are thrown for out + # of bound vocabs. oob_token_ids = None if which_token_ids == "bonus_token_ids": oob_token_ids = bonus_token_ids @@ -135,6 +147,9 @@ def test_uniform_target_distribution_accepts_all_tokens( output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + # We are using a uniform target probability distribution. + # For a uniform distribution the entropy is very high and it + # should lead to all draft tokens being accepted. Verify that. assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) if disable_bonus_tokens: @@ -174,6 +189,11 @@ def test_temperature_zero_target_distribution(seed: int, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) + # The target probaility distribution is a temperature zero distribution + # with zero entroy. Since our draft token ids don't match the probability + # 1.0 tokens in the target distribution we will reject all of them and + # fallback to the greedy sampling for selecting 1 token for each sequence. + # Verify the same. output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) @@ -198,13 +218,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # For batches 0 and 2 set the distribution to an uniform distribution. For - # batches 1 and 3 set it to a temperature 0 distribution. + # For sequences 0 and 2 set the distribution to a temperature + # zero distribution. For sequences 1 and 3 set it to a uniform + # distribution. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) - # Create target_probs such that only one token_id has probability 1.0 uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) target_probs[[1, 3]] = uniform_probs bonus_token_ids = torch.randint(low=0, @@ -214,12 +234,19 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + # verify the shape of output_token_ids assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) + # For sequences 0 and 2 verify that only 1 token is accepted + # which is the token with probability 1.0 in the target distribution + # at position 0. assert torch.all(output_token_ids[[0, 2], 1:] == -1) assert (torch.all(output_token_ids[[0, 2], 0] == zero_temperature_token_ids[[0, 2], 0])) + # For sequences 1 and 3 verify that all tokens are accepted since the + # target probability distribution is uniform. In addition verify that + # if disable_bonus_tokens is false then we also accept the bonus tokens. assert torch.all( output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) if disable_bonus_tokens: @@ -242,8 +269,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # For batches 0 and 2 set the distribution to an uniform distribution. For - # batches 1 and 3 set it to a temperature 0 distribution. + # Create a temperature zero target probability distribution and ensure + # all draft token ids correspond to the tokens with 1.0 probability. + # Verify that all of them are accepted. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) draft_token_ids = zero_temperature_token_ids @@ -276,3 +304,27 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) assert torch.all(output_token_ids[:, -3:] == -1) + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_replacement_token_ids( + seed: int, disable_bonus_tokens: bool, device: str): + set_random_seed(seed) + k = 10 + batch_size = 5 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + expected_replacement_tokens = -torch.ones( + (batch_size, k), dtype=torch.long) + expected_replacement_tokens[:, 0] = torch.argmax( + target_probs[:, 0, :], dim=1) + actual_replacement_tokens = ( + typical_acceptance_sampler._replacement_token_ids(target_probs)) + assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 64c6f36bc3322..edfd1faf9cb2f 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -17,7 +17,8 @@ def __init__( self, disable_bonus_tokens: bool = False, strict_mode: bool = False, - posterior_threshold: float = 0.3, + posterior_threshold: float = 0.09, + posterior_alpha: float = 0.3, ): """Create a Typical Acceptance Sampler. @@ -29,10 +30,14 @@ def __init__( during sampling. This catches correctness issues but adds nontrivial latency. posterior_threshold : A threshold value that sets a lower bound - for the posterior probability. Default is 0.3. - posterior_alpha : A threshold value that sets a lower bound - for the posterior probability. Default is 0.3. + on the posterior probability of a token in target model for it + to be accepted. Default is 0.09 + posterior_alpha : A scaling factor for the entropy-based + threshold in typical acceptance sampling. Typically defaults to + sqrt of posterior_threshold and is set to 0.3. """ + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha super().__init__() SpecDecodeBaseSampler.__init__( self, @@ -111,18 +116,27 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): .. math:: p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > \min \left( \epsilon, \delta * \exp \left( - -H(p_{\text{original}}(\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) where :math:`p_{\text{original}}` corresponds to target_probs and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters - specified using self._posterior_threshold_ and self._posterior_alpha_ + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. Returns: ------- torch.Tensor A boolean tensor of shape (batch_size, k) where each element indicates whether the corresponding draft token has been accepted - or rejected. + or rejected. True indicates acceptance and false indicates + rejection. """ candidates_prob = torch.gather( @@ -138,6 +152,29 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): return accepted_mask def _replacement_token_ids(self, target_probs): + """ + Generate one replacement token ID for each sequence based on target + probabilities. The replacement token is used as the fallback option + if typical acceptance sampling does not accept any draft tokens for + that particular sequence. + + This method computes the token IDs to be replaced by selecting the + token with the highest probability for each sequence in the first + position. The rest of the output is filled with -1. + + Parameters + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) containing + the target probability distribution + + Returns + ------- + torch.Tensor + A tensor of shape (batch_size, k) with the replacement + token IDs. Only the first column is set, and the rest of the + columns are filled with -1. + """ max_indices = torch.argmax(target_probs[:, 0, :], dim=1) output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), dtype=self.token_id_dtype) From 78d011bf83e8e5f324913cec40966bdd9aa3113a Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 31 May 2024 18:52:38 +0000 Subject: [PATCH 25/37] Dummy commit --- vllm/model_executor/layers/spec_decode_base_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 95f8e44f2f2de..4cd0b92cfcbae 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -13,7 +13,6 @@ def __init__(self, disable_bonus_tokens: bool = True, strict_mode: bool = False): """Base class constructor. - Args: disable_bonus_tokens: Whether or not to disable the bonus token. Require when bonus tokens will cause corrupt KV cache for From f67ffcdbbf70c83ec7dba7b51d9b2b34f8d455c7 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 31 May 2024 19:09:19 +0000 Subject: [PATCH 26/37] Reverting change to llm_engine_example.py --- examples/llm_engine_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index a92777d92da39..a81c4b3e399c3 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -1,5 +1,4 @@ import argparse - from typing import List, Tuple from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams From 1bdbe09bcdedeee043674ed3b032b96b7c4d1501 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 31 May 2024 23:09:12 +0000 Subject: [PATCH 27/37] Updating a comment --- .../layers/typical_acceptance_sampler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index edfd1faf9cb2f..79c3358cea1e0 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -51,16 +51,15 @@ def forward( bonus_token_ids: torch.Tensor, draft_token_ids: torch.Tensor, ) -> torch.Tensor: - """Sample token ids using rejection sampling. This accepts or rejects - tokens proposed by the draft model using the probability of each token - according to the draft and target models. + """Sample token ids using typical acceptance sampling. This accepts + or rejects tokens proposed by the draft model using the probability + of each token according to the draft and target models. In the worst case where all draft tokens are rejected, it is guaranteed - one correct token will be emitted. + one token will be emitted. - In the case where all draft tokens are accepted, a bonus token will be - accepted as its cheap to have the target model score this speculative - sequence. + In the case where all draft tokens are accepted, the bonus token will be + accepted conditioned on self._disable_bonus_tokens being false. Args: target_probs: The probability distribution over token ids given From 07f7b9c8d205c1afd0bac02e7ed8a86c96dc9e1b Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 3 Jun 2024 17:47:56 +0000 Subject: [PATCH 28/37] Dummy commit --- vllm/model_executor/layers/typical_acceptance_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 79c3358cea1e0..76d14f63762f4 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -12,7 +12,6 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): Multiple Decoding Heads" https://arxiv.org/pdf/2401.10774 """ - def __init__( self, disable_bonus_tokens: bool = False, From 74782dc1eccdef2b032c8ecca3ce30d33c34a04a Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 3 Jun 2024 22:44:30 +0000 Subject: [PATCH 29/37] Fix device for tensors --- .../layers/typical_acceptance_sampler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 76d14f63762f4..b3ebfe1e4ee82 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -137,13 +137,16 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): rejection. """ + device = target_probs.device candidates_prob = torch.gather( target_probs, dim=-1, - index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + index=draft_token_ids.unsqueeze(-1).to(device)).squeeze(-1) posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + 1e-5), dim=-1) threshold = torch.minimum( - torch.ones_like(posterior_entropy) * self._posterior_threshold, + torch.ones_like( + posterior_entropy, + device=device) * self._posterior_threshold, torch.exp(-posterior_entropy) * self._posterior_alpha, ) accepted_mask = candidates_prob > threshold @@ -175,6 +178,7 @@ def _replacement_token_ids(self, target_probs): """ max_indices = torch.argmax(target_probs[:, 0, :], dim=1) output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype) + dtype=self.token_id_dtype).to( + target_probs.device) output[:, 0] = max_indices return output From 7bcbbdb35b73819b091bc18ffa4e49af217d4796 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 6 Jun 2024 05:59:20 +0000 Subject: [PATCH 30/37] Fixing review comments --- .../test_typical_acceptance_sampler.py | 73 ++++++++++++++++++- .../layers/spec_decode_base_sampler.py | 1 - .../layers/typical_acceptance_sampler.py | 11 ++- 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 07f9498421f1b..b1b84d00f7cf0 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -59,6 +59,10 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, @torch.inference_mode() def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, device: str): + """ + Tests that the TypicalAcceptancSampler forward succeeds for + different combinations of k, vocab_size, batch_size and num devices. + """ torch.set_default_device(device) typical_acceptance_sampler = TypicalAcceptanceSampler() typical_acceptance_sampler.init_gpu_tensors(rank=0) @@ -82,11 +86,16 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, @torch.inference_mode() def test_raises_when_vocab_oob(above_or_below_vocab_range: str, which_token_ids: str, device: str): + """ + Tests that we throw an exception of the token ids fall outside + the bound of the provided vocabulary. + """ k = 3 batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -127,6 +136,17 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, @torch.inference_mode() def test_uniform_target_distribution_accepts_all_tokens( seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler with a uniform target probability + distribution. + + This test verifies that when provided with a uniform target probability + distribution, the TypicalAcceptanceSampler accepts all draft tokens. The + entropy of the uniform target distribution being high should lead to all + draft tokens being accepted. The test also ensures that the behavior + regarding bonus tokens is consistent with the `disable_bonus_tokens` + flag. + """ set_random_seed(seed) k = 3 batch_size = 5 @@ -167,6 +187,17 @@ def test_uniform_target_distribution_accepts_all_tokens( def test_temperature_zero_target_distribution(seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler with a zero-temperature target + probability distribution. + + This test verifies that when using a zero-temperature target probability + distribution, where only one token has a probability of 1.0, the + TypicalAcceptanceSampler correctly rejects all draft tokens that do not + match this probability. Additionally, it ensures that when all draft + tokens are rejected, the sampler falls back to greedy sampling to select a + single token from the target distribution. + """ set_random_seed(seed) k = 3 batch_size = 5 @@ -210,6 +241,22 @@ def test_temperature_zero_target_distribution(seed: int, @torch.inference_mode() def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler with a mixed target probability + distribution. + + This test ensures that the TypicalAcceptanceSampler handles a mixed + target probability distribution correctly. Specifically, it uses a + zero-temperature distribution for some sequences and a uniform + distribution for others. The test verifies that: + + - For sequences with a zero-temperature distribution, only the token + with a probability of 1.0 is accepted, and all other tokens are rejected. + - For sequences with a uniform distribution, all draft tokens are + accepted. + - When `disable_bonus_tokens` is False, the bonus tokens are also accepted + for sequences with a uniform distribution. + """ set_random_seed(seed) k = 3 batch_size = 4 @@ -261,6 +308,20 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, @torch.inference_mode() def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler's behavior when only a subset of draft + tokens should be accepted. + + This test verifies that the TypicalAcceptanceSampler correctly accepts or + rejects draft tokens based on a zero-temperature target probability + distribution. Specifically, it ensures that: + + - When all draft tokens match tokens with a probability of 1.0 in the + target distribution, all draft tokens are accepted. + - When only some draft tokens match tokens with a probability of 1.0 in + the target distribution, only those matching tokens are accepted, and the + rest are rejected. + """ set_random_seed(seed) k = 5 batch_size = 1 @@ -312,6 +373,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, @torch.inference_mode() def test_replacement_token_ids( seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler's method for generating + replacement token IDs. + + This test verifies that the `_replacement_token_ids` method of the + TypicalAcceptanceSampler correctly identifies the token IDs to be used + as replacements based on the target probability distribution. + Specifically, it ensures that the method correctly identifies the + tokens with the highest probability for each sequence in the batch. + """ set_random_seed(seed) k = 10 batch_size = 5 diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 4cd0b92cfcbae..9856a7e7ddea0 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,7 +1,6 @@ from typing import Optional import torch -import torch.jit class SpecDecodeBaseSampler(): diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index b3ebfe1e4ee82..3bdb72948a360 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -35,14 +35,13 @@ def __init__( threshold in typical acceptance sampling. Typically defaults to sqrt of posterior_threshold and is set to 0.3. """ - self._posterior_threshold = posterior_threshold - self._posterior_alpha = posterior_alpha - super().__init__() SpecDecodeBaseSampler.__init__( self, disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) nn.Module.__init__(self) + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha def forward( self, @@ -140,7 +139,7 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): device = target_probs.device candidates_prob = torch.gather( target_probs, dim=-1, - index=draft_token_ids.unsqueeze(-1).to(device)).squeeze(-1) + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + 1e-5), dim=-1) threshold = torch.minimum( @@ -178,7 +177,7 @@ def _replacement_token_ids(self, target_probs): """ max_indices = torch.argmax(target_probs[:, 0, :], dim=1) output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype).to( - target_probs.device) + dtype=self.token_id_dtype, + device=target_probs.device) output[:, 0] = max_indices return output From f0c2b79d5bfef6099ce25be553d3b087ae94cb13 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 06:41:57 +0000 Subject: [PATCH 31/37] Addressing comments --- vllm/model_executor/layers/typical_acceptance_sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 3bdb72948a360..3fb33d9c2d375 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -5,7 +5,6 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) - class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): """Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with @@ -140,8 +139,11 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): candidates_prob = torch.gather( target_probs, dim=-1, index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + # A small constant added to prevent computing the logarithm of zero, + # which can lead to undefined values. + epsilon = 1e-5 posterior_entropy = -torch.sum( - target_probs * torch.log(target_probs + 1e-5), dim=-1) + target_probs * torch.log(target_probs + epsilon), dim=-1) threshold = torch.minimum( torch.ones_like( posterior_entropy, From 0b7bc75104fac1c6438bd97b99827e2e2f01d77f Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 07:29:37 +0000 Subject: [PATCH 32/37] Add a new test for non default posteriors --- .../test_typical_acceptance_sampler.py | 63 +++++++++++++++++++ .../layers/spec_decode_base_sampler.py | 2 +- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index b1b84d00f7cf0..db5406e598afb 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -367,6 +367,69 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, assert torch.all(output_token_ids[:, -3:] == -1) +@pytest.mark.parametrize("seed", list(range(1))) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_accept_tokens_set_non_default_posteriors( + seed: int, disable_bonus_tokens: bool, device: str): + """ + Test the TypicalAcceptanceSampler's behavior when only a subset of draft + tokens should be accepted. + + This test verifies that the TypicalAcceptanceSampler correctly accepts or + rejects draft tokens based on a zero-temperature target probability + distribution. Specifically, it ensures that: + + - When all draft tokens match tokens with a probability of 1.0 in the + target distribution, all draft tokens are accepted. + - When only some draft tokens match tokens with a probability of 1.0 in + the target distribution, only those matching tokens are accepted, and the + rest are rejected. + """ + set_random_seed(seed) + k = 5 + batch_size = 1 + vocab_size = 30_000 + torch.set_default_device(device) + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + # Create a temperature zero target probability distribution and ensure + # all draft token ids correspond to the tokens with 1.0 probability. + # Verify that all of them are accepted. + target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( + batch_size, k, vocab_size)) + target_probs[target_probs == 0] = 0.00001 + draft_token_ids = get_draft_token_ids( + batch_size, k, vocab_size, zero_temperature_token_ids) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 1:-1] == -1) + + typical_acceptance_sampler = TypicalAcceptanceSampler( + strict_mode=True, disable_bonus_tokens=disable_bonus_tokens, + posterior_threshold=0.0, posterior_alpha=0.0) + typical_acceptance_sampler.init_gpu_tensors(rank=0) + output_token_ids = typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_token_ids) + assert output_token_ids.shape[0] == batch_size + assert output_token_ids.shape[1] == (k + 1) + assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) + if disable_bonus_tokens: + assert torch.all(output_token_ids[:, -1] == -1) + else: + assert torch.all(output_token_ids[:, -1] == bonus_token_ids) + + @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 9856a7e7ddea0..06917a44cb6d4 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -115,7 +115,7 @@ def _create_output( # Fill the recovered token ids. output.mul_(~after_false_mask).add_( substitute_token_ids.mul(after_false_mask)) - + self.num_accepted_tokens += accepted.sum() self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() self.num_draft_tokens += batch_size * k From 197ded69348a93a30d3189f78a9d1f7d0cf042ec Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:06:43 +0000 Subject: [PATCH 33/37] Documentation for test --- .../test_typical_acceptance_sampler.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index db5406e598afb..757e07dcd77ce 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -374,18 +374,10 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, def test_accept_tokens_set_non_default_posteriors( seed: int, disable_bonus_tokens: bool, device: str): """ - Test the TypicalAcceptanceSampler's behavior when only a subset of draft - tokens should be accepted. - - This test verifies that the TypicalAcceptanceSampler correctly accepts or - rejects draft tokens based on a zero-temperature target probability - distribution. Specifically, it ensures that: - - - When all draft tokens match tokens with a probability of 1.0 in the - target distribution, all draft tokens are accepted. - - When only some draft tokens match tokens with a probability of 1.0 in - the target distribution, only those matching tokens are accepted, and the - rest are rejected. + Test the TypicalAcceptanceSampler with custom posterior thresholds and + alpha values. This test verifies that by modifying the posterior + thresholds and alpha values we can change the acceptance behavior of the + sampler. """ set_random_seed(seed) k = 5 @@ -395,9 +387,12 @@ def test_accept_tokens_set_non_default_posteriors( typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) - # Create a temperature zero target probability distribution and ensure - # all draft token ids correspond to the tokens with 1.0 probability. - # Verify that all of them are accepted. + # Simulate temperature 0 probability distribution for target + # probabilities and create target probabilities such that only 1 token + # id has probability 1.0 and others have a very low probability of + # 0.00001. Populate draft_token_ids such that they exclude the token_ids + # with probability = 1.0. Without any changes to the posterior thresholds + # none of the draft tokens are accepted. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) target_probs[target_probs == 0] = 0.00001 @@ -414,6 +409,9 @@ def test_accept_tokens_set_non_default_posteriors( assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 1:-1] == -1) + # Change the posterior threshold values to 0.0 so that we will + # now accept even draft tokens with very low probability in the + # target distribution. Simulate and verify the same. typical_acceptance_sampler = TypicalAcceptanceSampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens, posterior_threshold=0.0, posterior_alpha=0.0) From bd435df2705507732d7a94804a8f022988610fba Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:21:15 +0000 Subject: [PATCH 34/37] Fix ruff errors --- vllm/model_executor/layers/rejection_sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index a4d51e948f5b0..be87d9d92d7c5 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,11 +1,12 @@ from functools import cached_property -from typing import Optional, Tuple +from typing import Tuple import torch import torch.jit import torch.nn as nn -from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler - +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler +) class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large From d4c750e70cf53c2078492fbae397234050ad1fef Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:31:24 +0000 Subject: [PATCH 35/37] Fix spell corrections --- vllm/model_executor/layers/typical_acceptance_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 3fb33d9c2d375..521de7c21724e 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -106,7 +106,7 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): A tensor of shape (batch_size, k) representing the proposed token ids. - A draft token_id x_{n+k} is accepted if it satisifies the + A draft token_id x_{n+k} is accepted if it satisfies the following condition .. math:: From 2c2004febb82f2a24c35e30af8396960560a4d3c Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:33:12 +0000 Subject: [PATCH 36/37] Fixing spell errors --- tests/samplers/test_typical_acceptance_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 757e07dcd77ce..341901d4f6c3c 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -12,9 +12,9 @@ def get_zero_temperature_prob_dist(batch_size, k, vocab_size): """ - Generates a fake temperature zero probablity distribution. + Generates a fake temperature zero probability distribution. Returns: - 1. A fake temperature zero probablity distribution of shape + 1. A fake temperature zero probability distribution of shape [batch_size, k, vocab_size] 2. Tensor of shape [batch_size, k] containing the token ids of the probability 1.0 tokens at each position. From 85c48f5e9d69f61cf0180272da84251b3bd3bdb9 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 18:39:56 +0000 Subject: [PATCH 37/37] Ran format.sh --- .../test_typical_acceptance_sampler.py | 30 ++++++++++--------- .../layers/rejection_sampler.py | 5 ++-- .../layers/spec_decode_base_sampler.py | 2 +- .../layers/typical_acceptance_sampler.py | 7 +++-- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 341901d4f6c3c..87cf37bc926bc 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -94,8 +94,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( - strict_mode=True) + typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -371,8 +370,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_accept_tokens_set_non_default_posteriors( - seed: int, disable_bonus_tokens: bool, device: str): +def test_accept_tokens_set_non_default_posteriors(seed: int, + disable_bonus_tokens: bool, + device: str): """ Test the TypicalAcceptanceSampler with custom posterior thresholds and alpha values. This test verifies that by modifying the posterior @@ -391,13 +391,13 @@ def test_accept_tokens_set_non_default_posteriors( # probabilities and create target probabilities such that only 1 token # id has probability 1.0 and others have a very low probability of # 0.00001. Populate draft_token_ids such that they exclude the token_ids - # with probability = 1.0. Without any changes to the posterior thresholds + # with probability = 1.0. Without any changes to the posterior thresholds # none of the draft tokens are accepted. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( batch_size, k, vocab_size)) target_probs[target_probs == 0] = 0.00001 - draft_token_ids = get_draft_token_ids( - batch_size, k, vocab_size, zero_temperature_token_ids) + draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, + zero_temperature_token_ids) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -413,8 +413,10 @@ def test_accept_tokens_set_non_default_posteriors( # now accept even draft tokens with very low probability in the # target distribution. Simulate and verify the same. typical_acceptance_sampler = TypicalAcceptanceSampler( - strict_mode=True, disable_bonus_tokens=disable_bonus_tokens, - posterior_threshold=0.0, posterior_alpha=0.0) + strict_mode=True, + disable_bonus_tokens=disable_bonus_tokens, + posterior_threshold=0.0, + posterior_alpha=0.0) typical_acceptance_sampler.init_gpu_tensors(rank=0) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, @@ -432,8 +434,8 @@ def test_accept_tokens_set_non_default_posteriors( @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_replacement_token_ids( - seed: int, disable_bonus_tokens: bool, device: str): +def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool, + device: str): """ Test the TypicalAcceptanceSampler's method for generating replacement token IDs. @@ -455,8 +457,8 @@ def test_replacement_token_ids( target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) expected_replacement_tokens = -torch.ones( (batch_size, k), dtype=torch.long) - expected_replacement_tokens[:, 0] = torch.argmax( - target_probs[:, 0, :], dim=1) + expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :], + dim=1) actual_replacement_tokens = ( - typical_acceptance_sampler._replacement_token_ids(target_probs)) + typical_acceptance_sampler._replacement_token_ids(target_probs)) assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index be87d9d92d7c5..fe9b2fac1117e 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -4,9 +4,10 @@ import torch import torch.jit import torch.nn as nn + from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeBaseSampler -) + SpecDecodeBaseSampler) + class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 06917a44cb6d4..9856a7e7ddea0 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -115,7 +115,7 @@ def _create_output( # Fill the recovered token ids. output.mul_(~after_false_mask).add_( substitute_token_ids.mul(after_false_mask)) - + self.num_accepted_tokens += accepted.sum() self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() self.num_draft_tokens += batch_size * k diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 521de7c21724e..f12d6a03b4d16 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -5,12 +5,14 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) + class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): """Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" https://arxiv.org/pdf/2401.10774 """ + def __init__( self, disable_bonus_tokens: bool = False, @@ -145,9 +147,8 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + epsilon), dim=-1) threshold = torch.minimum( - torch.ones_like( - posterior_entropy, - device=device) * self._posterior_threshold, + torch.ones_like(posterior_entropy, device=device) * + self._posterior_threshold, torch.exp(-posterior_entropy) * self._posterior_alpha, ) accepted_mask = candidates_prob > threshold