From 79c5cc1460b8a08b4548a275292033ee93d36f47 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 11 Jun 2024 17:42:26 +0000 Subject: [PATCH] [Frontend] Customizable RoPE theta (#5197) --- tests/test_config.py | 7 ++++++- vllm/config.py | 4 +++- vllm/engine/arg_utils.py | 8 ++++++++ vllm/engine/llm_engine.py | 3 ++- vllm/transformers_utils/config.py | 13 ++++++++----- 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 7cbdaeca9c4d4..6c8af9d7966b4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -63,8 +63,9 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW -def test_rope_scaling(): +def test_rope_customization(): TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + TEST_ROPE_THETA = 16_000_000.0 LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} llama_model_config = ModelConfig( @@ -76,6 +77,7 @@ def test_rope_scaling(): seed=0, ) assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None + assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000 assert llama_model_config.max_model_len == 8192 llama_model_config = ModelConfig( @@ -86,9 +88,12 @@ def test_rope_scaling(): dtype="float16", seed=0, rope_scaling=TEST_ROPE_SCALING, + rope_theta=TEST_ROPE_THETA, ) assert getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING + assert getattr(llama_model_config.hf_config, "rope_theta", + None) == TEST_ROPE_THETA assert llama_model_config.max_model_len == 16384 longchat_model_config = ModelConfig( diff --git a/vllm/config.py b/vllm/config.py index c07597b5b74a7..7ffb93c19ede9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -93,6 +93,7 @@ def __init__( revision: Optional[str] = None, code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, @@ -113,6 +114,7 @@ def __init__( self.revision = revision self.code_revision = code_revision self.rope_scaling = rope_scaling + self.rope_theta = rope_theta # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: self.tokenizer_revision = revision @@ -132,7 +134,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling) + code_revision, rope_scaling, rope_theta) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f87ee13091187..cd29db7d7a9ed 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -53,6 +53,7 @@ class EngineArgs: revision: Optional[str] = None code_revision: Optional[str] = None rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False @@ -400,6 +401,12 @@ def add_cli_args( type=json.loads, help='RoPE scaling configuration in JSON format. ' 'For example, {"type":"dynamic","factor":2.0}') + parser.add_argument('--rope-theta', + default=None, + type=float, + help='RoPE theta. Use with `rope_scaling`. In ' + 'some cases, changing the RoPE theta improves the ' + 'performance of the scaled model.') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -630,6 +637,7 @@ def create_engine_config(self, ) -> EngineConfig: revision=self.revision, code_revision=self.code_revision, rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, tokenizer_revision=self.tokenizer_revision, max_model_len=self.max_model_len, quantization=self.quantization, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cb5893e707c8b..4f56bbd5c2dc5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -162,7 +162,7 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "rope_scaling=%r, tokenizer_revision=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " @@ -177,6 +177,7 @@ def __init__( model_config.tokenizer_mode, model_config.revision, model_config.rope_scaling, + model_config.rope_theta, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index e971df16a46d4..ada84018212a0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -23,7 +23,8 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None) -> PretrainedConfig: + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None) -> PretrainedConfig: try: if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -50,10 +51,12 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) - if rope_scaling is not None: - logger.info("Updating rope_scaling from %r to %r", - getattr(config, "rope_scaling", None), rope_scaling) - config.update({"rope_scaling": rope_scaling}) + for key, value in [("rope_scaling", rope_scaling), + ("rope_theta", rope_theta)]: + if value is not None: + logger.info("Updating %s from %r to %r", key, + getattr(config, key, None), value) + config.update({key: value}) return config