Skip to content

Commit e254497

Browse files
authored
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
1 parent 4e12131 commit e254497

38 files changed

+1627
-160
lines changed
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from vllm import LLM
2+
3+
# Sample prompts.
4+
prompts = [
5+
"Hello, my name is",
6+
"The president of the United States is",
7+
"The capital of France is",
8+
"The future of AI is",
9+
]
10+
11+
# Create an LLM.
12+
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
13+
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
14+
outputs = model.encode(prompts)
15+
# Print the outputs.
16+
for output in outputs:
17+
print(output.outputs.embedding) # list of 4096 floats

examples/openai_embedding_client.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from openai import OpenAI
2+
3+
# Modify OpenAI's API key and API base to use vLLM's API server.
4+
openai_api_key = "EMPTY"
5+
openai_api_base = "http://localhost:8000/v1"
6+
7+
client = OpenAI(
8+
# defaults to os.environ.get("OPENAI_API_KEY")
9+
api_key=openai_api_key,
10+
base_url=openai_api_base,
11+
)
12+
13+
models = client.models.list()
14+
model = models.data[0].id
15+
16+
responses = client.embeddings.create(input=[
17+
"Hello my name is",
18+
"The best thing about vLLM is that it supports many different models"
19+
],
20+
model=model)
21+
22+
for data in responses.data:
23+
print(data.embedding) # list of float of len 4096

requirements-dev.txt

+6-3
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@ pytest-forked
1919
pytest-asyncio
2020
pytest-rerunfailures
2121
pytest-shard
22-
httpx
22+
23+
# testing utils
24+
awscli
2325
einops # required for MPT
26+
httpx
27+
peft
2428
requests
2529
ray
26-
peft
27-
awscli
30+
sentence-transformers # required for embedding
2831

2932
# Benchmarking
3033
aiohttp

tests/conftest.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def example_long_prompts() -> List[str]:
133133
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
134134
}
135135

136+
_EMBEDDING_MODELS = [
137+
"intfloat/e5-mistral-7b-instruct",
138+
]
139+
136140

137141
class HfRunner:
138142

