Skip to content

Commit

Permalink
[Core] Use array to speedup padding (vllm-project#6779)
Browse files Browse the repository at this point in the history
  • Loading branch information
peng1999 authored Jul 26, 2024
1 parent 084a01f commit 89a84b0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _apply_min_tokens_penalty(
seqs_to_penalize: List[int] = []
for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
if len(seq_data.output_token_ids_array) < min_tokens:
seqs_to_penalize.append(j)

if seqs_to_penalize:
Expand Down
21 changes: 12 additions & 9 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from array import array
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -329,8 +330,8 @@ def from_sampling_metadata(
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
prompt_tokens: List[array] = []
output_tokens: List[array] = []
top_ks: List[int] = []
temperatures: List[float] = []
top_ps: List[float] = []
Expand Down Expand Up @@ -432,13 +433,15 @@ def from_sampling_metadata(
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices)
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))
prompt_tokens.extend(
array('l') for _ in range(prefill_len))
output_tokens.extend(
array('l') for _ in range(prefill_len))
if seq_group.do_sample:
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(list(seq_data.prompt_token_ids))
output_tokens.append(list(seq_data.output_token_ids))
prompt_tokens.append(seq_data.prompt_token_ids_array)
output_tokens.append(seq_data.output_token_ids_array)

sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
Expand All @@ -454,9 +457,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int],
prompt_tokens: List[List[int]],
output_tokens: List[List[int]], vocab_size: int,
extra_seeds_to_generate: int, device: torch.device,
prompt_tokens: List[array], output_tokens: List[array],
vocab_size: int, extra_seeds_to_generate: int,
device: torch.device,
dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
Expand Down
23 changes: 16 additions & 7 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import math
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Expand Down Expand Up @@ -119,10 +120,10 @@ def __init__(
prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None,
) -> None:
self._prompt_token_ids: List[int] = list(prompt_token_ids)
self._prompt_token_ids = array('l', prompt_token_ids)
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self._output_token_ids: List[int] = (
list(output_token_ids) if output_token_ids is not None else [])
self._output_token_ids = array(
'l', output_token_ids if output_token_ids is not None else [])

self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
Expand All @@ -132,28 +133,36 @@ def __init__(
self._update_cached_all_tokens()

def _update_cached_all_tokens(self):
self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
self._output_token_ids)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._output_token_ids)

@property
def prompt_token_ids(self) -> Tuple[int, ...]:
return self._prompt_token_ids_tuple

@prompt_token_ids.setter
def prompt_token_ids(self, new_prompt_token_ids) -> None:
self._prompt_token_ids = list(new_prompt_token_ids)
self._prompt_token_ids = array('l', new_prompt_token_ids)
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
self._update_cached_all_tokens()

@property
def prompt_token_ids_array(self) -> array:
return self._prompt_token_ids

@property
def output_token_ids(self) -> Tuple[int, ...]:
return tuple(self._output_token_ids)

@output_token_ids.setter
def output_token_ids(self, new_output_token_ids) -> None:
self._output_token_ids = list(new_output_token_ids)
self._output_token_ids = array('l', new_output_token_ids)
self._update_cached_all_tokens()

@property
def output_token_ids_array(self) -> array:
return self._output_token_ids

def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id)
self._cached_all_token_ids.append(token_id)
Expand Down

0 comments on commit 89a84b0

Please sign in to comment.