Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Rename Phi3 rope scaling type #5595

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,10 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] != "su" \
and rope_scaling["type"] != "longrope":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
Expand Down
19 changes: 12 additions & 7 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


class Phi3SuScaledRotaryEmbedding(nn.Module):
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.

Based on the original RotaryEmbedding implementation.
Expand All @@ -491,11 +491,12 @@ def __init__(

if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)

self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
Expand Down Expand Up @@ -608,7 +609,9 @@ def get_rope(
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
if scaling_type != "su":
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type != "su" and scaling_type != "longrope":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
Expand All @@ -633,7 +636,9 @@ def get_rope(
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "su":
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
Expand All @@ -643,7 +648,7 @@ def get_rope(
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
Expand Down
Loading