From 61cabec535464af825424b84e9055e30a9ee06a3 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Wed, 17 Apr 2024 12:24:43 -0700 Subject: [PATCH] move mask as sdpa input instead of attribute (#3036) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3036 sdpa (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) input is taking attention mask as input, refactor the sdpa module input closer to the sdpa input ghstack-source-id: 222650466 exported-using-ghexport Reviewed By: mergennachin Differential Revision: D56119739 fbshipit-source-id: d9adda66e540abc518b7ffb6a5ebd2aab1626b3b (cherry picked from commit b341223dcfaeb6ae451060e5925aec01ee01d340) --- examples/models/llama2/export_llama_lib.py | 5 ++--- examples/models/llama2/llama_transformer.py | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index cb166445a8..890c909f66 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -96,12 +96,10 @@ class SDPACustom(torch.nn.Module): def __init__( self, kv_cache: KVCache, - mask, dim: int, ): super().__init__() self.kv_cache = kv_cache - self.mask = mask self.dim = dim def forward( @@ -112,6 +110,7 @@ def forward( v: torch.Tensor, bsz, seqlen, + mask, ): output = torch.ops.llama.sdpa_with_kv_cache( q, @@ -131,7 +130,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): setattr( module, name, - SDPACustom(child.kv_cache, child.mask, child.dim), + SDPACustom(child.kv_cache, child.dim), ) else: _replace_sdpa_with_custom_op(child) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index f16ec2bdec..189280bb8a 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -197,14 +197,14 @@ class SDPA(nn.Module): def __init__( self, kv_cache: KVCache, - mask, dim: int, + head_dim: int, n_rep: int, ): super().__init__() self.kv_cache = kv_cache - self.mask = mask self.dim = dim + self.head_dim = head_dim self.n_rep = n_rep def forward( @@ -215,17 +215,18 @@ def forward( v: torch.Tensor, bsz, seqlen, + mask: torch.Tensor, ) -> torch.Tensor: q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) k, v = self.kv_cache.update(input_pos, k, v) - mask = self.mask[None, None, input_pos] + attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) @@ -271,10 +272,10 @@ def __init__(self, args: ModelArgs, layer_id: int): not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v ) self.SDPA = SDPA( - self.kv_cache, - self.mask, - self.dim, - self.n_rep, + kv_cache=self.kv_cache, + dim=self.dim, + head_dim=self.head_dim, + n_rep=self.n_rep, ) def forward( @@ -298,7 +299,7 @@ def forward( if self.use_kv_cache: assert input_pos is not None - output = self.SDPA(input_pos, q, k, v, bsz, seqlen) + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) return self.wo(output) q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)