Skip to content

Commit

Permalink
Decouple custom ops in llama_transformer.py Part 1/N (pytorch#3005)
Browse files Browse the repository at this point in the history
Summary:
This is a no-op

Pull Request resolved: pytorch#3005

Test Plan:
CI

Run with

`python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv --use_sdpa_with_kv_cache -X`

and with

`python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv -X`

Make sure both work

Reviewed By: cccclai

Differential Revision: D56048177

Pulled By: mergennachin

fbshipit-source-id: 3ac9ac5c34f6fe215de1cfe8b5ddc7aae3635359
  • Loading branch information
mergennachin authored and facebook-github-bot committed Apr 12, 2024
1 parent b1edc3d commit 488afc5
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 56 deletions.
6 changes: 1 addition & 5 deletions examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,7 @@ def source_transform(
def _get_dynamic_shape(self) -> Any:
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
if self.use_kv_cache:
if self.use_sdpa_with_kv_cache:
return None
else:
# return {1: dim}, {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache
return None
return None
else:
return ({1: dim},)

Expand Down
4 changes: 4 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
pass
# TODO: Next diff transforms.append()

return (
load_llama_model(
checkpoint=checkpoint_path,
Expand Down
134 changes: 98 additions & 36 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,95 @@ def update(
return k_out, v_out


class SDPA(nn.Module):
def __init__(
self,
kv_cache: KVCache,
mask,
use_sdpa_with_kv_cache_op: bool,
dim: int,
n_rep: int,
):
super().__init__()
self.kv_cache = kv_cache
self.mask = mask
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
self.dim = dim
self.n_rep = n_rep

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
) -> torch.Tensor:
if not self.use_sdpa_with_kv_cache_op:
return self._forward_default(
input_pos,
q,
k,
v,
bsz,
seqlen,
)
else:
return self._forward_custom(
input_pos,
q,
k,
v,
bsz,
seqlen,
)

def _forward_custom(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
):
from .custom_ops import sdpa_with_kv_cache # noqa

output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
return output.view(bsz, seqlen, self.dim)

def _forward_default(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
mask = self.mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
Expand All @@ -229,7 +318,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op
self.layer_id = layer_id

causal_mask = torch.tril(
Expand All @@ -250,6 +338,13 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
)
self.SDPA = SDPA(
self.kv_cache,
self.mask,
args.use_sdpa_with_kv_cache_op,
self.dim,
self.n_rep,
)

def forward(
self,
Expand All @@ -272,41 +367,8 @@ def forward(

if self.use_kv_cache:
assert input_pos is not None

if not self.use_sdpa_with_kv_cache_op:

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
mask = self.mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
return y
else:
from .custom_ops import sdpa_with_kv_cache # noqa

output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
output = output.view(bsz, seqlen, -1)
output = self.wo(output)
return output
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
return self.wo(output)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
Expand Down
16 changes: 1 addition & 15 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,7 @@ def get_eager_model(self):

def get_example_inputs(self):
if self.use_kv_cache:
if self.use_sdpa_with_kv_cache_op:
return self.get_example_inputs_kvcache_sdpa()
else:
# return self.get_example_inputs_kvcache() TODO xnnpack does not handle forwarding symints, update partitioner to not partition symints
return self.get_example_inputs_kvcache_sdpa()
return self.get_example_inputs_kvcache_sdpa()
else:
return (
torch.tensor(
Expand All @@ -231,13 +227,3 @@ def get_example_inputs_kvcache_sdpa(self):
[0], dtype=torch.long
), # start_pos, what token of output are we on.)
)

def get_example_inputs_kvcache(self):
return (
torch.tensor(
[[1, 2, 3]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
torch.tensor(
[0, 1, 2], dtype=torch.long
), # start_pos, what token of output are we on.
)

0 comments on commit 488afc5

Please sign in to comment.