From 69e1d23e1eae4d33ac44fdab2257e29a542cd9a9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Feb 2025 12:25:29 -0800 Subject: [PATCH] [V1][BugFix] Clean up rejection sampler & Fix warning msg (#13362) Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 109 ++++++++++++++++++---------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 6a0bbe7b216fc..df1da89302111 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -3,7 +3,9 @@ import torch.nn as nn from torch.nn.utils.rnn import pad_sequence +from vllm import envs from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -19,27 +21,50 @@ class RejectionSampler(nn.Module): + def __init__(self): + super().__init__() + if current_platform.is_cuda: + if is_flashinfer_available: + if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for + # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by + # default it is unused). For backward compatibility, we set + # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and + # interpret it differently in V0 and V1 samplers: In V0, + # None means False, while in V1, None means True. This is + # why we use the condition + # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. + logger.info("Using FlashInfer for rejection sampling.") + self.forward_method = self.flashinfer_sample + else: + logger.warning( + "FlashInfer is available, but it is not enabled. " + "Falling back to the PyTorch-native implementation of " + "rejection sampling. For the best performance, " + "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + self.forward_method = self.forward_native + else: + logger.warning( + "FlashInfer is not available. Falling back to the PyTorch-" + "native implementation of rejection sampling. For the " + "best performance, please install FlashInfer.") + self.forward_method = self.forward_native + else: + self.forward_method = self.forward_native + def forward(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: if not sampling_metadata.all_greedy: raise NotImplementedError( - "Only greedy sampling is supported by rejection sampler.") + "Currently, only greedy sampling is supported by " + "rejection sampler.") + return self.forward_method(logits, sampling_metadata) - if is_flashinfer_available: - logger.info("User FlashInfer for rejection sampling.") - return RejectionSampler.flashinfer_sample(logits, - sampling_metadata) - else: - logger.warning( - "FlashInfer is not available. Falling back to the PyTorch-" - "native implementation of rejection sampling.") - return RejectionSampler.greedy_sample_native( - logits, sampling_metadata) - - @staticmethod def flashinfer_sample( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> SamplerOutput: + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: # NOTE: The following input preparationg can be moved # to the model runner with a persistent manner for better # performance. @@ -71,10 +96,10 @@ def flashinfer_sample( vocab_size = logits.size(-1) # NOTE: CPU <-> GPU synchronization happens here. draft_token_ids = draft_token_ids.to(logits.device) - draft_probs = RejectionSampler._create_greedy_token_probs( - draft_token_ids, vocab_size, logits.device) - target_probs = RejectionSampler._create_greedy_token_probs( - target_token_ids, vocab_size, logits.device) + draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size, + logits.device) + target_probs = _create_greedy_token_probs(target_token_ids, vocab_size, + logits.device) uniform_samples = torch.zeros(batch_size, max_spec_len + 1, device=logits.device) @@ -89,10 +114,11 @@ def flashinfer_sample( logprobs_tensors=None) # TODO: The following method can be optimized for better performance. - @staticmethod - def greedy_sample_native( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> SamplerOutput: + def forward_native( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] # Add 1 to include the 'bonus' token. sample_lens = [x + 1 for x in spec_lens] @@ -137,24 +163,27 @@ def greedy_sample_native( return SamplerOutput(sampled_token_ids=output_token_ids, logprobs_tensors=None) - @staticmethod - def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int, - out_device: torch.device) -> torch.Tensor: - batch_size, num_tokens = token_ids.shape - token_probs = torch.zeros(batch_size, - num_tokens, - vocab_size, - dtype=torch.float, - device=out_device) +def _create_greedy_token_probs( + token_ids: torch.Tensor, + vocab_size: int, + out_device: torch.device, +) -> torch.Tensor: + batch_size, num_tokens = token_ids.shape + + token_probs = torch.zeros(batch_size, + num_tokens, + vocab_size, + dtype=torch.float, + device=out_device) - # Ignore INVALID_TOKEN_ID. - valid_mask = (token_ids != INVALID_TOKEN_ID) - valid_indices = token_ids.clone() - valid_indices[~valid_mask] = 0 + # Ignore INVALID_TOKEN_ID. + valid_mask = (token_ids != INVALID_TOKEN_ID) + valid_indices = token_ids.clone() + valid_indices[~valid_mask] = 0 - token_probs.scatter_(dim=2, - index=valid_indices.unsqueeze(-1), - src=valid_mask.unsqueeze(-1).float()) + token_probs.scatter_(dim=2, + index=valid_indices.unsqueeze(-1), + src=valid_mask.unsqueeze(-1).float()) - return token_probs + return token_probs