Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: distributed_executor_backend=mp does not work with GPTQ tp>1 #6004

Closed
robertgshaw2-redhat opened this issue Jun 30, 2024 · 5 comments · Fixed by #6007
Closed

[Bug]: distributed_executor_backend=mp does not work with GPTQ tp>1 #6004

robertgshaw2-redhat opened this issue Jun 30, 2024 · 5 comments · Fixed by #6007
Assignees
Labels
bug Something isn't working

Comments

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Jun 30, 2024

Your current environment

The output of `python collect_env.py`

🐛 Describe the bug

distributed_executor_backed="mp" is now enabled by default for vLLM. However, this feature is currently incompatible with some GPTQ quantization for tp>1 due to the order in which torch is initialized. We get the classic RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Setting distributed_backend_executor="ray" works for GPTQ

The following fails:

from vllm import LLM

MODEL_NAME="TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
TENSOR_PARALLEL_SIZE=2

model = LLM(MODEL_NAME, enforce_eager=True, tensor_parallel_size=TENSOR_PARALLEL_SIZE, distributed_executor_backend="mp")
print(model.generate("The best thing about the internet is")[0].outputs[0].text)

with:

(vllm-upstream) rshaw@beaker:~/vllm$ python3 run.py 
INFO 06-30 14:26:18 gptq_marlin.py:140] The model is convertible to gptq_marlin during runtime. Using gptq_marlin kernel.
INFO 06-30 14:26:18 llm_engine.py:169] Initializing an LLM engine (v0.5.0.post1) with config: model='TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ', speculative_config=None, tokenizer='TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ)
(VllmWorkerProcess pid=3438507) Process VllmWorkerProcess:
(VllmWorkerProcess pid=3438507) Traceback (most recent call last):
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/.pyenv/versions/3.10.14/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
(VllmWorkerProcess pid=3438507)     self.run()
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/.pyenv/versions/3.10.14/lib/python3.10/multiprocessing/process.py", line 108, in run
(VllmWorkerProcess pid=3438507)     self._target(*self._args, **self._kwargs)
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm/vllm/executor/multiproc_worker_utils.py", line 210, in _run_worker_process
(VllmWorkerProcess pid=3438507)     worker = worker_factory()
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm/vllm/executor/gpu_executor.py", line 67, in _create_worker
(VllmWorkerProcess pid=3438507)     wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm/vllm/worker/worker_base.py", line 311, in init_worker
(VllmWorkerProcess pid=3438507)     self.worker = worker_class(*args, **kwargs)
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm/vllm/worker/worker.py", line 86, in __init__
(VllmWorkerProcess pid=3438507)     self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm/vllm/worker/model_runner.py", line 196, in __init__
(VllmWorkerProcess pid=3438507)     self.attn_backend = get_attn_backend(
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm/vllm/attention/selector.py", line 45, in get_attn_backend
(VllmWorkerProcess pid=3438507)     backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm/vllm/attention/selector.py", line 151, in which_attn_to_use
(VllmWorkerProcess pid=3438507)     if torch.cuda.get_device_capability()[0] < 8:
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm-upstream/lib/python3.10/site-packages/torch/cuda/__init__.py", line 430, in get_device_capability
(VllmWorkerProcess pid=3438507)     prop = get_device_properties(device)
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm-upstream/lib/python3.10/site-packages/torch/cuda/__init__.py", line 444, in get_device_properties
(VllmWorkerProcess pid=3438507)     _lazy_init()  # will define _get_device_properties
(VllmWorkerProcess pid=3438507)   File "/home/rshaw/vllm-upstream/lib/python3.10/site-packages/torch/cuda/__init__.py", line 279, in _lazy_init
(VllmWorkerProcess pid=3438507)     raise RuntimeError(
(VllmWorkerProcess pid=3438507) RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
ERROR 06-30 14:26:19 multiproc_worker_utils.py:120] Worker VllmWorkerProcess pid 3438507 died, exit code: 1
INFO 06-30 14:26:19 multiproc_worker_utils.py:123] Killing local vLLM worker processes

Workarounds for now:

  • Set distributed_executor_backend="mp"
@robertgshaw2-redhat robertgshaw2-redhat added the bug Something isn't working label Jun 30, 2024
@robertgshaw2-redhat robertgshaw2-redhat changed the title [Bug]: distributed_executor_backend=mp does not work with quantization [Bug]: distributed_executor_backend=mp does not work with GPTQ tp>1 Jun 30, 2024
@DarkLight1337
Copy link
Member

It would be great if we can better control when cuda is initialized, not sure whether that's feasible though.

Despite @youkaichao 's fix to the distributed tests, it's quite a headache to ensure that CUDA is not accidentally initialized wrongly during those tests.

@robertgshaw2-redhat
Copy link
Collaborator Author

For sure - Im going to look into this when I get some time. Its specific to GPTQ (does not happen for fp8 quantization). I think the source is that we check cuda_device_capability when deciding if we can convert GPTQ-->Marlin. I think this happens "too early" in the lifecycle

Will look into a workaround when I get some time. I have a couple PRs I want to wrap up before I look into this

@llmpros
Copy link
Contributor

llmpros commented Jun 30, 2024

met the same issue - happy to poke around / peek if folks are busy with higher priority tasks

@robertgshaw2-redhat
Copy link
Collaborator Author

Feel free. I think this function is the culprit --- initializes torch:

It is called by this function:

Which is called by this function:

I think solutions are:

  • trying to see if there is a way to get the cuda device capability without calling torch.cuda

@youkaichao
Copy link
Member

we can use pynvml to check the compute capability, without calling torch.cuda (which will initialize cuda context).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants