From fcc5187a409bf2aaa3a60ed6aade0d06a86952a2 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Wed, 9 Oct 2024 09:21:28 -0500 Subject: [PATCH] Use more realistic RoPE tests --- tests/test_rope.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_rope.py b/tests/test_rope.py index 14ea33c0aa..7293e52fa7 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -110,7 +110,7 @@ def test_rope_llama_3(): # See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings @torch.inference_mode() def test_rope_llama_3_1(): - head_dim = 128 + head_dim = 32 rope_theta = 50_000 their_rope_config = { @@ -130,7 +130,8 @@ def test_rope_llama_3_1(): config = LlamaConfig( rope_theta=rope_theta, - rope_scaling=their_rope_config + rope_scaling=their_rope_config, + head_dim=head_dim ) ################################## @@ -138,7 +139,7 @@ def test_rope_llama_3_1(): ################################## # transformer rope rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") - batch_size, seq_len = 1, 10 + batch_size, seq_len = 1, 131_072 qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -169,7 +170,7 @@ def test_rope_llama_3_1(): # See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings @torch.inference_mode() def test_rope_llama_3_2(): - head_dim = 128 + head_dim = 32 rope_theta = 50_000 their_rope_config = { @@ -189,7 +190,8 @@ def test_rope_llama_3_2(): config = LlamaConfig( rope_theta=rope_theta, - rope_scaling=their_rope_config + rope_scaling=their_rope_config, + head_dim=head_dim ) ################################## @@ -197,7 +199,7 @@ def test_rope_llama_3_2(): ################################## # transformer rope rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") - batch_size, seq_len = 1, 10 + batch_size, seq_len = 1, 131_072 qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -222,4 +224,5 @@ def test_rope_llama_3_2(): ours_k_rot = apply_rope(keys, ours_cos, ours_sin) theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) torch.testing.assert_close(theirs_q_rot, ours_q_rot) - torch.testing.assert_close(theirs_k_rot, ours_k_rot) \ No newline at end of file + torch.testing.assert_close(theirs_k_rot, ours_k_rot) +