diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e1a5d4ee28ea1..68b401d5bbbb7 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -90,6 +90,7 @@ def run_vllm( download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, + use_new_beam_search_impl: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -132,9 +133,23 @@ def run_vllm( max_tokens=output_len, )) - start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() + if not use_new_beam_search_impl: + start = time.perf_counter() + llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + else: + assert use_beam_search + prompts = [prompt for prompt, _, _ in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for prompt, input_len, _output_len in requests: + assert _output_len == output_len + start = time.perf_counter() + llm.beam_search(prompts, + beam_width=n, + max_tokens=output_len, + ignore_eos=True) + end = time.perf_counter() return end - start @@ -336,7 +351,7 @@ def main(args: argparse.Namespace): run_args.append(args.disable_frontend_multiprocessing) elapsed_time = uvloop.run(run_vllm_async(*run_args)) else: - elapsed_time = run_vllm(*run_args) + elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -396,6 +411,7 @@ def main(args: argparse.Namespace): default=1, help="Number of generated sequences per prompt.") parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--use-new-beam-search-impl", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, diff --git a/tests/conftest.py b/tests/conftest.py index c2616bcf7091c..69ac4aaee0fda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -798,6 +798,20 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def generate_beam_search_new( + self, + prompts: Union[List[str], List[List[int]]], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.model.beam_search(prompts, beam_width, max_tokens) + returned_outputs = [] + for output in outputs: + token_ids = [x.tokens for x in output.sequences] + texts = [x.text for x in output.sequences] + returned_outputs.append((token_ids, texts)) + return returned_outputs + def encode(self, prompts: List[str]) -> List[List[float]]: req_outputs = self.model.encode(prompts) outputs = [] diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 98a02dec895d2..a9bedc2956fdd 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -9,7 +9,7 @@ # 1. Increase max_tokens to 256. # 2. Increase beam_width to 8. # 3. Use the model "huggyllama/llama-7b". -MAX_TOKENS = [128] +MAX_TOKENS = [64] BEAM_WIDTHS = [4] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @@ -33,8 +33,8 @@ def test_beam_search_single_input( max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search(example_prompts, - beam_width, max_tokens) + vllm_outputs = vllm_model.generate_beam_search_new( + example_prompts, beam_width, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a86c51d23b34d..387813f374daa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,8 @@ +import itertools from contextlib import contextmanager -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, - overload) +from dataclasses import dataclass +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, + Union, cast, overload) from tqdm import tqdm @@ -30,6 +32,37 @@ logger = init_logger(__name__) +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -354,6 +387,105 @@ def generate( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) + def beam_search( + self, + prompts: List[Union[str, List[int]]], + beam_width: int, + max_tokens: int, + ignore_eos: bool = False, + ) -> List[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + beam_width: The number of beams to keep at each step. + max_tokens: The max number of tokens to generate for each prompt. + + TODO: how does beam search work together with length penalty, frequency + penalty, and stopping criteria, etc.? + """ + + tokenizer = self.get_tokenizer() + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=0.0) + instances: List[BeamSearchInstance] = [] + + for prompt in prompts: + prompt_tokens = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + instances.append(BeamSearchInstance(prompt_tokens)) + + for _ in range(max_tokens): + all_beams: List[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: List[Tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=lambda x: x.cum_logprob, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=lambda x: x.cum_logprob, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + def chat( self, messages: List[ChatCompletionMessageParam],