Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Nvidia] Enable support for Pascal GPUs #4409

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jasonacox
Copy link
Contributor

@jasonacox jasonacox commented Apr 27, 2024

[Hardware][Nvidia] Enable support for Pascal GPUs (sm_60, sm_61)

FIX: #963 #1284

Related: #4290 #2635

--

This is a new PR as a placeholder in the hope that the wheel size >100MB request is someday granted. This only adds compute capability 6.0 and 6.1. Note: pytorch is now only supporting sm_60.

>>> torch.__version__
'2.2.1+cu121'
>>> torch.cuda.torch.cuda.get_arch_list()
['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
>>> 

Pascal Architecture

  • (+) SM60 or SM_60, compute_60 – Quadro GP100, Tesla P100, DGX-1 (Generic Pascal)
  • (+) SM61 or SM_61, compute_61– GTX 1080, GTX 1070, GTX 1060, GTX 1050, GTX 1030 (GP108), GT 1010 (GP108) Titan Xp, Tesla P40, Tesla P4, Discrete GPU on the NVIDIA Drive PX2
  • (-) SM62 or SM_62, compute_62 – Integrated GPU on the NVIDIA Drive PX2, Tegra (Jetson) TX2

Example test on 4 x P100 GPUs on CUDA 12.2 system:

# build
DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm-openai --no-cache

 # run
docker run -d \
    --shm-size=10.24gb \
    --gpus '"device=0,1,2,3"' \
    -v /data/models:/root/.cache/huggingface \
    --env "HF_TOKEN=xyz" \
    -p 8000:8000 \
    --restart unless-stopped \
    --name vllm-openai \
    vllm-openai \
    --host 0.0.0.0 \
    --model=mistralai/Mistral-7B-Instruct-v0.1 \
    --enforce-eager \
    --dtype=float \
    --gpu-memory-utilization 0.95 \
    --tensor-parallel-size=4

@sasha0552
Copy link
Contributor

@youkaichao
As I see it, pypi/support#3792 has been approved. Is it possible to merge this PR now?

@sasha0552
Copy link
Contributor

sasha0552 commented May 19, 2024

(From Release Tracker)

#4409 might need a little bit more discussion given what features are supported for Pascal GPUs and whether building from source might be a better option.

I've been using vLLM on my P40s every day for almost a month now, and everything works fine. triton didn't accept one of my patches (they said we dropped support for pre-A100 GPUs, so I think there will soon be problems with other older architectures as well.), so things that depend on triton and use the tl.dot operation won't work (prefix caching, for example). However, there is a patched triton (sasha0552/triton), and just installing the patched triton is easier than installing both the patched triton and the patched vLLM. Also considering that the basic functionality works fine without triton.

Maybe the patched triton could be shipped like nccl (although not installed by default)? The patch is very simple, and I don't think it would be hard to maintain. I can maintain support for Pascal GPUs, if needed (I'm not going to move on from these GPUs until better options become available for the price per VRAM GB).

P.S. Whoever is reading this, you might want to check out my project, which has pre-built vllm and triton wheels for Pascal GPUs (and also patches & build scripts).

@AslanEZ
Copy link

AslanEZ commented Jul 2, 2024

[Hardware][Nvidia] Enable support for Pascal GPUs (sm_60, sm_61)

FIX: #963 #1284

Related: #4290 #2635

--

This is a new PR as a placeholder in the hope that the wheel size >100MB request is someday granted. This only adds compute capability 6.0 and 6.1. Note: pytorch is now only supporting sm_60.

>>> torch.__version__
'2.2.1+cu121'
>>> torch.cuda.torch.cuda.get_arch_list()
['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
>>> 

Pascal Architecture

  • (+) SM60 or SM_60, compute_60 – Quadro GP100, Tesla P100, DGX-1 (Generic Pascal)
  • (+) SM61 or SM_61, compute_61– GTX 1080, GTX 1070, GTX 1060, GTX 1050, GTX 1030 (GP108), GT 1010 (GP108) Titan Xp, Tesla P40, Tesla P4, Discrete GPU on the NVIDIA Drive PX2
  • (-) SM62 or SM_62, compute_62 – Integrated GPU on the NVIDIA Drive PX2, Tegra (Jetson) TX2

Example test on 4 x P100 GPUs on CUDA 12.2 system:

# build
DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm-openai --no-cache

 # run
docker run -d \
    --shm-size=10.24gb \
    --gpus '"device=0,1,2,3"' \
    -v /data/models:/root/.cache/huggingface \
    --env "HF_TOKEN=xyz" \
    -p 8000:8000 \
    --restart unless-stopped \
    --name vllm-openai \
    vllm-openai \
    --host 0.0.0.0 \
    --model=mistralai/Mistral-7B-Instruct-v0.1 \
    --enforce-eager \
    --dtype=float \
    --gpu-memory-utilization 0.95 \
    --tensor-parallel-size=4

Does this mean I can't run vllm on a Tesla P4, Even a small model?

@jasonacox
Copy link
Contributor Author

Does this mean I can't run vllm on a Tesla P4, Even a small model?

@AslanEZ I believe the P4 has a compute capability of 6.1. This PR requests to add that. Have you tested?

@AslanEZ
Copy link

AslanEZ commented Jul 3, 2024

Does this mean I can't run vllm on a Tesla P4, Even a small model?

@AslanEZ I believe the P4 has a compute capability of 6.1. This PR requests to add that. Have you tested?

I have tested it by installing with pip. It didn't work.

[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/attention/backends/xformers.py", line 323, in forward [rank0]: output[num_prefill_tokens:] = PagedAttention.forward_decode( [rank0]: RuntimeError: CUDA error: no kernel image is available for execution on the device [rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. [rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1. [rank0]: Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

I intend to try your code now.

@AslanEZ
Copy link

AslanEZ commented Jul 3, 2024

Does this mean I can't run vllm on a Tesla P4, Even a small model?

@AslanEZ I believe the P4 has a compute capability of 6.1. This PR requests to add that. Have you tested?

Oh, it works! Thank you!

@dirkson
Copy link

dirkson commented Aug 12, 2024

Could we get an update on the status of this PR? I've been eagerly awaiting it, as I can't use vllm until it supports my hardware.

@sasha0552
Copy link
Contributor

@dirkson it was answered here #6434 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for compute capability <7.0
4 participants