Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Use flashinfer sampling kernel when available #7137

Merged
merged 27 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7e47c2d
Use flashinfer kernel to do sampling if available
peng1999 Aug 2, 2024
fa61264
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 2, 2024
56beab0
Fix type mismatch
peng1999 Aug 5, 2024
5396c9d
Some renaming
peng1999 Aug 5, 2024
5999bd3
Fallback for flashinfer sampler
peng1999 Aug 5, 2024
420b004
Formatting fix
peng1999 Aug 5, 2024
98d372e
Tests fix
peng1999 Aug 5, 2024
0a8be18
Fix mypy
peng1999 Aug 5, 2024
f170646
Add test for flashinfer sampler
peng1999 Aug 5, 2024
88c8a98
Suppress yapf on import
peng1999 Aug 5, 2024
c404cd5
Fix pipeline
peng1999 Aug 5, 2024
c361a95
Change back to torch generator, add env flags
peng1999 Aug 6, 2024
8af0e09
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 6, 2024
99f7ecc
rename env for flashinfer, rollback changes in utils
peng1999 Aug 7, 2024
7e03711
rollback changes to utils
peng1999 Aug 7, 2024
6416046
rename env
peng1999 Aug 8, 2024
fdc23a3
add top_k_top_p when fallback
peng1999 Aug 8, 2024
b97c911
Adapt flashinfer 0.1.4
peng1999 Aug 12, 2024
f8d7093
Revert changes to sampling_metadata
peng1999 Aug 12, 2024
2d7e5c3
Change flashinfer 0.1.2 to 0.1.4 in test
peng1999 Aug 12, 2024
20eee6a
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 12, 2024
f893110
Disable flashinfer in GPTQ reproduce test
peng1999 Aug 15, 2024
e4cfcfc
Disable flashinfer sampler in distributed test
peng1999 Aug 15, 2024
c5194ec
Merge remote-tracking branch 'upstream/main' into opt-topk
peng1999 Aug 15, 2024
0ec8b61
Disable flashinfer sampler by default
peng1999 Aug 16, 2024
9eaea5c
Update vllm/envs.py
peng1999 Aug 17, 2024
18d59a1
Merge branch 'vllm-project:main' into opt-topk
peng1999 Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ steps:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
- tests/samplers
command: pytest -v -s samplers
commands:
- pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers

- label: LogitsProcessor Test # 5min
mirror_hardwares: [amd]
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE ####################


Expand Down
37 changes: 36 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from transformers import GenerationConfig, GenerationMixin

import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -634,7 +635,10 @@ def mock_sample(probs, *args, **kwargs):
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)

with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
# top-k and top-p is only calculated when flashinfer kernel is not available
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
patch("vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", None):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

assert sample_probs is not None
Expand All @@ -645,6 +649,37 @@ def mock_sample(probs, *args, **kwargs):
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
pytest.skip("Flashinfer sampler is disabled")

set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)

def failing_flashinfer_sampling(*_args, **_kwargs):
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

with patch(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

assert sampler_output == fallback_sampler_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):

Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
Expand Down Expand Up @@ -256,6 +257,10 @@ def get_default_config_root():
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),

# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
Expand Down
110 changes: 85 additions & 25 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from importlib.util import find_spec
from math import inf
from typing import Dict, List, Optional, Tuple

Expand All @@ -11,6 +13,7 @@
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton

import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
Expand All @@ -19,6 +22,16 @@
PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceOutput)

if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)

# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None

# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]

Expand Down Expand Up @@ -123,7 +136,7 @@ def forward(
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

Expand Down Expand Up @@ -476,32 +489,65 @@ def _multinomial(
seq_groups: Optional[List[SequenceGroupToSample]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
probs = probs.repeat_interleave(num_samples, dim=0)
peng1999 marked this conversation as resolved.
Show resolved Hide resolved
comaniac marked this conversation as resolved.
Show resolved Hide resolved
q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(
generator=seq_group.generator)
sample_idx = next_sample_idx
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
q[sample_idx:sample_idx +
stride].exponential_(generator=seq_group.generator)
sample_idx += stride
return probs.div_(q).argmax(dim=1).view(-1, num_samples)


def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
top_ks = top_ks.repeat_interleave(num_samples)
top_ps = top_ps.repeat_interleave(num_samples)
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if seq_groups is None:
uniform_samples.uniform_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
uniform_samples[:, sample_idx:sample_idx +
stride].uniform_(generator=seq_group.generator)
sample_idx += stride
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
probs,
uniform_samples,
top_ks,
top_ps,
)
if not success.all():
warnings.warn("FlashInfer rejection sampling failed, fallback.",
stacklevel=1)
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0])
return batch_next_token_ids.view(-1, num_samples)


def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
Expand Down Expand Up @@ -564,18 +610,28 @@ def _sample_with_torch(
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
}

multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices], max_best_of_in_batch,
**seeded_args)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)

if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_best_of_in_batch,
seq_groups=seq_groups_arg)

if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type]
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)

elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
Expand Down Expand Up @@ -693,9 +749,12 @@ def _sample_with_triton_kernel(


def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
Expand All @@ -713,6 +772,7 @@ def _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
Expand Down
Loading