Skip to content

Commit

Permalink
re-implement beam search on top of vllm core (#8726)
Browse files Browse the repository at this point in the history
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
  • Loading branch information
youkaichao and LunrEclipse authored Sep 24, 2024
1 parent 88577ac commit 0250dd6
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 9 deletions.
24 changes: 20 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def run_vllm(
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand Down Expand Up @@ -132,9 +133,23 @@ def run_vllm(
max_tokens=output_len,
))

start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
if not use_new_beam_search_impl:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(prompts,
beam_width=n,
max_tokens=output_len,
ignore_eos=True)
end = time.perf_counter()
return end - start


Expand Down Expand Up @@ -336,7 +351,7 @@ def main(args: argparse.Namespace):
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args)
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -396,6 +411,7 @@ def main(args: argparse.Namespace):
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,20 @@ def generate_beam_search(
outputs = self.generate(prompts, beam_search_params)
return outputs

def generate_beam_search_new(
self,
prompts: Union[List[str], List[List[int]]],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
returned_outputs = []
for output in outputs:
token_ids = [x.tokens for x in output.sequences]
texts = [x.text for x in output.sequences]
returned_outputs.append((token_ids, texts))
return returned_outputs

def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []
Expand Down
6 changes: 3 additions & 3 deletions tests/samplers/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [128]
MAX_TOKENS = [64]
BEAM_WIDTHS = [4]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]

Expand All @@ -33,8 +33,8 @@ def test_beam_search_single_input(
max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
vllm_outputs = vllm_model.generate_beam_search_new(
example_prompts, beam_width, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
Expand Down
136 changes: 134 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import itertools
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)

from tqdm import tqdm

Expand Down Expand Up @@ -30,6 +32,37 @@
logger = init_logger(__name__)


@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None


@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]


class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []


class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
Expand Down Expand Up @@ -354,6 +387,105 @@ def generate(
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)

def beam_search(
self,
prompts: List[Union[str, List[int]]],
beam_width: int,
max_tokens: int,
ignore_eos: bool = False,
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""

tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=0.0)
instances: List[BeamSearchInstance] = []

for prompt in prompts:
prompt_tokens = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
instances.append(BeamSearchInstance(prompt_tokens))

for _ in range(max_tokens):
all_beams: List[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances))
instance_start_and_end: List[Tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))

if len(all_beams) == 0:
break

prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]

# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False)

for (start, end), instance in zip(instance_start_and_end,
instances):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]

if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob,
reverse=True)
instance.beams = sorted_beams[:beam_width]

outputs = []
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))

return outputs

def chat(
self,
messages: List[ChatCompletionMessageParam],
Expand Down

0 comments on commit 0250dd6

Please sign in to comment.