Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IMPORTANT Bug: Model return empty response (output len = 0), when recieved multiple concurrent request. #3209

Closed
tattrongvu opened this issue Mar 5, 2024 · 21 comments

Comments

@tattrongvu
Copy link

tattrongvu commented Mar 5, 2024

When I did a bunch of load test the vLLM endpoint with OpenAI API, I see that the server will return 20% to 50% empty responses when it recieved multiple concurrent request.
Configuration:

vLLM==0.3.0
Model: Zephyr-7b-beta, although it even worse with bigger model like Llama-2-70b and Mixtral 8x7b
Cuda 12.2
1 x A100 80GB GPU

Number of concurrent request: 100, request rate set as default = "inf", meaning 100 request will be send concurrently.

How to replicate:
I'm using latest benchmark_serving.py (5.3.2024) with following modification:
I comment out L69-L72: https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py#L69
and modify following:

  • Line 88 to:
if prompt_len < MIN_PROMPT_LEN or output_len < MIN_OUTPUT_LEN:
  • Line 93 to:
if prompt_len > MAX_PROMPT_LEN or output_len > MAX_OUTPUT_LEN:

and set the params like that:

MIN_PROMPT_LEN = 400
MAX_PROMPT_LEN = 700
MAX_OUTPUT_LEN= 300
MIN_OUTPUT_LEN= 100

Moreover, since in the current code of https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py#L265
doesn't check for the output len, I need to add 2 more line of code like this:

if generated_text=="":
   output.success = False
else:
   output.success = True

With above configuration, the success rate is 76 / 100, there are 24 request returned with empty response.

NOTE:
I tried replace the aiohttp with httpx, openai package as http client. The problem persist.
Moreover, the emtpy response happen even with small number of concurrent request. If we send 5 or 10 request concurrently, we will get 2, 3 empty responses.

I tested with Mixtral 8x7b, Llama-2-70b unquantized, the number of empty response even worse., up to 50%.

My suggestion is put a Rate Limiter (maybe just a decorator or await rate_limit() function) to the api_server.py to limit the request rate that server can handle to some predefine number like 10 request / 5 seconds,...I could provide a MR later if needed.

Please take a look and let me know if something wrong here.
Thanks in advance.

@tattrongvu
Copy link
Author

tattrongvu commented Mar 5, 2024

  • Here is the result image of Zephyr-7b with 1 x A100 GPU (80gb)

  • Success rate: 76%
    image

  • This is for Llama-2-70b-instruct with 4 x A100 GPU (80gb each, total 320gb vram)

  • Success rate: 65%
    image

  • This is for Mixtral 8x7b instruct with 2 x A100 GPU (80gb each, total 160GB vram)

  • Success rate: 72%
    image

@ywang96
Copy link
Member

ywang96 commented Mar 5, 2024

I don't think the model is returning empty responses - but rather these requests were just failing due to some error.

Each output.generated_text started as an empty string and if an error happens or a non-200 response code is received, it will stay as an empty string with output.success=False.

Regarding rate limiter - I think that's a good point and it's also mentioned in #3127. I'm curious is there any reason why you're not setting --request-rate ?

@tattrongvu
Copy link
Author

tattrongvu commented Mar 5, 2024

Hi, thanks for your quick response.
As said above, I modifed https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py#L265 to add

if generated_text=="":
   output.success = False
else:
   output.success = True

And it is INSIDE the if statement:

if response.status == 200: 

