Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: sending request using response_format json twice breaks vLLM #4070

Open
samos123 opened this issue Apr 14, 2024 · 13 comments
Open

[Bug]: sending request using response_format json twice breaks vLLM #4070

samos123 opened this issue Apr 14, 2024 · 13 comments
Labels
bug Something isn't working

Comments

@samos123
Copy link
Contributor

Your current environment

Collecting environment information...
PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.29.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.133+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L4
Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             8
On-line CPU(s) list:                0-7
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 4
Socket(s):                          1
Stepping:                           7
BogoMIPS:                           4400.44
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          128 KiB (4 instances)
L1i cache:                          128 KiB (4 instances)
L2 cache:                           4 MiB (4 instances)
L3 cache:                           38.5 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT Host state unknown

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.1.2
[pip3] triton==2.1.0
[conda] Could not collectROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.0.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
�[4mGPU0	CPU Affinity	NUMA Affinity	GPU NUMA ID�[0m
GPU0	 X 	0-7	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

vLLM gets into a corrupted state and only responds garbage after sending a specific response_format = json request. The first request vllm is able to respond with a somewhat reasonable response but once you repeat the same request it only starts responding with \n\t\t\t\t... where \t repeats until max_tokens is reached.

Steps to reproduce:

  1. Deploy vLLM v0.4.0-post1 with openai compatible API endpoint and mistral v0.2 instruct from HF. Model ID: mistralai/Mistral-7B-Instruct-v0.2. The following config was used:
gpuMemoryUtilization: "0.90"
maxModelLen: 16752

This is running on a single L4 GPU

  1. Send the following request twice
curl http://localhost:8080/v1/completions  -H "Content-Type: application/json" -d @request-body-sanitized.json

Current results:

  1. The first results will show a result like this:
{"id":"cmpl-e7eef8ba397b4dab8c396a0611773178","object":"text_completion","created":1713118051,"model":"mistral-7b-instruct-v0.2","choices":[{"index":0,"text":"\n{\n   \t\"isvalid\"\t:\ttrue,\n   \t\"summary\"\t:\t\"A large l
anguage model (LLM) is a type of AI model that can generate and process natural language. It learns from text do
cuments and can be used for tasks like text generation and classification. LLMs can predict the next word or tok
en in a text input.\"\n}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n

\n will repeat until max tokens is hit

  1. The second request will show a response that only shows \n\t\t\t\t where \t repeats until max token is hit.

Ocasionally vLLM gets into a bad state where all requests returns errors as well, but I can't consistently get into that state. The following errors were seen when that happens:

lark.exceptions.UnexpectedToken: Unexpected token Token('LBRACE', '{') at line 11, column 2.                    
Expected one of:                                                                                                
        * RBRACE                                                                                                
        * UNESCAPED_STRING                                                                                      
Previous tokens: [Token('LBRACE', '{')] 
lark.exceptions.UnexpectedCharacters: No terminal matches '{' in the current parser context, at line 11 col 2   
                                                                                                               
{{{{{{{{{{     

This issue was originally reported in Lingo: substratusai/kubeai#96 but it seems to be an issue with vLLM itself.

Expected results

vLLM should not get into a broken state where subsequent responses do not provide any results due to using response_format = json

@samos123 samos123 added the bug Something isn't working label Apr 14, 2024
@samos123
Copy link
Contributor Author

request-body-sanitized.json
Attaching the request body that was used to reproduce

@samos123 samos123 changed the title [Bug]: response_format json corrupts vLLM [Bug]: sending request using response_format json twice breaks vLLM Apr 14, 2024
@simon-mo
Copy link
Collaborator

Looks like some issues with outlines's fsm copying and initialization. for this case using lm-format-enforcer might be better. #3868

@robertgshaw2-neuralmagic
Copy link
Collaborator

There seems to be a bug with:

response_format = { 
   "type": "json_object" 
},
  • Sample Client:
from openai import OpenAI
import os

def prompt_json_completion(messages):    
    base_url = os.getenv("BASE_URL", "http://localhost:8000/v1")
    api_key = os.getenv("API_KEY", "EMPTY")
    max_tokens = os.getenv("MAX_TOKENS", 100)

    client = OpenAI(api_key = api_key, base_url = base_url)
    completion = client.chat.completions.create(
        model = client.models.list().data[0].id,            
        # response_format = { 
        #     "type": "json_object" 
        # },
        messages = messages,
        max_tokens = max_tokens,
    )
    #print(completion)
    print(completion.choices[0].message.content)

if __name__ == "__main__":
    user_prompt = "Generate example JSON data of a student in an SIS"
    messages = [
                {"role": "user", "content": user_prompt}
            ]
    prompt_json_completion(messages=messages)

I am getting all whitespace if I uncomment response_format

@nydpy
Copy link

nydpy commented Apr 15, 2024

I have the same error with json_object did anyone encounter this error with previous version?

@XeonHis
Copy link

XeonHis commented Apr 26, 2024

same problem when setting "response_format" to {"type": "json_object"}, text generation stops when reaching max model length. When seting "response_format" to {"type": "text"}, everything goes well.
Model: Mistral-7B-Instruct-v0.2-Function-Calling
vllm: 0.4.1

@JGSweets
Copy link
Contributor

JGSweets commented Jun 6, 2024

Outlines has made several improvements to its json output and was previously fixed to outlines==0.0.34.

These issues might have been fixed with the nightly:

outlines >= 0.0.43 # Requires torch >= 2.1.0

@maxdebayser
Copy link
Contributor

I think that PR #4109 that was merged into main fixes this issue. (@br3no)

@JGSweets
Copy link
Contributor

@maxdebayser I recently tried v0.5.0.post1 and vLLM + outlines still exhibits the issue of producing \t \n repeatedly until max_length when specifying {"type": "json_object"}.

@Randolph-zeng
Copy link

Yes, I am also having the same problem with versions
vllm 0.4.2 vllm-nccl-cu12 2.18.1.0.4.0
Curious that what is the cause of this issue and any workaround ? Will definitely appreciate any pointers

@CNYoki
Copy link

CNYoki commented Jul 24, 2024

Same as you. I have to give up respone_format.

@K-Mistele
Copy link
Contributor

Hi guys!
this probably isn't an outlines or lm-format-enforcer issues, but an issue with guided decoding. Copypasta from my response to #8020:

important to note: When you're using json_object or json_schema in response_format, you must instruct the model to produce JSON if you want good results; and you will get the best results if you tell it to produce JSON and you give it an example of what you want.

The following guidance is from the OpenAI docs, but it applies to vLLM as well:

Screenshot 2024-09-27 at 2 44 03 PM

Basically, if you try to force the model to generate JSON when it's trying to generate natural text, it may produce just whitespace if whitespace tokens are more likely than a { token in the logprobs, since valid JSON can include whitespace before or after. This is very likely if you use json_object or json_schem without telling the model that you want JSON output, and what you want that output to look like, either in the system prompt or in a user message.

Hope that helps!

@K-Mistele
Copy link
Contributor

It may be worth adding something about this in the vLLM docs -- seems to be a point of confusion; have been discussing this with the Nous team too

@gohar94
Copy link

gohar94 commented Nov 5, 2024

Same issue when using nvidia/NVLM-D-72B and llava-hf/llava-1.5-7b-hf -- not passing response_format seems to work for now but that does not seem like a proper solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests