Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Factor out common code in SequenceData and Sequence #8675

Merged
merged 1 commit into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
import random
from array import array
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch

Expand All @@ -12,8 +11,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter, is_pin_memory_available


Expand Down Expand Up @@ -59,9 +57,7 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -205,9 +201,8 @@ def create_sampling_params(min_tokens,
return sampling_params

def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
seq_data = SequenceData.from_seqs(
random.choices(range(0, VOCAB_SIZE), k=num_input))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
Expand Down Expand Up @@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
Expand Down Expand Up @@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
Expand Down Expand Up @@ -699,11 +690,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
Expand Down
12 changes: 3 additions & 9 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from array import array
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
Expand All @@ -11,8 +10,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
Expand Down Expand Up @@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data={
i:
SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
),
i: SequenceData.from_seqs(prompt_token_ids[:],
cont_token_ids[:]),
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},
Expand Down
8 changes: 2 additions & 6 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
from array import array
from typing import Tuple
from unittest.mock import patch

Expand All @@ -9,8 +8,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -71,9 +69,7 @@ def pick_ith(token_ids, logits):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
Expand Down
7 changes: 2 additions & 5 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from array import array

import pytest

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SequenceData,
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)

from .core.utils import create_dummy_prompt
Expand Down Expand Up @@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):


def test_sequence_data_prefill():
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
seq_data = SequenceData.from_seqs([1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
Expand Down
22 changes: 7 additions & 15 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import itertools
from array import array
from typing import List

import pytest
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_cpu, make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size
Expand Down Expand Up @@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))

seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
Expand Down Expand Up @@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
Expand Down
16 changes: 5 additions & 11 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from array import array
from typing import List

import pytest
Expand All @@ -8,8 +7,7 @@
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size

Expand Down Expand Up @@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_data = SequenceData.from_seqs(range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand Down Expand Up @@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
Expand Down Expand Up @@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_data = SequenceData.from_seqs(range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
Expand All @@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
Expand Down
8 changes: 1 addition & 7 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
from array import array
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
Expand All @@ -22,10 +21,6 @@

C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)

# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"


@dataclass(frozen=True)
class InputContext:
Expand Down Expand Up @@ -130,8 +125,7 @@ def _default_dummy_data_factory(
# Avoid circular import
from vllm.sequence import SequenceData

dummy_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_seq_data = SequenceData.from_counts({0: seq_len})
dummy_multi_modal_data = None

return dummy_seq_data, dummy_multi_modal_data
Expand Down
61 changes: 37 additions & 24 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from array import array
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
Expand Down Expand Up @@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None

@staticmethod
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
if len(counts_by_token) == 0:
return SequenceData.from_seqs([])

arrs = [
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
for token_id, count in counts_by_token.items()
]

return SequenceData(reduce(array.__add__, arrs))

@staticmethod
def from_seqs(
prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None,
) -> "SequenceData":
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids)

if output_token_ids is None:
return SequenceData(prompt_token_ids_arr)

output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids)

return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr)

def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
Expand Down Expand Up @@ -370,8 +400,6 @@ def __init__(
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None

# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
Expand Down Expand Up @@ -400,8 +428,7 @@ def __init__(
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")

self.data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand All @@ -422,37 +449,23 @@ def __init__(
def n_blocks(self) -> int:
return (self.get_len() + self.block_size - 1) // self.block_size

@property
@cached_property
def prompt(self) -> Optional[str]:
if self._prompt is not None:
# Reuse precomputed prompt string
return self._prompt

# Select decoder or encoder input prompt str,
# as appropriate
# Select decoder or encoder input prompt str, as appropriate
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")

# Cache prompt
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
return cast(Optional[str], self.inputs.get(prompt_key))

@property
@cached_property
def prompt_token_ids(self) -> List[int]:
if self._prompt_token_ids is not None:
# Reuse precomputed prompt token ids
return self._prompt_token_ids

# Select decoder or encoder input prompt
# token ids, as appropriate
# Select decoder or encoder input prompt token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")

# Cache computed prompt token ids
self._prompt_token_ids = cast(List[int],
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
return cast(List[int], self.inputs.get(prompt_token_ids_key))

@property
def multi_modal_data(self) -> "MultiModalDataDict":
Expand Down
Loading