Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Customizable RoPE theta #5197

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


def test_rope_scaling():
def test_rope_customization():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}

llama_model_config = ModelConfig(
Expand All @@ -76,6 +77,7 @@ def test_rope_scaling():
seed=0,
)
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
assert llama_model_config.max_model_len == 8192

llama_model_config = ModelConfig(
Expand All @@ -86,9 +88,12 @@ def test_rope_scaling():
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
rope_theta=TEST_ROPE_THETA,
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert getattr(llama_model_config.hf_config, "rope_theta",
None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384

longchat_model_config = ModelConfig(
Expand Down
4 changes: 3 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
Expand All @@ -113,6 +114,7 @@ def __init__(
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
Expand All @@ -132,7 +134,7 @@ def __init__(
self.skip_tokenizer_init = skip_tokenizer_init

self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling)
code_revision, rope_scaling, rope_theta)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class EngineArgs:
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
Expand Down Expand Up @@ -400,6 +401,12 @@ def add_cli_args(
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument('--rope-theta',
default=None,
type=float,
help='RoPE theta. Use with `rope_scaling`. In '
'some cases, changing the RoPE theta improves the '
'performance of the scaled model.')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
Expand Down Expand Up @@ -630,6 +637,7 @@ def create_engine_config(self, ) -> EngineConfig:
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"rope_scaling=%r, tokenizer_revision=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
Expand All @@ -177,6 +177,7 @@ def __init__(
model_config.tokenizer_mode,
model_config.revision,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
Expand Down
15 changes: 9 additions & 6 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None) -> PretrainedConfig:
try:
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
Expand All @@ -50,10 +51,12 @@ def get_config(model: str,
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
if rope_scaling is not None:
logger.info("Updating rope_scaling from %r to %r",
getattr(config, "rope_scaling", None), rope_scaling)
config.update({"rope_scaling": rope_scaling})
for key in ["rope_scaling", "rope_theta"]:
if locals()[key] is not None:
logger.info("Updating %s from %r to %r", key,
getattr(config, key, None),
locals()[key])
config.update({key: locals()[key]})
sasha0552 marked this conversation as resolved.
Show resolved Hide resolved
return config


Expand All @@ -68,4 +71,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
else:
return config
return config
Loading