Skip to content

Commit

Permalink
Add a simple sdpa (#3037) (#3166)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3037

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```
After applying the diff, we remove the following ops
```
    %aten_full_like_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.full_like.default](args = (%aten_index_tensor_2, 0), kwargs = {dtype: torch.float32, pin_memory: False, memory_format: torch.preserve_format})

    %aten_logical_not_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.logical_not.default](args = (%aten_index_tensor_2,), kwargs = {})

    %aten_scalar_tensor_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.scalar_tensor.default](args = (-inf,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})

    %aten_where_self : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default, %aten_scalar_tensor_default, %aten_full_like_default), kwargs = {})

    %aten_mul_scalar : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_3, 0.5946035575013605), kwargs = {})
    ...
    %aten_mul_scalar_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_6, 0.5946035575013605), kwargs = {})
```
but introduce an add
    %aten_add_tensor_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mul_tensor_11, %aten_index_tensor_2), kwargs = {})
```
ghstack-source-id: 223152096
exported-using-ghexport

Reviewed By: mergennachin, kimishpatel

Differential Revision: D56119737

fbshipit-source-id: ec8e875f0a4c4ec67b7493e4872c9a5b081e6de7
(cherry picked from commit cf78107)
  • Loading branch information
cccclai authored Apr 19, 2024
1 parent aa3f22c commit efb7cf3
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
75 changes: 75 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import argparse
import copy
import logging
import math
import os
import shlex

Expand Down Expand Up @@ -143,6 +144,80 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
return module


class SDPASimple(torch.nn.Module):

def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
mask,
):
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
attn_mask = mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
scale_factor = 1 / math.sqrt(q.size(-1))
attn_weight = q @ k.transpose(-2, -1) * scale_factor
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
y = attn_weight @ v

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
setattr(
module,
name,
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
)
else:
replace_sdpa_with_simple_sdpa(child)
return module


def replace_causal_mask(module: torch.nn.Module):
for buffer_fqn_name, buffer in module.named_buffers():
buffer_name = buffer_fqn_name.split(".")[-1]
if buffer_name == "mask":
max_seq_len = buffer.shape[-1]
mask = torch.full(
(max_seq_len, max_seq_len),
float("-inf"),
device="cpu",
)

mask = torch.triu(mask, diagonal=1)
module.register_buffer(buffer_name, mask)
for _, child in module.named_children():
replace_causal_mask(child)
return module


def quantize(
model: torch.nn.Module,
qmode: str,
Expand Down
15 changes: 15 additions & 0 deletions examples/models/llama2/tests/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

oncall("executorch")

python_unittest(
name = "test_simple_sdpa",
srcs = [
"test_simple_sdpa.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama2:export_library",
"//executorch/examples/models/llama2:llama_transformer",
],
)
54 changes: 54 additions & 0 deletions examples/models/llama2/tests/test_simple_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

import torch
from executorch.examples.models.llama2.export_llama_lib import SDPASimple
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA


class SDPATest(unittest.TestCase):
def test_simple_sdpa(self):
# Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py
max_batch_size = 1
max_seq_length = 128
n_heads = 8
head_dim = 8
dim = 64
n_rep = 1
bsz = 1
seqlen = 1
n_local_heads = n_heads
kv_cache = KVCache(
max_batch_size=max_batch_size,
max_seq_length=max_seq_length,
n_heads=n_heads,
head_dim=head_dim,
transpose_cache=True,
)
sdpa = SDPA(
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
)
input_pos = torch.tensor([0])
query = torch.randn(1, 1, n_local_heads, head_dim)
key = torch.randn(1, 1, n_local_heads, head_dim)
value = torch.randn(1, 1, n_local_heads, head_dim)
mask = torch.randn(max_seq_length, max_seq_length)
sdpa_output = sdpa(
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
)

simple_sdpa = SDPASimple(
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
)
simple_sdpa_output = simple_sdpa(
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
)

# Compare the output from output from two sdpa implementation
self.assertTrue(torch.allclose(sdpa_output, simple_sdpa_output))

0 comments on commit efb7cf3

Please sign in to comment.