(https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py#L248)

It mean that even the response.status == 200, the model output empty text.

I'm setting the request-rate to "inf" in the benchmark_serving.py to simulate concurrent user.

But what I suggest is put the Rate Limiter directly to the API endpoint code in openai/api_server.py,
NOT in the benchmark_serving.py. with the hope that it will fix the problem by reduce the rate the endpoint must be handle in parallel.

@ywang96
Copy link
Member

ywang96 commented Mar 5, 2024

Moreover, since in the current code of https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py#L265
doesn't check for the output len, I need to add 2 more line of code like this:

We actually do when we calculate the metrics.

output_len = len(tokenizer.encode(outputs[i].generated_text))

@tattrongvu
Copy link
Author

tattrongvu commented Mar 5, 2024

Moreover, since in the current code of https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py#L265
doesn't check for the output len, I need to add 2 more line of code like this:

We actually do when we calculate the metrics

Yes I also saw that, but the problem is when calculate the metrics, if you just add zero in the output lens it just reduce the final average throughput token / s but it doesn't show any warning or error.
I think it make more sense to explicitly show error or warning that we recieved a response with empty text.

@ywang96
Copy link
Member

ywang96 commented Mar 5, 2024

you just add zero in the output lens it just reduce the final average throughput token / s but it doesn't show any warning or error.

That's true, but there are two separate issues we should be tracking here:

  • if the model is actually generating empty responses (in this case, output.success = True and the 0 output len will be added, and we can do a warning if an empty output is generated.)
  • if the request fails (in this case output.success = False and it'll be disregarded when we do the metric calculation).

Error logging makes sense to me and I can add it to my current PR in progress #3194

@tattrongvu
Copy link
Author

tattrongvu commented Mar 5, 2024

thanks, would be great if you update what I identifed above to your PR of benchmark_serving.py.

But the main point here is let see if we could know why the model return empty response and fix it.

As far as I know, we already have a requests queue handle by Async-LLM-Engine here: https://github.com/vllm-project/vllm/blob/main/vllm/engine/async_llm_engine.py#L130
So not sure which parameters we should tune to fix this problem.

Or I can just put explicily another Rate Limit Queue decorator above the POST API at https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py#L178

@ywang96
Copy link
Member

ywang96 commented Mar 5, 2024

It mean that even the response.status == 200, the model output empty text.

That's correct, but for an empty output, how do you currently differentiate between a connection failure and the model actually generating nothing?

@tattrongvu
Copy link
Author

how do you currently differentiate between a connection failure and the model actually generating nothing?

So could it happen when response.status == 200 and the connection failure somehow ? I though the response status is use to determine that, no?

@ywang96
Copy link
Member

ywang96 commented Mar 5, 2024

What I'm saying is the aiohttp.ClientSession that's initialized to send the post request has a disconnected error during the streaming phase. AFIAK by default there's a limit on how many aiohttp.ClientSession you can initialize from a single IP address, going beyond that will give you an error from aiohttp.

except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False

@tattrongvu
Copy link
Author

tattrongvu commented Mar 5, 2024

how many aiohttp.ClientSession you can initialize from a single IP address, going beyond that will give you an error from aiohttp

Good point, lets check how we can explicitly set this numbers.

But I even replaced the aiohttp that currently used in benchmark_serving.py, I tried with OpenAI python package, and tried with HTTPX (which is actually used in the back of openai package), the empty response persist with different client package.

And it happen with small number of concurrent requests as well, even 5 or 10 concurrent request, we will get 2 or 3 of them is empty response.

@akshatjagga
Copy link

akshatjagga commented Mar 5, 2024

Consistently facing the same issue, even without concurrent requests

@ywang96
Copy link
Member

ywang96 commented Mar 5, 2024

But I even replaced the aiohttp that currently used in benchmark_serving.py, I tried with OpenAI python package, and tried with HTTPX (which is actually used in the back of openai package), the empty response persist with different client package.

That's more alarming indeed. I'll incorporate some changes in my PR to try to eliminate the possibilities that these errors are caused by the benchmark script itself rather than the actual model server.

@ywang96
Copy link
Member

ywang96 commented Mar 6, 2024

I've added changes in #3194 to capture the request errors and actual output length and you can examine after in the result json afterwards by specifying the --save-result flag. I am able to run at least 400 requests with inf traffic on a mixtral deployment on 2 A100-80G GPUs and all the responses look fine.

Command:

python3 vllm/benchmarks/benchmark_serving.py \
               --backend openai \
               --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
               --num-prompts 400 \
               --dataset-name sonnet \
               --base-url http://mixtral-vllm.local \
               --endpoint /v1/completions \
               --dataset-path vllm/benchmarks/sonnet.txt \
               --save-result

Results:

Traffic request rate: inf
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:52<00:00,  7.63it/s]
Successful requests: 400
Benchmark duration: 52.406242 s
Total input tokens: 207627
Total generated tokens: 54717
Request throughput: 7.63 requests/s
Input token throughput: 3961.88 tokens/s
Output token throughput: 1044.09 tokens/s
Mean TTFT: 18562.55 ms
Median TTFT: 13210.84 ms
P99 TTFT: 42933.43 ms
Mean TPOT: 168.24 ms
Median TPOT: 162.74 ms
P99 TPOT: 431.18 ms

@tattrongvu
Copy link
Author

tattrongvu commented Mar 6, 2024

@ywang96 I see that you use another dataset named sonnet instead of ShareGPT, and the request per second is 7.63 instead of 1.72 as my result. If possible, could you run on ShareGPT with my configuration of 500 input and about 150 max output in each request?
I just checked the sonnet.txt in your MR, it is very small dataset indeed.

What your promp input & max output mean?
In my experiments, I set prompt input in average of 500 tokens, max output mean about 160 tokens to simulate realistic RAG use case.

@ywang96
Copy link
Member

ywang96 commented Mar 6, 2024

@tattrongvu 550, 150 (you can also tell from dividing total inputs & outputs tokens by number of requests in the benchmark result print)

If you take a look at #3194 and use that version of benchmark_serving.py and backend_request.func.py, you should be able to run this same command with minor changes on the URL too. (I've also uploaded sonnet.txt)

@tattrongvu
Copy link
Author

@ywang96 ok, lets me run your modified benchmark in the sonnet and also again in the ShareGPT.

@tattrongvu
Copy link
Author

@ywang96 So, I just copy the "sample_sonnet_requests" function from your repo at https://github.com/ywang96/vllm/blob/add-prefix/benchmarks/benchmark_serving.py#L103 and put into my modifed benchmark as above and use it like that:

input_requests = sample_sonnet_requests(args.dataset, args.num_prompts, 500, 200, 15, tokenizer)

And the result:
image

So from the result, the good new is it completed all 400 requests without empty responses.
But the empty responses still there if I run with ShareGPT dataset.
That could indicate the problem maybe lay on the dataset, maybe the prompt input is not clear or it is already completed sentence,...idk.

Could you use ShareGPT on your setup to confirm?

Otherwise I see in my setup, I have 2.9 r/s instead of 7.2, maybe I should update to vllm 0.3.3 :)

@hahmad2008
Copy link

@tattrongvu @ywang96 guys any help?
#3230

@tattrongvu
Copy link
Author

I identified the problem.
The prompt that sample randomly from ShareGPT dataset contain many "completed sentences", with this sentences, when putting to the completion api of the models, it doesn't know how to continue the text. If I add some prefix like:

template = """
Given following context:
{}
Base on the given context, answer following question:
What is this document about?
"""

and then use above template to format the sampled prompt from ShareGPT dataset like that:

prompt = template.format(request_func_input.prompt)

The models will try to summarize the document, hence not output empty response any more.

@hahmad2008
Copy link

@tattrongvu but i already pass my prompt to a prompt template before feeding it to my model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants