|
23 | 23 | MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
24 | 24 | # any model with a chat template should work here
|
25 | 25 | MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
| 26 | +EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" |
26 | 27 | # technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
27 | 28 | # generation quality here
|
28 | 29 | LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
@@ -121,7 +122,7 @@ def zephyr_lora_files():
|
121 | 122 | return snapshot_download(repo_id=LORA_NAME)
|
122 | 123 |
|
123 | 124 |
|
124 |
| -@pytest.fixture(scope="session") |
| 125 | +@pytest.fixture(scope="module") |
125 | 126 | def server(zephyr_lora_files):
|
126 | 127 | ray.init()
|
127 | 128 | server_runner = ServerRunner.remote([
|
@@ -150,6 +151,25 @@ def server(zephyr_lora_files):
|
150 | 151 | ray.shutdown()
|
151 | 152 |
|
152 | 153 |
|
| 154 | +@pytest.fixture(scope="module") |
| 155 | +def embedding_server(zephyr_lora_files): |
| 156 | + ray.shutdown() |
| 157 | + ray.init() |
| 158 | + server_runner = ServerRunner.remote([ |
| 159 | + "--model", |
| 160 | + EMBEDDING_MODEL_NAME, |
| 161 | + # use half precision for speed and memory savings in CI environment |
| 162 | + "--dtype", |
| 163 | + "bfloat16", |
| 164 | + "--max-model-len", |
| 165 | + "8192", |
| 166 | + "--enforce-eager", |
| 167 | + ]) |
| 168 | + ray.get(server_runner.ready.remote()) |
| 169 | + yield server_runner |
| 170 | + ray.shutdown() |
| 171 | + |
| 172 | + |
153 | 173 | @pytest.fixture(scope="module")
|
154 | 174 | def client():
|
155 | 175 | client = openai.AsyncOpenAI(
|
@@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
|
890 | 910 | or "less_than_equal" in exc_info.value.message)
|
891 | 911 |
|
892 | 912 |
|
| 913 | +@pytest.mark.parametrize( |
| 914 | + "model_name", |
| 915 | + [EMBEDDING_MODEL_NAME], |
| 916 | +) |
| 917 | +async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, |
| 918 | + model_name: str): |
| 919 | + input = [ |
| 920 | + "The chef prepared a delicious meal.", |
| 921 | + ] |
| 922 | + |
| 923 | + # test single embedding |
| 924 | + embeddings = await client.embeddings.create( |
| 925 | + model=model_name, |
| 926 | + input=input, |
| 927 | + encoding_format="float", |
| 928 | + ) |
| 929 | + assert embeddings.id is not None |
| 930 | + assert embeddings.data is not None and len(embeddings.data) == 1 |
| 931 | + assert len(embeddings.data[0].embedding) == 4096 |
| 932 | + assert embeddings.usage.completion_tokens == 0 |
| 933 | + assert embeddings.usage.prompt_tokens == 9 |
| 934 | + assert embeddings.usage.total_tokens == 9 |
| 935 | + |
| 936 | + # test using token IDs |
| 937 | + input = [1, 1, 1, 1, 1] |
| 938 | + embeddings = await client.embeddings.create( |
| 939 | + model=model_name, |
| 940 | + input=input, |
| 941 | + encoding_format="float", |
| 942 | + ) |
| 943 | + assert embeddings.id is not None |
| 944 | + assert embeddings.data is not None and len(embeddings.data) == 1 |
| 945 | + assert len(embeddings.data[0].embedding) == 4096 |
| 946 | + assert embeddings.usage.completion_tokens == 0 |
| 947 | + assert embeddings.usage.prompt_tokens == 5 |
| 948 | + assert embeddings.usage.total_tokens == 5 |
| 949 | + |
| 950 | + |
| 951 | +@pytest.mark.parametrize( |
| 952 | + "model_name", |
| 953 | + [EMBEDDING_MODEL_NAME], |
| 954 | +) |
| 955 | +async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, |
| 956 | + model_name: str): |
| 957 | + # test List[str] |
| 958 | + inputs = [ |
| 959 | + "The cat sat on the mat.", "A feline was resting on a rug.", |
| 960 | + "Stars twinkle brightly in the night sky." |
| 961 | + ] |
| 962 | + embeddings = await client.embeddings.create( |
| 963 | + model=model_name, |
| 964 | + input=inputs, |
| 965 | + encoding_format="float", |
| 966 | + ) |
| 967 | + assert embeddings.id is not None |
| 968 | + assert embeddings.data is not None and len(embeddings.data) == 3 |
| 969 | + assert len(embeddings.data[0].embedding) == 4096 |
| 970 | + |
| 971 | + # test List[List[int]] |
| 972 | + inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], |
| 973 | + [25, 32, 64, 77]] |
| 974 | + embeddings = await client.embeddings.create( |
| 975 | + model=model_name, |
| 976 | + input=inputs, |
| 977 | + encoding_format="float", |
| 978 | + ) |
| 979 | + assert embeddings.id is not None |
| 980 | + assert embeddings.data is not None and len(embeddings.data) == 4 |
| 981 | + assert len(embeddings.data[0].embedding) == 4096 |
| 982 | + assert embeddings.usage.completion_tokens == 0 |
| 983 | + assert embeddings.usage.prompt_tokens == 17 |
| 984 | + assert embeddings.usage.total_tokens == 17 |
| 985 | + |
| 986 | + |
893 | 987 | if __name__ == "__main__":
|
894 | 988 | pytest.main([__file__])
|
0 commit comments