Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Implement vLLM V1 [1/N] #9289

Merged
merged 101 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
0d8a651
Add vllm_v1
WoosukKwon Oct 11, 2024
9a5c899
Max num seqs
WoosukKwon Oct 11, 2024
b2b90a9
Fix chunked prefill
WoosukKwon Oct 11, 2024
0ee05d2
Fix flash-attn
WoosukKwon Oct 11, 2024
647cb1b
yapf
WoosukKwon Oct 11, 2024
2f04e52
Fix memory
WoosukKwon Oct 13, 2024
9a159be
Fix
WoosukKwon Oct 15, 2024
00d3975
Remove time
WoosukKwon Oct 15, 2024
dff359d
Minor
WoosukKwon Oct 15, 2024
0cb2454
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 15, 2024
90390bc
Minor
WoosukKwon Oct 15, 2024
fa82f0d
Fix
WoosukKwon Oct 15, 2024
5cf508c
Revert
WoosukKwon Oct 15, 2024
8c476d5
Remove commit_id
WoosukKwon Oct 15, 2024
10e474a
Fix slot_mapping
WoosukKwon Oct 15, 2024
e35a3d2
Remove comment
WoosukKwon Oct 15, 2024
4ce3470
Fix
WoosukKwon Oct 15, 2024
815f137
Remove logits processor
WoosukKwon Oct 16, 2024
e7605a7
Fix dummy run
WoosukKwon Oct 16, 2024
ae5089b
comment
WoosukKwon Oct 16, 2024
05934ea
Fix
WoosukKwon Oct 16, 2024
fa5ad10
Remove redundancy
WoosukKwon Oct 16, 2024
ea44286
Minor
WoosukKwon Oct 16, 2024
789aeb8
Fix
WoosukKwon Oct 16, 2024
58053f0
Minor
WoosukKwon Oct 18, 2024
e56e3e5
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 18, 2024
7b3219f
Clean up
WoosukKwon Oct 18, 2024
d0090d2
Add inits
WoosukKwon Oct 18, 2024
deacb3b
yapf
WoosukKwon Oct 18, 2024
e51bda9
vllm_v1 -> vllm.v1
WoosukKwon Oct 18, 2024
9abf055
yapf
WoosukKwon Oct 18, 2024
6dd0155
Fix
WoosukKwon Oct 18, 2024
c83966b
Add VLLM_USE_V1
WoosukKwon Oct 18, 2024
bffa71c
fix
WoosukKwon Oct 18, 2024
ba1dc5e
Minor
WoosukKwon Oct 18, 2024
8a9a114
Fix
WoosukKwon Oct 18, 2024
405f895
isort
WoosukKwon Oct 18, 2024
c7a70e9
Move detokenizer_utils
WoosukKwon Oct 18, 2024
2bec533
Minor
WoosukKwon Oct 18, 2024
f4f573b
yapf
WoosukKwon Oct 18, 2024
2a29e1d
Fix
WoosukKwon Oct 18, 2024
dc2106f
Minor
WoosukKwon Oct 18, 2024
68bd6f7
Rename ports
WoosukKwon Oct 18, 2024
f03d574
Comment
WoosukKwon Oct 18, 2024
44b152b
comment
WoosukKwon Oct 18, 2024
1b186a8
Comment
WoosukKwon Oct 18, 2024
4e07a47
Minor
WoosukKwon Oct 18, 2024
c6ab902
Add comments
WoosukKwon Oct 18, 2024
fd59c5e
Remove unused methods
WoosukKwon Oct 18, 2024
4afd2d2
Add check_health
WoosukKwon Oct 18, 2024
6cea5e7
Fix switching between V0 and V1 engine
WoosukKwon Oct 18, 2024
9978be3
Make async detokenizer work
WoosukKwon Oct 18, 2024
0e93601
yapf
WoosukKwon Oct 18, 2024
b96dd05
Do not send prompt tokens redundantly
WoosukKwon Oct 19, 2024
248f890
Remove async gpu executor
WoosukKwon Oct 19, 2024
6225c8d
compatibility
WoosukKwon Oct 19, 2024
f7752d8
Optimize random_sample
WoosukKwon Oct 20, 2024
c460da9
Remove
WoosukKwon Oct 20, 2024
b4a674b
Use dict
WoosukKwon Oct 20, 2024
0c5f5a9
yapf
WoosukKwon Oct 20, 2024
40b4c78
Fix
WoosukKwon Oct 20, 2024
b2aaea2
Minor
WoosukKwon Oct 20, 2024
d5ec4cb
Fix deotkenizer
WoosukKwon Oct 20, 2024
fbbb771
Minor
WoosukKwon Oct 20, 2024
ad3b0d9
Detokenizer & DetokenizerProc
WoosukKwon Oct 20, 2024
91ae792
yapf
WoosukKwon Oct 20, 2024
9598d43
Minor
WoosukKwon Oct 20, 2024
40c5114
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 20, 2024
0ef47d5
Fix
WoosukKwon Oct 20, 2024
96e5781
Fix
WoosukKwon Oct 20, 2024
aefa95f
Optimize
WoosukKwon Oct 20, 2024
952fab8
Comment
WoosukKwon Oct 20, 2024
eb2008a
Fix
WoosukKwon Oct 20, 2024
ec8e871
Add comment on scheduler
WoosukKwon Oct 20, 2024
f811fe0
Optimize object creation
WoosukKwon Oct 20, 2024
f03416b
Optimize finish_requests
WoosukKwon Oct 20, 2024
cd57404
Minor:
WoosukKwon Oct 20, 2024
9f637d6
Minor
WoosukKwon Oct 20, 2024
8ac308f
Support API server
WoosukKwon Oct 20, 2024
d35fb71
Fix
WoosukKwon Oct 20, 2024
3af10d7
Minor
WoosukKwon Oct 20, 2024
da2958f
Support stop ids
WoosukKwon Oct 21, 2024
f89edac
Minor
WoosukKwon Oct 21, 2024
864dd27
RequestMetrics
WoosukKwon Oct 21, 2024
f8f7d23
Fix
WoosukKwon Oct 21, 2024
e5fb326
mypy
WoosukKwon Oct 21, 2024
380568c
mypy
WoosukKwon Oct 21, 2024
ec43110
mypy
WoosukKwon Oct 21, 2024
cd99b21
Fix
WoosukKwon Oct 21, 2024
76bb54f
TODO on top-p top-k
WoosukKwon Oct 21, 2024
261a1ef
Refactor
WoosukKwon Oct 21, 2024
5a2ddbf
Minor
WoosukKwon Oct 21, 2024
44412b5
Remove
WoosukKwon Oct 21, 2024
f8c8b8e
typo
WoosukKwon Oct 21, 2024
a0fa8eb
Merge branch 'main' into re-arch-v1
WoosukKwon Oct 21, 2024
3ba8865
Preallocate instead of watermark
WoosukKwon Oct 21, 2024
8c4b84c
RequestOutput
WoosukKwon Oct 21, 2024
0d21798
Minor
WoosukKwon Oct 21, 2024
c13b503
num_new_tokens
WoosukKwon Oct 21, 2024
804f0cd
Minor
WoosukKwon Oct 21, 2024
e441f0a
Add __init__
WoosukKwon Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="meta-llama/Llama-3.1-8B")
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
5 changes: 5 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_NEW = enum.auto()
alexm-redhat marked this conversation as resolved.
Show resolved Hide resolved
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
Expand Down Expand Up @@ -112,6 +113,10 @@ def get_attn_backend(
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_NEW:
from vllm_v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions vllm/commit_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__commit__ = "f2bd246c17ba67d7749a2560a30711f74cd19177"
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
43 changes: 24 additions & 19 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of

from vllm_v1.engine.llm_engine import LLMEngine as LLMEngineV1
from vllm_v1.outputs import RequestOutput as RequestOutputV1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the interface is compatible, would the following be easier?

if USE_V1:
    from vllm_v1.engine.llm_engine import LLMEngine
    from vllm_v1.outputs import RequestOutput
else:
    from vllm.engine.llm_engine import LLMEngine
    from vllm.outputs import RequestOutput  

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I introduced the VLLM_USE_V1 env variable and added a similar if statement. PTAL.


logger = init_logger(__name__)


Expand Down Expand Up @@ -174,8 +177,13 @@ def __init__(
mm_processor_kwargs=mm_processor_kwargs,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
# FIXME:
engine_args.max_num_seqs = max(engine_args.max_num_seqs, 2048)
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
engine_args.enable_chunked_prefill = False
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
self.llm_engine = LLMEngineV1.from_engine_args(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this where we want to switch between vllm_v1 and the old vllm?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding an if here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I introduced a new env variable VLLM_USE_V1, which is 0 by default. By setting this env variable, users can use the V1 code path.

engine_args, usage_context=UsageContext.LLM_CLASS)
# self.llm_engine = LLMEngine.from_engine_args(
# engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()

def get_tokenizer(self) -> AnyTokenizer:
Expand Down Expand Up @@ -876,27 +884,24 @@ def _run_engine(
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = (total_out_toks /
pbar.format_dict["elapsed"])
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
finished_reqs, _ = self.llm_engine.step()
for req in finished_reqs:
output = RequestOutputV1.from_request(req)
outputs.append(output)
if use_tqdm:
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += len(output.outputs[0].token_ids)
out_spd = (total_out_toks / pbar.format_dict["elapsed"])
pbar.postfix = (f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

if use_tqdm:
pbar.close()
self.llm_engine.terminate_detokenizer()
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
if sampling_metadata is not None:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
Expand All @@ -69,7 +70,8 @@ def forward(
logits *= self.scale

# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)

return logits

Expand Down
247 changes: 247 additions & 0 deletions vllm_v1/attention/backends/flash_attn.py
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.forward_context import get_forward_context


class FlashAttentionBackend(AttentionBackend):

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_name() -> str:
return "flash-attn-new"

@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)


@dataclass
class FlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|

max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved


class FlashAttentionImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")

support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")

# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

output = torch.ops.vllm.unified_flash_attention(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
return output


@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)

if kv_cache is not None:
key_cache = kv_cache[0]
value_cache = kv_cache[1]

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)

if (attn_metadata.block_table is None
or attn_metadata.block_table.numel() == 0):
# Profiling run.
output = torch.empty_like(query)
return output

output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
return output.view(num_tokens, hidden_size)


@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
Loading
Loading