Skip to content

Commit

Permalink
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
Browse files Browse the repository at this point in the history
  • Loading branch information
CatherineSue authored May 11, 2024
1 parent 4e12131 commit e254497
Show file tree
Hide file tree
Showing 38 changed files with 1,627 additions and 160 deletions.
17 changes: 17 additions & 0 deletions examples/offline_inference_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from vllm import LLM

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
23 changes: 23 additions & 0 deletions examples/openai_embedding_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

responses = client.embeddings.create(input=[
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
],
model=model)

for data in responses.data:
print(data.embedding) # list of float of len 4096
9 changes: 6 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ pytest-forked
pytest-asyncio
pytest-rerunfailures
pytest-shard
httpx

# testing utils
awscli
einops # required for MPT
httpx
peft
requests
ray
peft
awscli
sentence-transformers # required for embedding

# Benchmarking
aiohttp
Expand Down
38 changes: 30 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def example_long_prompts() -> List[str]:
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
}

_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]


class HfRunner:

Expand All @@ -145,14 +149,7 @@ def __init__(
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if model_name not in _VISION_LANGUAGE_MODELS:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
else:
if model_name in _VISION_LANGUAGE_MODELS:
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
model_name,
torch_dtype=torch_dtype,
Expand All @@ -162,6 +159,20 @@ def __init__(
model_name,
torch_dtype=torch_dtype,
)
elif model_name in _EMBEDDING_MODELS:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype).cuda()
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
Expand Down Expand Up @@ -334,6 +345,9 @@ def generate_greedy_logprobs_limit(
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts)

def __del__(self):
del self.model
cleanup()
Expand Down Expand Up @@ -459,6 +473,14 @@ def generate_beam_search(
outputs = self.generate(prompts, beam_search_params)
return outputs

def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs

def __del__(self):
del self.model
cleanup()
Expand Down
12 changes: 6 additions & 6 deletions tests/engine/output_processor/test_multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter

Expand Down Expand Up @@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
new_token_ids = list(range(num_new_tokens))

outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
new_token_ids = list(range(num_new_tokens))

outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids[eos_index] = eos_token_id

outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids[eos_index] = eos_token_id

outputs = [
SequenceGroupOutput(
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class MockModelConfig:
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False


@dataclass
Expand Down
96 changes: 95 additions & 1 deletion tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
Expand Down Expand Up @@ -121,7 +122,7 @@ def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def server(zephyr_lora_files):
ray.init()
server_runner = ServerRunner.remote([
Expand Down Expand Up @@ -150,6 +151,25 @@ def server(zephyr_lora_files):
ray.shutdown()


@pytest.fixture(scope="module")
def embedding_server(zephyr_lora_files):
ray.shutdown()
ray.init()
server_runner = ServerRunner.remote([
"--model",
EMBEDDING_MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()


@pytest.fixture(scope="module")
def client():
client = openai.AsyncOpenAI(
Expand Down Expand Up @@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)


@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
input = [
"The chef prepared a delicious meal.",
]

# test single embedding
embeddings = await client.embeddings.create(
model=model_name,
input=input,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 9
assert embeddings.usage.total_tokens == 9

# test using token IDs
input = [1, 1, 1, 1, 1]
embeddings = await client.embeddings.create(
model=model_name,
input=input,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 5
assert embeddings.usage.total_tokens == 5


@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
# test List[str]
inputs = [
"The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky."
]
embeddings = await client.embeddings.create(
model=model_name,
input=inputs,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096

# test List[List[int]]
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
embeddings = await client.embeddings.create(
model=model_name,
input=inputs,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 4
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17


if __name__ == "__main__":
pytest.main([__file__])
44 changes: 44 additions & 0 deletions tests/models/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_llama_embedding.py`.
"""
import pytest
import torch
import torch.nn.functional as F

MODELS = [
"intfloat/e5-mistral-7b-instruct",
]


def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.encode(example_prompts)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.encode(example_prompts)
del vllm_model

similarities = compare_embeddings(hf_outputs, vllm_outputs)
all_similarities = torch.stack(similarities)
tolerance = 1e-2
assert torch.all((all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"Not all values are within {tolerance} of 1.0"
6 changes: 3 additions & 3 deletions tests/samplers/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def pick_vllm(token_ids, logits):
# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[0],
sampling_params=params_with_logprobs,
params=params_with_logprobs,
prompt_token_ids=None,
)

# test prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[1],
sampling_params=SamplingParams(
params=SamplingParams(
prompt_logprobs=3,
max_tokens=max_tokens,
),
Expand All @@ -53,7 +53,7 @@ def pick_vllm(token_ids, logits):
# test grouped requests
vllm_model.model._add_request(
prompt=example_prompts[2],
sampling_params=SamplingParams(max_tokens=max_tokens),
params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/samplers/test_seeded_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_random_sample_with_seed(
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=params,
params=params,
)

results = llm._run_engine(use_tqdm=False)
Expand Down
Loading

0 comments on commit e254497

Please sign in to comment.