Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added query-key norm to accomodate OLMo2 #1894

Merged
merged 6 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Config:
# Transformer block (structure, normalizations)
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
norm_qk: bool = False
post_attention_norm: bool = False
post_mlp_norm: bool = False
parallel_residual: bool = True
Expand Down
15 changes: 15 additions & 0 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ def __init__(self, config: Config, block_idx: int) -> None:
block_idx % config.sliding_window_layer_stride == 0
)

if config.norm_qk:
self.norm_q = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps)
self.norm_k = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps)
else:
self.norm_q = self.norm_k = None

self.config = config

def forward(
Expand Down Expand Up @@ -324,6 +330,15 @@ def forward(
k = k.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_k, hs)
v = v.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_v, hs)

if self.config.norm_qk:
q = q.reshape(B, T, -1) # (B, T, nh_q * hs)
q = self.norm_q(q)
q = q.view(B, T, self.config.n_head, self.config.head_size)

k = k.reshape(B, T, -1) # (B, T, nh_k * hs)
k = self.norm_k(k)
k = k.view(B, T, self.config.n_query_groups, self.config.head_size)

ysjprojects marked this conversation as resolved.
Show resolved Hide resolved
# The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are
# multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector
# of size `hs`.
Expand Down
Loading