@@ -145,14 +149,7 @@ def __init__(
145149
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
146150
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
147151
self.model_name = model_name
148-
if model_name not in _VISION_LANGUAGE_MODELS:
149-
self.model = AutoModelForCausalLM.from_pretrained(
150-
model_name,
151-
torch_dtype=torch_dtype,
152-
trust_remote_code=True,
153-
).cuda()
154-
self.processor = None
155-
else:
152+
if model_name in _VISION_LANGUAGE_MODELS:
156153
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
157154
model_name,
158155
torch_dtype=torch_dtype,
@@ -162,6 +159,20 @@ def __init__(
162159
model_name,
163160
torch_dtype=torch_dtype,
164161
)
162+
elif model_name in _EMBEDDING_MODELS:
163+
# Lazy init required for AMD CI
164+
from sentence_transformers import SentenceTransformer
165+
self.model = SentenceTransformer(
166+
model_name,
167+
device="cpu",
168+
).to(dtype=torch_dtype).cuda()
169+
else:
170+
self.model = AutoModelForCausalLM.from_pretrained(
171+
model_name,
172+
torch_dtype=torch_dtype,
173+
trust_remote_code=True,
174+
).cuda()
175+
self.processor = None
165176
if tokenizer_name is None:
166177
tokenizer_name = model_name
167178
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
@@ -334,6 +345,9 @@ def generate_greedy_logprobs_limit(
334345
return [(output_ids, output_str, output_logprobs)
335346
for output_ids, output_str, output_logprobs in outputs]
336347

348+
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
349+
return self.model.encode(prompts)
350+
337351
def __del__(self):
338352
del self.model
339353
cleanup()
@@ -459,6 +473,14 @@ def generate_beam_search(
459473
outputs = self.generate(prompts, beam_search_params)
460474
return outputs
461475

476+
def encode(self, prompts: List[str]) -> List[List[float]]:
477+
req_outputs = self.model.encode(prompts)
478+
outputs = []
479+
for req_output in req_outputs:
480+
embedding = req_output.outputs.embedding
481+
outputs.append(embedding)
482+
return outputs
483+
462484
def __del__(self):
463485
del self.model
464486
cleanup()

tests/engine/output_processor/test_multi_step.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
1010
from vllm.engine.output_processor.stop_checker import StopChecker
1111
from vllm.sampling_params import SamplingParams
12-
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
13-
SequenceStatus)
12+
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
13+
SequenceOutput, SequenceStatus)
1414
from vllm.transformers_utils.detokenizer import Detokenizer
1515
from vllm.utils import Counter
1616

@@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
5151
new_token_ids = list(range(num_new_tokens))
5252

5353
outputs = [
54-
SequenceGroupOutput(
54+
CompletionSequenceGroupOutput(
5555
samples=[
5656
SequenceOutput(
5757
parent_seq_id=seq.seq_id,
@@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
103103
new_token_ids = list(range(num_new_tokens))
104104

105105
outputs = [
106-
SequenceGroupOutput(
106+
CompletionSequenceGroupOutput(
107107
samples=[
108108
SequenceOutput(
109109
parent_seq_id=seq.seq_id,
@@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
170170
new_token_ids[eos_index] = eos_token_id
171171

172172
outputs = [
173-
SequenceGroupOutput(
173+
CompletionSequenceGroupOutput(
174174
samples=[
175175
SequenceOutput(
176176
parent_seq_id=seq.seq_id,
@@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
239239
new_token_ids[eos_index] = eos_token_id
240240

241241
outputs = [
242-
SequenceGroupOutput(
242+
CompletionSequenceGroupOutput(
243243
samples=[
244244
SequenceOutput(
245245
parent_seq_id=seq.seq_id,

tests/entrypoints/openai/test_serving_chat.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class MockModelConfig:
1414
tokenizer_mode = "auto"
1515
max_model_len = 100
1616
tokenizer_revision = None
17+
embedding_mode = False
1718

1819

1920
@dataclass

tests/entrypoints/test_openai_server.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
2424
# any model with a chat template should work here
2525
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
26+
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
2627
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
2728
# generation quality here
2829
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@@ -121,7 +122,7 @@ def zephyr_lora_files():
121122
return snapshot_download(repo_id=LORA_NAME)
122123

123124

124-
@pytest.fixture(scope="session")
125+
@pytest.fixture(scope="module")
125126
def server(zephyr_lora_files):
126127
ray.init()
127128
server_runner = ServerRunner.remote([
@@ -150,6 +151,25 @@ def server(zephyr_lora_files):
150151
ray.shutdown()
151152

152153

154+
@pytest.fixture(scope="module")
155+
def embedding_server(zephyr_lora_files):
156+
ray.shutdown()
157+
ray.init()
158+
server_runner = ServerRunner.remote([
159+
"--model",
160+
EMBEDDING_MODEL_NAME,
161+
# use half precision for speed and memory savings in CI environment
162+
"--dtype",
163+
"bfloat16",
164+
"--max-model-len",
165+
"8192",
166+
"--enforce-eager",
167+
])
168+
ray.get(server_runner.ready.remote())
169+
yield server_runner
170+
ray.shutdown()
171+
172+
153173
@pytest.fixture(scope="module")
154174
def client():
155175
client = openai.AsyncOpenAI(
@@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
890910
or "less_than_equal" in exc_info.value.message)
891911

892912

913+
@pytest.mark.parametrize(
914+
"model_name",
915+
[EMBEDDING_MODEL_NAME],
916+
)
917+
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
918+
model_name: str):
919+
input = [
920+
"The chef prepared a delicious meal.",
921+
]
922+
923+
# test single embedding
924+
embeddings = await client.embeddings.create(
925+
model=model_name,
926+
input=input,
927+
encoding_format="float",
928+
)
929+
assert embeddings.id is not None
930+
assert embeddings.data is not None and len(embeddings.data) == 1
931+
assert len(embeddings.data[0].embedding) == 4096
932+
assert embeddings.usage.completion_tokens == 0
933+
assert embeddings.usage.prompt_tokens == 9
934+
assert embeddings.usage.total_tokens == 9
935+
936+
# test using token IDs
937+
input = [1, 1, 1, 1, 1]
938+
embeddings = await client.embeddings.create(
939+
model=model_name,
940+
input=input,
941+
encoding_format="float",
942+
)
943+
assert embeddings.id is not None
944+
assert embeddings.data is not None and len(embeddings.data) == 1
945+
assert len(embeddings.data[0].embedding) == 4096
946+
assert embeddings.usage.completion_tokens == 0
947+
assert embeddings.usage.prompt_tokens == 5
948+
assert embeddings.usage.total_tokens == 5
949+
950+
951+
@pytest.mark.parametrize(
952+
"model_name",
953+
[EMBEDDING_MODEL_NAME],
954+
)
955+
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
956+
model_name: str):
957+
# test List[str]
958+
inputs = [
959+
"The cat sat on the mat.", "A feline was resting on a rug.",
960+
"Stars twinkle brightly in the night sky."
961+
]
962+
embeddings = await client.embeddings.create(
963+
model=model_name,
964+
input=inputs,
965+
encoding_format="float",
966+
)
967+
assert embeddings.id is not None
968+
assert embeddings.data is not None and len(embeddings.data) == 3
969+
assert len(embeddings.data[0].embedding) == 4096
970+
971+
# test List[List[int]]
972+
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
973+
[25, 32, 64, 77]]
974+
embeddings = await client.embeddings.create(
975+
model=model_name,
976+
input=inputs,
977+
encoding_format="float",
978+
)
979+
assert embeddings.id is not None
980+
assert embeddings.data is not None and len(embeddings.data) == 4
981+
assert len(embeddings.data[0].embedding) == 4096
982+
assert embeddings.usage.completion_tokens == 0
983+
assert embeddings.usage.prompt_tokens == 17
984+
assert embeddings.usage.total_tokens == 17
985+
986+
893987
if __name__ == "__main__":
894988
pytest.main([__file__])

tests/models/test_embedding.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
2+
3+
Run `pytest tests/models/test_llama_embedding.py`.
4+
"""
5+
import pytest
6+
import torch
7+
import torch.nn.functional as F
8+
9+
MODELS = [
10+
"intfloat/e5-mistral-7b-instruct",
11+
]
12+
13+
14+
def compare_embeddings(embeddings1, embeddings2):
15+
similarities = [
16+
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
17+
for e1, e2 in zip(embeddings1, embeddings2)
18+
]
19+
return similarities
20+
21+
22+
@pytest.mark.parametrize("model", MODELS)
23+
@pytest.mark.parametrize("dtype", ["half"])
24+
def test_models(
25+
hf_runner,
26+
vllm_runner,
27+
example_prompts,
28+
model: str,
29+
dtype: str,
30+
) -> None:
31+
hf_model = hf_runner(model, dtype=dtype)
32+
hf_outputs = hf_model.encode(example_prompts)
33+
del hf_model
34+
35+
vllm_model = vllm_runner(model, dtype=dtype)
36+
vllm_outputs = vllm_model.encode(example_prompts)
37+
del vllm_model
38+
39+
similarities = compare_embeddings(hf_outputs, vllm_outputs)
40+
all_similarities = torch.stack(similarities)
41+
tolerance = 1e-2
42+
assert torch.all((all_similarities <= 1.0 + tolerance)
43+
& (all_similarities >= 1.0 - tolerance)
44+
), f"Not all values are within {tolerance} of 1.0"

tests/samplers/test_logits_processor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ def pick_vllm(token_ids, logits):
3636
# test logits_processors when prompt_logprobs is not None
3737
vllm_model.model._add_request(
3838
prompt=example_prompts[0],
39-
sampling_params=params_with_logprobs,
39+
params=params_with_logprobs,
4040
prompt_token_ids=None,
4141
)
4242

4343
# test prompt_logprobs is not None
4444
vllm_model.model._add_request(
4545
prompt=example_prompts[1],
46-
sampling_params=SamplingParams(
46+
params=SamplingParams(
4747
prompt_logprobs=3,
4848
max_tokens=max_tokens,
4949
),
@@ -53,7 +53,7 @@ def pick_vllm(token_ids, logits):
5353
# test grouped requests
5454
vllm_model.model._add_request(
5555
prompt=example_prompts[2],
56-
sampling_params=SamplingParams(max_tokens=max_tokens),
56+
params=SamplingParams(max_tokens=max_tokens),
5757
prompt_token_ids=None,
5858
)
5959

tests/samplers/test_seeded_generate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_random_sample_with_seed(
6060
llm._add_request(
6161
prompt=prompt,
6262
prompt_token_ids=None,
63-
sampling_params=params,
63+
params=params,
6464
)
6565

6666
results = llm._run_engine(use_tqdm=False)

0 commit comments

Comments
 (0)