Skip to content

Commit

Permalink
[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Oct 18, 2024
1 parent 944dd8e commit 1ffc8a7
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
7 changes: 5 additions & 2 deletions vllm/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional

from vllm.sequence import Logprob


@dataclass
Expand All @@ -11,6 +13,7 @@ class BeamSearchSequence:
"""
# The tokens includes the prompt.
tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None

Expand All @@ -28,7 +31,7 @@ class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
]
self.completed: List[BeamSearchSequence] = []

Expand Down
29 changes: 19 additions & 10 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def generate(

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
prompt: Union[str, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -71,17 +71,25 @@ async def beam_search(
length_penalty = params.length_penalty

tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
if isinstance(prompt, str):
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = []

for _ in range(max_tokens):
Expand Down Expand Up @@ -114,6 +122,7 @@ async def beam_search(
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

Expand All @@ -131,22 +140,22 @@ async def beam_search(
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam.text = tokenizer.decode(beam.tokens[tokenized_length:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.cum_logprob,
logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_token_ids=tokenized_prompt,
prompt_logprobs=None)

yield beam_search_output
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

Expand Down
3 changes: 1 addition & 2 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Sequence as GenericSequence
from typing import Union

from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
Expand Down Expand Up @@ -93,7 +92,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: Optional[PromptType],
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
Expand Down

0 comments on commit 1ffc8a7

Please sign in to comment.