Skip to content

Commit

Permalink
Update on "Add a simple sdpa"
Browse files Browse the repository at this point in the history
Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```

Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/)

[ghstack-poisoned]
  • Loading branch information
cccclai committed Apr 19, 2024
2 parents f5ec6cf + 1de1fe7 commit 5465fb7
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 14 deletions.
36 changes: 33 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,21 @@

#define divup4(x) ((x + 3) / 4)

#define to_buffer_i(idx, sizes) \
idx.x + idx.y* sizes.x + idx.z* sizes.y* sizes.x + \
idx.w* sizes.z* sizes.y* sizes.x;
// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape
// Output: buffer_idx in the continuous nchw-buffer.
#define to_buffer_i(idx, sizes) \
(idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \
idx.w * sizes.z * sizes.y * sizes.x)

// Inverse of to_buffer_i
// Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape
// Output: ivec4 user-level coorindate
#define from_buffer_i(buf_i, sizes) \
ivec4( \
buf_i % sizes.x, \
(buf_i / (sizes.x)) % sizes.y, \
(buf_i / (sizes.x * sizes.y)) % sizes.z, \
(buf_i / (sizes.x * sizes.y * sizes.z)))

#define get_packed_dim_C_packed(vec) vec.z
#define get_packed_dim_W_packed(vec) vec.x
Expand All @@ -20,6 +32,8 @@
#define get_packed_stride_W_packed(vec) (1)
#define get_packed_stride_H_packed(vec) (vec.x)

// Input: pos is a texture position, sizes is a pack-aligned size.
// Output: a user-level (w, h, c, n) coordinate
#define to_tensor_idx_C_packed(pos, sizes) \
ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z)

Expand All @@ -29,6 +43,9 @@
#define to_tensor_idx_H_packed(pos, sizes) \
ivec4(pos.x, (pos.y * 4), pos.z % sizes.z, pos.z / sizes.z)

// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
// size.
// Output: texture location
#define to_texture_pos_C_packed(idx, sizes) \
ivec3(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4)

Expand All @@ -38,6 +55,19 @@
#define to_texture_pos_H_packed(idx, sizes) \
ivec3(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z))

// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
// size with the index in the texel.
// Output: ivec4, xyz is the texture position, w is the element index in the
// texel.
#define to_texture_pos_elem_C_packed(idx, sizes) \
ivec4(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4, idx.z % 4)

#define to_texture_pos_elem_W_packed(idx, sizes) \
ivec4(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z), idx.x % 4)

#define to_texture_pos_elem_H_packed(idx, sizes) \
ivec4(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z), idx.y % 4)

// Given a buffer(1-D) index cur, compute a new index where the corresponding
// tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
Expand Down
76 changes: 76 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

layout(std430) buffer;

#include "indexing_utils.h"

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;

#define VEC4_T ${texel_type(DTYPE)}

#define to_tensor_idx to_tensor_idx_${PACKING}
#define to_texture_pos_elem to_texture_pos_elem_${PACKING}
#define get_packed_stride get_packed_stride_${PACKING}

layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes {
uvec4 out_gpu_sizes;
};

layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes {
uvec4 out_cpu_sizes;
};

layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes {
uvec4 in_gpu_sizes;
};

layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes {
uvec4 in_cpu_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;


void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes);

if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes))) {
return;
}

// Assume there is a virtual continous buffer in nchw format. From the output
// pos, we first calculate the index in the virual buffer, and then calculate
// the input position from the indx.

const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes);
const uvec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes);

VEC4_T value;
// Need to look up the 4 values in the output texel separately.
for (int i=0; i<4; i++) {
ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes);

ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes);

VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0));

value[i] = intex[in_pos_elem.w];
}

imageStore(image_out, out_pos, value);
}
14 changes: 14 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
view:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
PACKING:
- VALUE: C_packed
- VALUE: W_packed
- VALUE: H_packed
shader_variants:
- NAME: view
51 changes: 51 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/View.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

std::string kernel_name = "view";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);
add_memory_layout_suffix(kernel_name, *t_out);

api::utils::uvec3 global_size = t_out->extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
{t_out->gpu_sizes_ubo(),
t_out->cpu_sizes_ubo(),
t_in->gpu_sizes_ubo(),
t_in->cpu_sizes_ubo()}));
}

void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// Note: The second argument size_ref is not used here. Since the output
// tensor's size have been determined during compilation.
return add_view_node(graph, args[0], args[2]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.view_copy.default, view);
}

} // namespace vkcompute
28 changes: 28 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,33 @@ def get_permute_inputs():
return test_suite


def get_view_inputs():
test_suite = VkTestSuite(
[
((3, 4, 5), [1, 1, -1]),
((3, 4, 5), [1, -1, 1]),
((3, 4, 5), [-1, 1, 1]),
((8, 7, 2, 3), [4, 3, 7, 4]),
((8, 7, 2, 3), [7, -1, 2, 1]),
((8, 7, 2, 3), [1, 1, 1, -1]),
((8, 7, 2, 3), [-1]),
((2, 3, 3, 7), [2, -1, 1, 1]),
((3, 5, 2, 7), [7, -1, 2, 1]),
((2, 2, 8, 6), [2, 6, -1, 1]),
((2, 2, 8, 6), [6, -1, 1]),
((S1, S2, S1, S2), [S2, -1, 1, S1]),
((S1, S2, S1, S2), [S1, 1, -1, S2]),
((S1, S2, S1, S2), [-1, 1, S1, S2]),
]
)
test_suite.layouts = [
"api::kWidthPacked",
"api::kHeightPacked",
"api::kChannelsPacked",
]
return test_suite


test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
Expand All @@ -208,4 +235,5 @@ def get_permute_inputs():
"aten.select_copy.int": get_select_int_inputs(),
"aten.permute.default": get_permute_inputs(),
"aten.permute_copy.default": get_permute_inputs(),
"aten.view_copy.default": get_view_inputs(),
}
8 changes: 7 additions & 1 deletion backends/vulkan/test/op_tests/utils/codegen_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,16 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str:
for size in arg_sizes_or_val:
name_str += str(size) + "x"
name_str = name_str[:-1]
# minus sign is a invalid char for test case. change to "n".
name_str = name_str.replace("-", "n")

elif isinstance(arg_sizes_or_val, list):
for size in arg_sizes_or_val:
name_str += str(size) + "c"
name_str = name_str[:-1]
# minus sign is a invalid char for test case. change to "n".
name_str = name_str.replace("-", "n")

else:
name_str += str(arg_sizes_or_val).replace(".", "p")
return name_str
Expand Down Expand Up @@ -234,7 +240,7 @@ def generate_suite_cpp(self) -> str:
// from_blob doesn't take ownership of data. Hence must create a copy as
// "values" will go out of scope.
return at::from_blob(values.data(), sizes, dtype).detach().clone();
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
}}
{test_suites_cpp}
Expand Down
11 changes: 1 addition & 10 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,7 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:


class SDPASimple(torch.nn.Module):
"""
This is a simpler implementation of SDPA module defined in llama_transformer.py. Notice that it's
an implementation including both some preprocessing logic and F.scaled_dot_product_attention.
"""

def __init__(
self,
kv_cache: KVCache,
Expand All @@ -172,7 +169,6 @@ def forward(
seqlen,
mask,
):
# The first few lines are the same as the original SDPA module.
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
Expand All @@ -182,11 +178,6 @@ def forward(

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

# Following is the different part. Instead of calling F.scaled_dot_product_attention,
# we use the following implementation to avoid the decomposition from F.scaled_dot_product_attention,
# as the decompostion is too expensive. The following will get rid of aten.full_like, aten.logical_not,
# aten.scalar_tensor, aten.where and 2 extra aten.mul.
scale_factor = 1 / math.sqrt(q.size(-1))
attn_weight = q @ k.transpose(-2, -1) * scale_factor
attn_weight += attn_mask
Expand Down

0 comments on commit 5465fb7

Please sign in to comment.