diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 35577ad3ec..00d71a5b01 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -206,11 +206,7 @@ def source_transform( def _get_dynamic_shape(self) -> Any: dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1) if self.use_kv_cache: - if self.use_sdpa_with_kv_cache: - return None - else: - # return {1: dim}, {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache - return None + return None else: return ({1: dim},) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 76cfd00f3b..8728b3fdd2 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -492,6 +492,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) + if args.use_sdpa_with_kv_cache: + pass + # TODO: Next diff transforms.append() + return ( load_llama_model( checkpoint=checkpoint_path, diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 66fc47b17f..d0794b8c37 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -209,6 +209,95 @@ def update( return k_out, v_out +class SDPA(nn.Module): + def __init__( + self, + kv_cache: KVCache, + mask, + use_sdpa_with_kv_cache_op: bool, + dim: int, + n_rep: int, + ): + super().__init__() + self.kv_cache = kv_cache + self.mask = mask + self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op + self.dim = dim + self.n_rep = n_rep + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ) -> torch.Tensor: + if not self.use_sdpa_with_kv_cache_op: + return self._forward_default( + input_pos, + q, + k, + v, + bsz, + seqlen, + ) + else: + return self._forward_custom( + input_pos, + q, + k, + v, + bsz, + seqlen, + ) + + def _forward_custom( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ): + from .custom_ops import sdpa_with_kv_cache # noqa + + output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.kv_cache.k_cache, + self.kv_cache.v_cache, + input_pos[-1].item(), + seqlen, + ) + return output.view(bsz, seqlen, self.dim) + + def _forward_default( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ) -> torch.Tensor: + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + k, v = self.kv_cache.update(input_pos, k, v) + mask = self.mask[None, None, input_pos] + + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_id: int): super().__init__() @@ -229,7 +318,6 @@ def __init__(self, args: ModelArgs, layer_id: int): self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) - self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op self.layer_id = layer_id causal_mask = torch.tril( @@ -250,6 +338,13 @@ def __init__(self, args: ModelArgs, layer_id: int): self.head_dim, not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v ) + self.SDPA = SDPA( + self.kv_cache, + self.mask, + args.use_sdpa_with_kv_cache_op, + self.dim, + self.n_rep, + ) def forward( self, @@ -272,41 +367,8 @@ def forward( if self.use_kv_cache: assert input_pos is not None - - if not self.use_sdpa_with_kv_cache_op: - - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) - mask = self.mask[None, None, input_pos] - - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - y = self.wo(y) - return y - else: - from .custom_ops import sdpa_with_kv_cache # noqa - - output = torch.ops.llama.sdpa_with_kv_cache( - q, - k, - v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, - input_pos[-1].item(), - seqlen, - ) - output = output.view(bsz, seqlen, -1) - output = self.wo(output) - return output + output = self.SDPA(input_pos, q, k, v, bsz, seqlen) + return self.wo(output) q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 5428e34b74..31931adb38 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -209,11 +209,7 @@ def get_eager_model(self): def get_example_inputs(self): if self.use_kv_cache: - if self.use_sdpa_with_kv_cache_op: - return self.get_example_inputs_kvcache_sdpa() - else: - # return self.get_example_inputs_kvcache() TODO xnnpack does not handle forwarding symints, update partitioner to not partition symints - return self.get_example_inputs_kvcache_sdpa() + return self.get_example_inputs_kvcache_sdpa() else: return ( torch.tensor( @@ -231,13 +227,3 @@ def get_example_inputs_kvcache_sdpa(self): [0], dtype=torch.long ), # start_pos, what token of output are we on.) ) - - def get_example_inputs_kvcache(self): - return ( - torch.tensor( - [[1, 2, 3]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. - torch.tensor( - [0, 1, 2], dtype=torch.long - ), # start_pos, what token of output are we on. - )