Skip to content

Commit

Permalink
Fix logits processor handling
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and joerunde committed Apr 2, 2024
1 parent e92db68 commit 74803b6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 27 deletions.
4 changes: 2 additions & 2 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.tgis_utils.logits_processors import LengthPenaltyWarper
from vllm.tgis_utils.logits_processors import ExpDecayLengthPenaltyWarper
from vllm.worker.model_runner import ModelRunner


Expand Down Expand Up @@ -106,7 +106,7 @@ def test_exponential_decay_length_penalty(seed: int, device: str):
logits_processor.scale = 1.0

eos_token_id = 100
lenpen = LengthPenaltyWarper([2, 2.0], eos_token_id)
lenpen = ExpDecayLengthPenaltyWarper((2, 2.0), eos_token_id)

seq_group_metadata_list = []
prompt_lens = []
Expand Down
21 changes: 10 additions & 11 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.logger import init_logger
from vllm.sequence import Logprob
from vllm.tgis_utils.logits_processors import (LengthPenaltyWarper,
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
TypicalLogitsWarperWrapper)
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

Expand Down Expand Up @@ -276,6 +276,7 @@ async def _validate_and_convert_params(
resp_options = params.response
sampling = params.sampling
stopping = params.stopping
decoding = params.decoding
greedy = params.method == DecodingMethod.GREEDY

max_new_tokens: Optional[int] = None
Expand All @@ -295,9 +296,6 @@ async def _validate_and_convert_params(

logprobs = with_default(logprobs, None)

# GAPS:
# - exp_decay_length_penalty

# NEW FUNCTION TO ADD (later)
# - presence penalty, freq penalty
# - min_p
Expand All @@ -316,14 +314,15 @@ async def _validate_and_convert_params(
if not greedy and 0.0 < sampling.typical_p < 1.0:
logits_processors.append(
TypicalLogitsWarperWrapper(mass=sampling.typical_p))
if params.decoding.length_penalty is not None:
length_penalty = (
params.decoding.length_penalty.start_index,
params.decoding.length_penalty.decay_factor,

if decoding.HasField("length_penalty"):
length_penalty_tuple = (
decoding.length_penalty.start_index,
decoding.length_penalty.decay_factor,
)
logits_processors.append(
LengthPenaltyWarper(length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
ExpDecayLengthPenaltyWarper(length_penalty=length_penalty_tuple,
eos_token_id=self.tokenizer.eos_token_id))

time_limit_millis = stopping.time_limit_millis
deadline = time.time(
Expand All @@ -342,7 +341,7 @@ async def _validate_and_convert_params(
top_p=with_default(sampling.top_p, 1.0),
seed=sampling.seed if sampling.HasField("seed") else None,
repetition_penalty=with_default(
params.decoding.repetition_penalty, 1.0),
decoding.repetition_penalty, 1.0),
logits_processors=logits_processors,
stop=with_default(stopping.stop_sequences, None),
include_stop_str_in_output=stopping.include_stop_sequence
Expand Down
10 changes: 3 additions & 7 deletions vllm/entrypoints/grpc/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,9 @@ def validate_params(params: Parameters, max_max_new_tokens: int):
decoding = params.decoding

# Decoding parameter checks
if decoding.HasField("length_penalty"):
args = [
decoding.length_penalty.start_index,
decoding.length_penalty.decay_factor
]
if None in args or not (1.0 <= args[1] <= 10.0):
TGISValidationError.LengthPenalty.error()
if decoding.HasField("length_penalty") and not (
1.0 <= decoding.length_penalty.decay_factor <= 10.0):
TGISValidationError.LengthPenalty.error()

if not (0 <= decoding.repetition_penalty <= 2):
# (a value of 0 means no penalty / unset)
Expand Down
15 changes: 8 additions & 7 deletions vllm/tgis_utils/logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,26 @@ def __init__(self, mass: float):
self.warper = TypicalLogitsWarper(mass=mass)

def __call__(self, token_ids: List[int],
logits: torch.tensor) -> torch.tensor:
logits: torch.Tensor) -> torch.Tensor:
# transformers warpers assume tensors of shape (batch_size, vocab_size)
# and the typical warper doesn't use input_ids
return self.warper(input_ids=None, scores=logits.reshape((1, -1)))
return self.warper(input_ids=None,
scores=logits.reshape(1, -1)).flatten()


class LengthPenaltyWarper:
class ExpDecayLengthPenaltyWarper:

def __init__(self, length_penalty: Tuple[int, float], eos_token_id: int):
self.length_penalty = length_penalty
self.start, self.penalty = length_penalty
self.eos_token_id = eos_token_id

def __call__(self, token_ids: List[int],
logits: torch.tensor) -> torch.tensor:
tokens_past = len(token_ids) - self.length_penalty[0]
logits: torch.Tensor) -> torch.Tensor:
tokens_past = len(token_ids) - self.start
if tokens_past > 0:
eos_logit = logits[self.eos_token_id]
# To support negative logits we compute the penalty of the
# absolute value and add to the original logit
logits[self.eos_token_id] = eos_logit + torch.abs(eos_logit) * (
pow(self.length_penalty[1], tokens_past) - 1)
pow(self.penalty, tokens_past) - 1)
return logits

0 comments on commit 74803b6

Please sign in to comment.