Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Model] Skip loading lm_head weights if using tie_word_embeddings #6758

Merged
merged 3 commits into from
Aug 1, 2024

Conversation

tjohnson31415
Copy link
Contributor

In llama and other models with tie_word_embeddings, there seems to be cases where the weight files will include both the lm_head.weight and embed_tokens.weight tensors, particularly after tuning procedures. Attempting to load such weights results in an error like:

[rank0]:   File "/opt/vllm/lib64/python3.11/site-packages/vllm/worker/model_runner.py", line 553, in load_model
[rank0]:     self.model = get_model(model_config=self.model_config,
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/vllm/lib64/python3.11/site-packages/vllm/model_executor/model_loader/__init__.py", line 21, in get_model
[rank0]:     return loader.load_model(model_config=model_config,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/vllm/lib64/python3.11/site-packages/vllm/model_executor/model_loader/loader.py", line 278, in load_model
[rank0]:     model.load_weights(
[rank0]:   File "/opt/vllm/lib64/python3.11/site-packages/vllm/model_executor/models/llama.py", line 481, in load_weights
[rank0]:     param = params_dict[name]
[rank0]:             ~~~~~~~~~~~^^^^^^
[rank0]: KeyError: 'lm_head.weight'

This error arises when the model does not have an lm_head.weight in named_parameters() but the weights files include lm_head.weight. With tie_word_embeddings set to true, lm_head.weight should be a duplicate of embed_tokens.weight, so the extra tensor seems to be included in the .safetensors unnecessarily. The change in this PR is to ignore lm_head.weight when tie_word_embeddings is true.

#3553 is a related issue that saw the same issue for Gemma which always uses tied weights.

I was made aware of this error in regards to a fine-tune of ibm-granite/granite-3b-code-instruct, which is a llama architecture with tie_word_embeddings set to True. After understanding the cause, I looked for other model implementations that may have the same issue with:

grep -r tie_word_embeddings vllm/model_executor/models/

I found some model implementations with tie_word_embeddings already include a check to skip loading lm_head.weight (qwen2, starcoder2, falcon). I added a check to the other models that did not include that check in load_weights.

I only tested the fix with llama, but included the checks for the models that were missing it to head off future issues. Let me know if it would be preferable to only change this for llama and leave the other models for a later PRs.

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 26, 2024
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tjohnson31415!

@njhill
Copy link
Member

njhill commented Jul 29, 2024

@tjohnson31415 could you merge in latest main branch? I think that should fix the failing tests

* upstream/main: (66 commits)
  [Bugfix] Fix PaliGemma MMP (vllm-project#6930)
  [TPU] Fix greedy decoding (vllm-project#6933)
  [Kernel] Tuned int8 kernels for Ada Lovelace (vllm-project#6848)
  [Kernel] Fix marlin divide-by-zero warnings (vllm-project#6904)
  [ci] GHA workflow to remove ready label upon "/notready" comment (vllm-project#6921)
  [Kernel] Remove unused variables in awq/gemm_kernels.cu (vllm-project#6908)
  [Frontend] New `allowed_token_ids` decoding request parameter (vllm-project#6753)
  [Bugfix] Allow vllm to still work if triton is not installed. (vllm-project#6786)
  [TPU] Support tensor parallelism in async llm engine (vllm-project#6891)
  [Kernel] Fix deprecation function warnings squeezellm quant_cuda_kernel (vllm-project#6901)
  [Core] Reduce unnecessary compute when logprobs=None (vllm-project#6532)
  [Kernel] Tuned FP8 Kernels for Ada Lovelace (vllm-project#6677)
  [Model] Initialize support for InternVL2 series models (vllm-project#6514)
  [Misc] Pass cutlass_fp8_supported correctly in fbgemm_fp8 (vllm-project#6871)
  Add Nemotron to PP_SUPPORTED_MODELS (vllm-project#6863)
  [Kernel] Increase precision of GPTQ/AWQ Marlin kernel (vllm-project#6795)
  [TPU] Reduce compilation time & Upgrade PyTorch XLA version  (vllm-project#6856)
  [Docs] Add RunLLM chat widget (vllm-project#6857)
  [Model] Initial support for BLIP-2 (vllm-project#5920)
  [CI/Build][Doc] Update CI and Doc for VLM example changes (vllm-project#6860)
  ...
@njhill
Copy link
Member

njhill commented Jul 30, 2024

@tjohnson31415 could you do it one more time? :) another fix went in for the tensorizer tests

* upstream/main:
  [Build] Temporarily Disable Kernels and LoRA tests (vllm-project#6961)
  [core][misc] improve free_finished_seq_groups (vllm-project#6865)
  [Kernel] Remove scaled_fp8_quant kernel padding footgun (vllm-project#6842)
  [Bugfix] Fix tensorizer memory profiling bug during testing (vllm-project#6881)
  [OpenVINO] Updated OpenVINO requirements and build docs (vllm-project#6948)
  [Kernel] Squash a few more warnings (vllm-project#6914)
  [BugFix] Fix use of per-request seed with pipeline parallel (vllm-project#6698)
  [Doc] Super tiny fix doc typo (vllm-project#6949)
@TissueC
Copy link

TissueC commented Aug 1, 2024

A emergent solution for this issue: manually initialize the lm_head with word_embedding, and set the tie_word_embeddings in config.json as False.
Hope this PR could be merged ASAP

@njhill njhill merged commit 630dd9e into vllm-project:main Aug 1, 2024
63 checks passed
@tjohnson31415 tjohnson31415 deleted the skip-lm-head branch August 1, 2024 17:28
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
…ings (vllm-project#6758)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
davidthomas426 added a commit to davidthomas426/vllm that referenced this pull request Sep 13, 2024
Cherry-pick vllm-project#6758 fix: skip loading lm_head if tie_word_embeddings
davidthomas426 pushed a commit to davidthomas426/vllm that referenced this pull request Sep 13, 2024
…ings (vllm-project#6758)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants