-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for showcasing how to do multi-latent Attention
stack-info: PR: #113, branch: drisspg/stack/6
- Loading branch information
Showing
2 changed files
with
546 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
"""Implementation of Multi-head Level Attention (MLA) RoPE score modification from DeepSeek-V2. | ||
Reference: https://arxiv.org/pdf/2405.04434 - DeepSeek-V2: A Strong, Economical, and | ||
Efficient Mixture-of-Experts Language Model | ||
""" | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn.attention.flex_attention import _score_mod_signature | ||
|
||
|
||
def generate_mla_rope_score_mod( | ||
query_rope: Tensor, | ||
key_rope: Tensor, | ||
num_heads: int, | ||
scale: float = 1.0, | ||
) -> _score_mod_signature: | ||
"""Returns an MLA RoPE score modification function to be used w/ FlexAttention | ||
Args: | ||
query_pe: Positional embeddings for queries [batch, num_heads, seq_len, head_dim] | ||
key_pe: Positional embeddings for keys [batch, num_heads//128, seq_len, head_dim] | ||
num_heads: The number of query heads | ||
scale: Scaling factor for the positional embedding contribution | ||
use_vmap: Whether to use vectorized operations (recommended for training) | ||
Returns: | ||
mla_rope_score_mod: Score modification function for FlexAttention | ||
""" | ||
|
||
def mla_rope_score_mod( | ||
score: Tensor, b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor | ||
) -> Tensor: | ||
return score + ( | ||
scale * torch.dot(query_rope[b, h, q_idx], key_rope[b, h // num_heads, kv_idx]) | ||
) | ||
|
||
mla_rope_score_mod.__name__ = f"mla_rope_score_mod_scale_{scale}" | ||
return mla_rope_score_mod | ||
|
||
|
||
def main(device: str = "cuda"): | ||
"""Visualize the attention scores with MLA RoPE modification. | ||
Args: | ||
device: Device to use for computation | ||
""" | ||
from attn_gym import visualize_attention_scores | ||
|
||
# Example dimensions | ||
B, H, SEQ_LEN, LATENT_HEAD_DIM = 1, 128, 8, 512 | ||
ROPE_HEAD_DIM = 64 | ||
|
||
# Create random tensors for visualization | ||
query = torch.rand(B, H, SEQ_LEN, LATENT_HEAD_DIM, device=device) | ||
|
||
key = torch.rand(B, 1, SEQ_LEN, LATENT_HEAD_DIM, device=device) | ||
|
||
# Create positional embeddings | ||
query_pe = torch.rand(B, H, SEQ_LEN, ROPE_HEAD_DIM, device=device) | ||
key_pe = torch.rand(B, 1, SEQ_LEN, ROPE_HEAD_DIM, device=device) | ||
|
||
# Generate the score modification function | ||
mla_rope_score_mod = generate_mla_rope_score_mod( | ||
query_rope=query_pe, key_rope=key_pe, num_heads=H | ||
) | ||
|
||
# Visualize attention scores with MLA RoPE modification | ||
visualize_attention_scores( | ||
query, key, score_mod=mla_rope_score_mod, device=device, name="mla_rope_score_mod" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
from jsonargparse import CLI | ||
except ImportError: | ||
raise ImportError("Be sure to run: pip install -e .'[viz]'") | ||
CLI(main) |
Oops, something went wrong.