Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model]: Add transformers backend support #11330

Open
wants to merge 105 commits into
base: main
Choose a base branch
from

Conversation

ArthurZucker
Copy link

@ArthurZucker ArthurZucker commented Dec 19, 2024

Adds support for transformers as a backend

Following huggingface/transformers#35235, a bunch of models should already be supported, we are ramping up support for more models.

Thanks @Isotr0py for the TP support, and @hmellor for his help as well!
This includes:

  • trust_remote_code=True support: any model on the hub, if it implements attention the correct way can be natively supported!!
  • tensor parallel support

ArthurZucker and others added 2 commits December 19, 2024 10:33
Co-authored-by: Isotr0py <41363108+Isotr0py@users.noreply.github.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Dec 19, 2024
@ywang96
Copy link
Member

ywang96 commented Dec 19, 2024

Hello @ArthurZucker! This is very exciting!

I know this PR is still a draft, but could you provide some context on the scope of this effort? Is it to support any model on transformers?

@Isotr0py Isotr0py mentioned this pull request Dec 19, 2024
40 tasks
@ArthurZucker
Copy link
Author

Yep, overall this should support any model that is supported in transformers, were the cache is "simple" so for now, most of the decoder models and the encoder models for a single modularity!
For MultiModal models, we might need a little bit of extra work, but I thing LLAVA models should work out of the box!

We are refactor our models to make sure it's propagated to as many models as possible!

@ArthurZucker
Copy link
Author

Might not have time to finish this week, will make it ready for next week 🎄
This should be minimal (no support fort Lora or at least I am not testing it ! This might. need to either call transformers's from pretrain, or replace modules similarly to TP)

@simon-mo simon-mo mentioned this pull request Jan 9, 2025
37 tasks
hmellor added 13 commits January 9, 2025 11:35
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
…orted

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@hmellor
Copy link
Collaborator

hmellor commented Jan 16, 2025

Benchmarks on A100 using the following command:

python benchmarks/benchmark_throughput.py --backend vllm --model meta-llama/Llama-3.1-8B-Instruct --dataset ShareGPT_V3_unfiltered_cleaned_split.json

Results:

Class Result
LlamaForCausalLM Throughput: 12.88 requests/s, 5325.05 total tokens/s, 2554.02 output tokens/s
TransformersModel Throughput: 11.38 requests/s, 4705.90 total tokens/s, 2257.06 output tokens/s

hmellor and others added 2 commits January 16, 2025 19:08
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@hmellor
Copy link
Collaborator

hmellor commented Jan 30, 2025

Seems like the only relevant failure is:

[2025-01-30T08:42:53Z] FAILED models/test_registry.py::test_registry_imports[TransformersModel] - KeyError: 'TransformersModel'
[2025-01-30T08:42:53Z] FAILED models/test_registry.py::test_hf_registry_coverage - AssertionError: Please add the following architectures to `tests/models/registry.py`: {'TransformersModel'}
[2025-01-30T08:42:53Z] assert not {'TransformersModel'}

This passes locally, but fails in CI, not sure why. I have reordered the pytest calls temporarily so we can see the test_transformers.py tests in CI.

hmellor and others added 3 commits January 30, 2025 19:14
@ArthurZucker
Copy link
Author

Just waiting for the Cis

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, LGTM! I've tested locally with Llama by running gsm8k evals, where I see good accuracy and slightly less throughput as we would expect.

vLLM impl:

lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto
...
Processed prompts: 100%|████████████| 1319/1319 [00:45<00:00, 29.09it/s, est. speed input: 25375.03 toks/s, output: 2839.50 toks/s]
Running generate_until requests: 100%|████████████| 1319/1319 [00:45<00:00, 28.98it/s]
2025-01-31:20:32:00,673 INFO     [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7801|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7566|±  |0.0118|

Transformers impl:

lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,model_impl=transformers --tasks gsm8k --num_fewshot 5 --batch_size auto
...
Processed prompts: 100%|█████████████| 1319/1319 [00:54<00:00, 24.36it/s, est. speed input: 21247.29 toks/s, output: 2378.45 toks/s]
Running generate_until requests: 100%|███████████████| 1319/1319 [00:54<00:00, 24.25it/s]
2025-01-31:20:30:10,932 INFO     [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,model_impl=transformers), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7763|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7521|±  |0.0119|

One thing I would like to note is that it seems V1 is not supported yet. Running VLLM_USE_V1=1 with Llama results in an error about the input_embeds forward pass arg. This is only used for multimodal models currently so we could get around this for this case by ignoring it.

ERROR 01-31 20:27:32 core.py:208]   File "/home/mgoin/code/vllm/vllm/v1/worker/gpu_model_runner.py", line 870, in _dummy_run
ERROR 01-31 20:27:32 core.py:208]     hidden_states = model(
ERROR 01-31 20:27:32 core.py:208]                     ^^^^^^
ERROR 01-31 20:27:32 core.py:208]   File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-31 20:27:32 core.py:208]     return self._call_impl(*args, **kwargs)
ERROR 01-31 20:27:32 core.py:208]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-31 20:27:32 core.py:208]   File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-31 20:27:32 core.py:208]     return forward_call(*args, **kwargs)
ERROR 01-31 20:27:32 core.py:208]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-31 20:27:32 core.py:208] TypeError: TransformersModel.forward() got an unexpected keyword argument 'inputs_embeds'

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 31, 2025
@DarkLight1337
Copy link
Member

