Skip to content

Commit

Permalink
[V1][BugFix] Clean up rejection sampler & Fix warning msg (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#13362)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
  • Loading branch information
WoosukKwon authored Feb 16, 2025
1 parent d67cc21 commit 69e1d23
Showing 1 changed file with 69 additions and 40 deletions.
109 changes: 69 additions & 40 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata

Expand All @@ -19,27 +21,50 @@

class RejectionSampler(nn.Module):

def __init__(self):
super().__init__()
if current_platform.is_cuda:
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for rejection sampling.")
self.forward_method = self.flashinfer_sample
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"rejection sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward_method = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling. For the "
"best performance, please install FlashInfer.")
self.forward_method = self.forward_native
else:
self.forward_method = self.forward_native

def forward(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
raise NotImplementedError(
"Only greedy sampling is supported by rejection sampler.")
"Currently, only greedy sampling is supported by "
"rejection sampler.")
return self.forward_method(logits, sampling_metadata)

if is_flashinfer_available:
logger.info("User FlashInfer for rejection sampling.")
return RejectionSampler.flashinfer_sample(logits,
sampling_metadata)
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling.")
return RejectionSampler.greedy_sample_native(
logits, sampling_metadata)

@staticmethod
def flashinfer_sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
Expand Down Expand Up @@ -71,10 +96,10 @@ def flashinfer_sample(
vocab_size = logits.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids = draft_token_ids.to(logits.device)
draft_probs = RejectionSampler._create_greedy_token_probs(
draft_token_ids, vocab_size, logits.device)
target_probs = RejectionSampler._create_greedy_token_probs(
target_token_ids, vocab_size, logits.device)
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
logits.device)
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
logits.device)
uniform_samples = torch.zeros(batch_size,
max_spec_len + 1,
device=logits.device)
Expand All @@ -89,10 +114,11 @@ def flashinfer_sample(
logprobs_tensors=None)

# TODO: The following method can be optimized for better performance.
@staticmethod
def greedy_sample_native(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
def forward_native(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens]
Expand Down Expand Up @@ -137,24 +163,27 @@ def greedy_sample_native(
return SamplerOutput(sampled_token_ids=output_token_ids,
logprobs_tensors=None)

@staticmethod
def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int,
out_device: torch.device) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape

token_probs = torch.zeros(batch_size,
num_tokens,
vocab_size,
dtype=torch.float,
device=out_device)
def _create_greedy_token_probs(
token_ids: torch.Tensor,
vocab_size: int,
out_device: torch.device,
) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape

token_probs = torch.zeros(batch_size,
num_tokens,
vocab_size,
dtype=torch.float,
device=out_device)

# Ignore INVALID_TOKEN_ID.
valid_mask = (token_ids != INVALID_TOKEN_ID)
valid_indices = token_ids.clone()
valid_indices[~valid_mask] = 0
# Ignore INVALID_TOKEN_ID.
valid_mask = (token_ids != INVALID_TOKEN_ID)
valid_indices = token_ids.clone()
valid_indices[~valid_mask] = 0

token_probs.scatter_(dim=2,
index=valid_indices.unsqueeze(-1),
src=valid_mask.unsqueeze(-1).float())
token_probs.scatter_(dim=2,
index=valid_indices.unsqueeze(-1),
src=valid_mask.unsqueeze(-1).float())

return token_probs
return token_probs

0 comments on commit 69e1d23

Please sign in to comment.