From 0b4629ad9abeac932a531a36c75ec89d7b831afd Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 28 Dec 2024 06:14:27 -0500 Subject: [PATCH 1/3] added query-key norm to accomodate OLMo2 --- litgpt/config.py | 1 + litgpt/model.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/litgpt/config.py b/litgpt/config.py index a4a70c8238..133a9247a1 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -25,6 +25,7 @@ class Config: # Transformer block (structure, normalizations) norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" norm_eps: float = 1e-5 + norm_qk: bool = False post_attention_norm: bool = False post_mlp_norm: bool = False parallel_residual: bool = True diff --git a/litgpt/model.py b/litgpt/model.py index cbdf2a4bdd..ca52a1cec2 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -287,6 +287,12 @@ def __init__(self, config: Config, block_idx: int) -> None: block_idx % config.sliding_window_layer_stride == 0 ) + if config.norm_qk: + self.norm_q = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps) + self.norm_k = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps) + else: + self.norm_q = self.norm_k = None + self.config = config def forward( @@ -324,6 +330,15 @@ def forward( k = k.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_k, hs) v = v.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_v, hs) + if self.config.norm_qk: + q = q.reshape(B, T, -1) # (B, T, nh_q * hs) + q = self.norm_q(q) + q = q.view(B, T, self.config.n_head, self.config.head_size) + + k = k.reshape(B, T, -1) # (B, T, nh_k * hs) + k = self.norm_k(k) + k = k.view(B, T, self.config.n_query_groups, self.config.head_size) + # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector # of size `hs`. From 5f9df570e4513f3b799b6fa5c4b53de90333a27f Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Mon, 30 Dec 2024 19:24:57 +0300 Subject: [PATCH 2/3] Add rerun on failures for test_readme/download[model,books] --- tests/test_readme.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_readme.py b/tests/test_readme.py index 65c21dcde2..95b03e1474 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -32,6 +32,7 @@ def run_command(command): @pytest.mark.dependency() +@pytest.mark.flaky(reruns=5, reruns_delay=2) def test_download_model(): repo_id = str(REPO_ID).replace("\\", "/") # fix for Windows CI command = ["litgpt", "download", str(repo_id)] @@ -48,6 +49,7 @@ def test_download_model(): @pytest.mark.dependency() +@pytest.mark.flaky(reruns=5, reruns_delay=2) def test_download_books(): CUSTOM_TEXTS_DIR.mkdir(parents=True, exist_ok=True) From 395f99f45ca097be5f36e36493bf83022365068d Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Wed, 1 Jan 2025 09:55:33 -0500 Subject: [PATCH 3/3] refactoring --- litgpt/model.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index ca52a1cec2..7f827a9b36 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -324,21 +324,16 @@ def forward( # Split qkv into query, key and value matrices. q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) + if self.config.norm_qk: + q = self.norm_q(q) + k = self.norm_k(k) + # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the # embedding size (C) into num_heads (nh) and head_size (hs). q = q.view(B, T, self.config.n_head, self.config.head_size) # (B, T, nh_q, hs) k = k.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_k, hs) v = v.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_v, hs) - if self.config.norm_qk: - q = q.reshape(B, T, -1) # (B, T, nh_q * hs) - q = self.norm_q(q) - q = q.view(B, T, self.config.n_head, self.config.head_size) - - k = k.reshape(B, T, -1) # (B, T, nh_k * hs) - k = self.norm_k(k) - k = k.view(B, T, self.config.n_query_groups, self.config.head_size) - # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector # of size `hs`.