From 74803b616e76255a47f2d486bebb9f854b042620 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 2 Apr 2024 16:19:38 -0700 Subject: [PATCH] Fix logits processor handling --- tests/test_logits_processor.py | 4 ++-- vllm/entrypoints/grpc/grpc_server.py | 21 ++++++++++----------- vllm/entrypoints/grpc/validation.py | 10 +++------- vllm/tgis_utils/logits_processors.py | 15 ++++++++------- 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index c2a263814..f1db62111 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.tgis_utils.logits_processors import LengthPenaltyWarper +from vllm.tgis_utils.logits_processors import ExpDecayLengthPenaltyWarper from vllm.worker.model_runner import ModelRunner @@ -106,7 +106,7 @@ def test_exponential_decay_length_penalty(seed: int, device: str): logits_processor.scale = 1.0 eos_token_id = 100 - lenpen = LengthPenaltyWarper([2, 2.0], eos_token_id) + lenpen = ExpDecayLengthPenaltyWarper((2, 2.0), eos_token_id) seq_group_metadata_list = [] prompt_lens = [] diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index 311e4a386..c693fbc98 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -33,7 +33,7 @@ from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.logger import init_logger from vllm.sequence import Logprob -from vllm.tgis_utils.logits_processors import (LengthPenaltyWarper, +from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper, TypicalLogitsWarperWrapper) from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -276,6 +276,7 @@ async def _validate_and_convert_params( resp_options = params.response sampling = params.sampling stopping = params.stopping + decoding = params.decoding greedy = params.method == DecodingMethod.GREEDY max_new_tokens: Optional[int] = None @@ -295,9 +296,6 @@ async def _validate_and_convert_params( logprobs = with_default(logprobs, None) - # GAPS: - # - exp_decay_length_penalty - # NEW FUNCTION TO ADD (later) # - presence penalty, freq penalty # - min_p @@ -316,14 +314,15 @@ async def _validate_and_convert_params( if not greedy and 0.0 < sampling.typical_p < 1.0: logits_processors.append( TypicalLogitsWarperWrapper(mass=sampling.typical_p)) - if params.decoding.length_penalty is not None: - length_penalty = ( - params.decoding.length_penalty.start_index, - params.decoding.length_penalty.decay_factor, + + if decoding.HasField("length_penalty"): + length_penalty_tuple = ( + decoding.length_penalty.start_index, + decoding.length_penalty.decay_factor, ) logits_processors.append( - LengthPenaltyWarper(length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + ExpDecayLengthPenaltyWarper(length_penalty=length_penalty_tuple, + eos_token_id=self.tokenizer.eos_token_id)) time_limit_millis = stopping.time_limit_millis deadline = time.time( @@ -342,7 +341,7 @@ async def _validate_and_convert_params( top_p=with_default(sampling.top_p, 1.0), seed=sampling.seed if sampling.HasField("seed") else None, repetition_penalty=with_default( - params.decoding.repetition_penalty, 1.0), + decoding.repetition_penalty, 1.0), logits_processors=logits_processors, stop=with_default(stopping.stop_sequences, None), include_stop_str_in_output=stopping.include_stop_sequence diff --git a/vllm/entrypoints/grpc/validation.py b/vllm/entrypoints/grpc/validation.py index 308fef1d5..17920bdcd 100644 --- a/vllm/entrypoints/grpc/validation.py +++ b/vllm/entrypoints/grpc/validation.py @@ -71,13 +71,9 @@ def validate_params(params: Parameters, max_max_new_tokens: int): decoding = params.decoding # Decoding parameter checks - if decoding.HasField("length_penalty"): - args = [ - decoding.length_penalty.start_index, - decoding.length_penalty.decay_factor - ] - if None in args or not (1.0 <= args[1] <= 10.0): - TGISValidationError.LengthPenalty.error() + if decoding.HasField("length_penalty") and not ( + 1.0 <= decoding.length_penalty.decay_factor <= 10.0): + TGISValidationError.LengthPenalty.error() if not (0 <= decoding.repetition_penalty <= 2): # (a value of 0 means no penalty / unset) diff --git a/vllm/tgis_utils/logits_processors.py b/vllm/tgis_utils/logits_processors.py index 5ea20c442..2a8437e10 100644 --- a/vllm/tgis_utils/logits_processors.py +++ b/vllm/tgis_utils/logits_processors.py @@ -10,25 +10,26 @@ def __init__(self, mass: float): self.warper = TypicalLogitsWarper(mass=mass) def __call__(self, token_ids: List[int], - logits: torch.tensor) -> torch.tensor: + logits: torch.Tensor) -> torch.Tensor: # transformers warpers assume tensors of shape (batch_size, vocab_size) # and the typical warper doesn't use input_ids - return self.warper(input_ids=None, scores=logits.reshape((1, -1))) + return self.warper(input_ids=None, + scores=logits.reshape(1, -1)).flatten() -class LengthPenaltyWarper: +class ExpDecayLengthPenaltyWarper: def __init__(self, length_penalty: Tuple[int, float], eos_token_id: int): - self.length_penalty = length_penalty + self.start, self.penalty = length_penalty self.eos_token_id = eos_token_id def __call__(self, token_ids: List[int], - logits: torch.tensor) -> torch.tensor: - tokens_past = len(token_ids) - self.length_penalty[0] + logits: torch.Tensor) -> torch.Tensor: + tokens_past = len(token_ids) - self.start if tokens_past > 0: eos_logit = logits[self.eos_token_id] # To support negative logits we compute the penalty of the # absolute value and add to the original logit logits[self.eos_token_id] = eos_logit + torch.abs(eos_logit) * ( - pow(self.length_penalty[1], tokens_past) - 1) + pow(self.penalty, tokens_past) - 1) return logits