Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a simple sdpa #3037

Closed
wants to merge 7 commits into from
Closed

Add a simple sdpa #3037

wants to merge 7 commits into from

Conversation

cccclai
Copy link
Contributor

@cccclai cccclai commented Apr 14, 2024

Stack from ghstack (oldest at bottom):

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including torch.where

def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)

Differential Revision: D56119737

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3037

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5465fb7 with merge base 1eed125 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 14, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56119737

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56119737

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56119737

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56119737

This was referenced Apr 16, 2024
@@ -143,6 +144,79 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
return module


class SDPASimple(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just keep sdpa as is by registering it as a custom qnn op that qnn delegate can directly consume?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just for qnn. Basically for any backend that doesn't support sdpa, it makes more sense to use this version instead.

attn_weight = torch.softmax(attn_weight, dim=-1)
y = attn_weight @ v

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need contiguous?

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56119737

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56119737

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56119737

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in cf78107.

cccclai added a commit to cccclai/executorch-1 that referenced this pull request Apr 19, 2024
Summary:
Pull Request resolved: pytorch#3037

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```
After applying the diff, we remove the following ops
```
    %aten_full_like_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.full_like.default](args = (%aten_index_tensor_2, 0), kwargs = {dtype: torch.float32, pin_memory: False, memory_format: torch.preserve_format})

    %aten_logical_not_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.logical_not.default](args = (%aten_index_tensor_2,), kwargs = {})

    %aten_scalar_tensor_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.scalar_tensor.default](args = (-inf,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})

    %aten_where_self : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default, %aten_scalar_tensor_default, %aten_full_like_default), kwargs = {})

    %aten_mul_scalar : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_3, 0.5946035575013605), kwargs = {})
    ...
    %aten_mul_scalar_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_6, 0.5946035575013605), kwargs = {})
```
but introduce an add
    %aten_add_tensor_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mul_tensor_11, %aten_index_tensor_2), kwargs = {})
```
ghstack-source-id: 223152096
exported-using-ghexport

Reviewed By: mergennachin, kimishpatel

Differential Revision: D56119737

fbshipit-source-id: ec8e875f0a4c4ec67b7493e4872c9a5b081e6de7
(cherry picked from commit cf78107)
guangy10 pushed a commit that referenced this pull request Apr 19, 2024
Summary:
Pull Request resolved: #3037

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```
After applying the diff, we remove the following ops
```
    %aten_full_like_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.full_like.default](args = (%aten_index_tensor_2, 0), kwargs = {dtype: torch.float32, pin_memory: False, memory_format: torch.preserve_format})

    %aten_logical_not_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.logical_not.default](args = (%aten_index_tensor_2,), kwargs = {})

    %aten_scalar_tensor_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.scalar_tensor.default](args = (-inf,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})

    %aten_where_self : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default, %aten_scalar_tensor_default, %aten_full_like_default), kwargs = {})

    %aten_mul_scalar : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_3, 0.5946035575013605), kwargs = {})
    ...
    %aten_mul_scalar_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_6, 0.5946035575013605), kwargs = {})
```
but introduce an add
    %aten_add_tensor_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mul_tensor_11, %aten_index_tensor_2), kwargs = {})
```
ghstack-source-id: 223152096
exported-using-ghexport

Reviewed By: mergennachin, kimishpatel

Differential Revision: D56119737

fbshipit-source-id: ec8e875f0a4c4ec67b7493e4872c9a5b081e6de7
(cherry picked from commit cf78107)
This was referenced Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants