Skip to content

Commit

Permalink
[Core] Support serving encoder/decoder models (vllm-project#7258)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Aug 9, 2024
1 parent 0fa1490 commit 7eb4a51
Show file tree
Hide file tree
Showing 25 changed files with 603 additions and 464 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install mypy==1.11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
Expand Down
8 changes: 4 additions & 4 deletions examples/offline_inference_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
'''

from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
TokensPrompt, zip_enc_dec_prompts)

dtype = "float"

Expand Down Expand Up @@ -61,9 +61,9 @@
)

# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompt_lists(
zipped_prompt_list = zip_enc_dec_prompts(
['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt'])

Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1
2 changes: 1 addition & 1 deletion requirements-lint.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5

# type checking
mypy==1.9.0
mypy==1.11.1
types-PyYAML
types-requests
types-setuptools
32 changes: 19 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union

import pytest
Expand All @@ -14,20 +15,19 @@
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
BatchFeature)

from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
is_cpu)

logger = init_logger(__name__)

Expand Down Expand Up @@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
return prompts


class DecoderPromptType(Enum):
"""For encoder/decoder models only."""
CUSTOM = 1
NONE = 2
EMPTY_STR = 3


@pytest.fixture
def example_encoder_decoder_prompts() \
-> Dict[DecoderPromptType,
Tuple[List[str], List[Optional[str]]]]:
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
Expand All @@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
# NONE decoder prompt type
return {
DecoderPromptType.NONE:
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
DecoderPromptType.EMPTY_STR:
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
DecoderPromptType.CUSTOM:
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
}


Expand Down Expand Up @@ -444,7 +450,7 @@ def generate_greedy_logprobs_limit(

def generate_encoder_decoder_greedy_logprobs_limit(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
Expand Down Expand Up @@ -608,7 +614,7 @@ def generate_w_logprobs(

def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
'''
Expand Down Expand Up @@ -653,7 +659,7 @@ def generate_greedy_logprobs(

def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import pytest

from tests.models.utils import DecoderPromptType
from vllm.utils import cuda_device_count_stateless

from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close
from ..utils import fork_new_process_for_each_test

Expand Down
50 changes: 50 additions & 0 deletions tests/entrypoints/openai/test_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import openai
import pytest

from ...utils import RemoteOpenAIServer

MODEL_NAME = "facebook/bart-base"


@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--enforce-eager",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1

choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=2, total_tokens=7)

# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
38 changes: 27 additions & 11 deletions tests/models/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Run `pytest tests/models/test_bart.py`.
"""
from typing import List, Optional, Tuple

from vllm.utils import is_cpu

if not is_cpu():
Expand All @@ -11,22 +13,31 @@

import pytest

from tests.models.utils import DecoderPromptType
from vllm.sequence import SampleLogprobs

from ..conftest import DecoderPromptType
from .utils import check_logprobs_close

MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]

DECODER_PROMPT_TYPES = ([
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
DecoderPromptType.NONE
])
def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

hf_output_str = output_str + "</s>"
if decoder_prompt_type == DecoderPromptType.NONE:
hf_output_str = "<s>" + hf_output_str

return output_ids, hf_output_str, out_logprobs

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(
hf_runner,
vllm_runner,
Expand Down Expand Up @@ -146,8 +157,13 @@ def test_models(
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)

check_logprobs_close(outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
11 changes: 0 additions & 11 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings
from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union

from vllm.sequence import SampleLogprobs
Expand Down Expand Up @@ -136,13 +135,3 @@ def check_logprobs_close(
warnings.simplefilter("always")

warnings.warn(fail_msg, stacklevel=2)


class DecoderPromptType(Enum):
'''
For encoder/decoder models only -
'''
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
2 changes: 1 addition & 1 deletion tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt

STRING_INPUTS = [
'',
Expand Down
10 changes: 10 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,16 @@ def _get_num_seqlen_agnostic_layers(
if t != "attention"
])

@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)

@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode


class CacheConfig:
"""Configuration for the KV cache.
Expand Down
Loading

0 comments on commit 7eb4a51

Please sign in to comment.