#12599 has been merged, can you merge from main to fix the merge conflicts?

Copy link

mergify bot commented Feb 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ArthurZucker.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 1, 2025
Isotr0py and others added 2 commits February 1, 2025 14:07
Signed-off-by: Isotr0py <2037008807@qq.com>
@mergify mergify bot removed the needs-rebase label Feb 1, 2025
@Isotr0py
Copy link
Collaborator

Isotr0py commented Feb 1, 2025

transformers backend should work with V1 now:

INFO 02-01 14:05:23 __init__.py:183] Automatically detected platform cuda.
WARNING 02-01 14:05:25 arg_utils.py:1325] Setting max_num_batched_tokens to 8192 for LLM_CLASS usage context.
INFO 02-01 14:05:31 config.py:540] This model supports multiple tasks: {'classify', 'score', 'reward', 'embed', 'generate'}. Defaulting to 'generate'.
INFO 02-01 14:05:31 config.py:1508] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 02-01 14:05:32 core.py:45] Initializing a V1 LLM engine (v0.1.dev4338+g2079e43) with config: model='../Llama-3.2-1B-Instruct', speculative_config=None, tokenizer='../Llama-3.2-1B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=../Llama-3.2-1B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 02-01 14:05:33 registry.py:336] `mm_limits` has already been set for model=../Llama-3.2-1B-Instruct, and will be overwritten by the new values.
INFO 02-01 14:05:33 gpu_model_runner.py:843] Starting to load model ../Llama-3.2-1B-Instruct...
INFO 02-01 14:05:33 transformers.py:129] Using Transformers backend.
INFO 02-01 14:05:33 cuda.py:162] Using Flash Attention backend on V1 engine.
WARNING 02-01 14:05:33 topk_topp_sampler.py:44] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
WARNING 02-01 14:05:33 config.py:3361] `torch.compile` is turned on, but the model ../Llama-3.2-1B-Instruct does not support it. Please open an issue on GitHubif you want it to be supported.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.40s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.40s/it]

INFO 02-01 14:05:36 gpu_model_runner.py:848] Loading model weights took 2.3029 GB
INFO 02-01 14:05:37 kv_cache_utils.py:395] # GPU blocks: 11864
INFO 02-01 14:05:39 gpu_model_runner.py:1019] Graph capturing finished in 2 secs, took 0.09 GiB
INFO 02-01 14:05:39 core.py:89] init engine (profile, create kv cache, warmup model) took 3.11 seconds
Processed prompts: 100%|████████████████████████████████████████████████| 4/4 [00:00<00:00,  7.99it/s, est. speed input: 51.99 toks/s, output: 127.97 toks/s]
Prompt: 'Hello, my name is', Generated text: ' Emma and I am a 4th grade student. I just wanted to say'
Prompt: 'The president of the United States is', Generated text: ' the head of state and the head of government. He or she is typically the'
Prompt: 'The capital of France is', Generated text: ' Paris. The French word for the phrase "good morning" is "bonjour'
Prompt: 'The future of AI is', Generated text: ' multifaceted, and researchers are exploring various applications that will benefit society. One'

Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get this merged first! We can add BNB and LoRA support in other following PR.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 1, 2025 06:22
Signed-off-by: Isotr0py <2037008807@qq.com>
@DarkLight1337
Copy link
Member

Please fix the failing tests

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
@DarkLight1337
Copy link
Member

Please also add the distributed transformers test to the distributed tests CI

Signed-off-by: Isotr0py <2037008807@qq.com>
Isotr0py and others added 2 commits February 2, 2025 10:50
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
@Isotr0py Isotr0py enabled auto-merge (squash) February 2, 2025 02:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants