From f387247b4e52187226dec232de77f720e4137fce Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Jul 2024 09:46:05 -0700 Subject: [PATCH] [BugFix] Fix RoPE error in Llama 3.1 (#6693) --- vllm/config.py | 53 +++++++++---------- .../model_executor/layers/rotary_embedding.py | 7 +-- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 069963053045f..c27d26c098b59 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -154,15 +154,6 @@ def __init__( self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072 - and getattr(self.hf_config, "rope_scaling", None) is None): - # Note(simon): this is a special case for a model that doesn't - # supply rope_scaling. We should remove this once the model is - # updated. - self.hf_config.update({"rope_scaling": { - "type": "extended", - }}) - if (not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.sliding_window is not None): @@ -1492,24 +1483,32 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - # The correct one should be "longrope", kept "su" here - # to be backward compatible - if rope_scaling is not None and rope_scaling["type"] not in { - "su", "longrope", "extended" - }: - if disable_sliding_window: - # TODO(robertgshaw): Find a model that supports rope_scaling - # with sliding window to see if this case should be allowed. - raise NotImplementedError( - "Disabling sliding window is not supported for models " - "with rope_scaling. Please raise an issue so we can " - "investigate.") - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] - derived_max_model_len *= scaling_factor + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 3f9573f550341..60ba4623edc38 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -794,12 +794,13 @@ def get_rope( rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: - scaling_type = rope_scaling["type"] + scaling_type = rope_scaling[ + "type"] if "type" in rope_scaling else rope_scaling["rope_type"] # The correct one should be "longrope" but keep "su" here # for backward compatible - if scaling_type not in {"su", "longrope", "extended"}: + if scaling_type not in {"su", "longrope", "llama3"}: scaling_factor = rope_scaling["factor"] - if scaling_type == "extended": + if scaling_type == "llama3": rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype)