Skip to content

Commit

Permalink
[Frontend] Customizable RoPE theta (vllm-project#5197)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha0552 authored and jimpang committed Jun 27, 2024
1 parent 79fc9ca commit 226e869
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 8 deletions.
7 changes: 6 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


def test_rope_scaling():
def test_rope_customization():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}

llama_model_config = ModelConfig(
Expand All @@ -76,6 +77,7 @@ def test_rope_scaling():
seed=0,
)
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
assert llama_model_config.max_model_len == 8192

llama_model_config = ModelConfig(
Expand All @@ -86,9 +88,12 @@ def test_rope_scaling():
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
rope_theta=TEST_ROPE_THETA,
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert getattr(llama_model_config.hf_config, "rope_theta",
None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384

longchat_model_config = ModelConfig(
Expand Down
4 changes: 3 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
Expand All @@ -113,6 +114,7 @@ def __init__(
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
Expand All @@ -132,7 +134,7 @@ def __init__(
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling)
code_revision, rope_scaling, rope_theta)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class EngineArgs:
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
Expand Down Expand Up @@ -400,6 +401,12 @@ def add_cli_args(
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument('--rope-theta',
default=None,
type=float,
help='RoPE theta. Use with `rope_scaling`. In '
'some cases, changing the RoPE theta improves the '
'performance of the scaled model.')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
Expand Down Expand Up @@ -630,6 +637,7 @@ def create_engine_config(self, ) -> EngineConfig:
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"rope_scaling=%r, tokenizer_revision=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
Expand All @@ -177,6 +177,7 @@ def __init__(
model_config.tokenizer_mode,
model_config.revision,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
Expand Down
13 changes: 8 additions & 5 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None) -> PretrainedConfig:
try:
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
Expand All @@ -50,10 +51,12 @@ def get_config(model: str,
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
if rope_scaling is not None:
logger.info("Updating rope_scaling from %r to %r",
getattr(config, "rope_scaling", None), rope_scaling)
config.update({"rope_scaling": rope_scaling})
for key, value in [("rope_scaling", rope_scaling),
("rope_theta", rope_theta)]:
if value is not None:
logger.info("Updating %s from %r to %r", key,
getattr(config, key, None), value)
config.update({key: value})
return config


Expand Down

0 comments on commit 226e869

Please sign in to comment.