Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Support kv_cached stories 110M llama2 (#4142
Browse files Browse the repository at this point in the history
)

Summary:
- Add custom memory descirptor
- Add e2e example script verified with story110M in 8a8w, 16a4w
- Add qnn_llama_runner to run static LLAMA.
- Add readme
- Add slice op test
- Change RemoveClone to RemoveRedundancy
- Change SimpleADB parameter artifact to build_path and related codes
- Change multihead attentions to multiple single head.
- Move sort inputs from execute to init
- Remove split op
- Support u16 and u8 mixed-precision quantization.

Pull Request resolved: #4142

Reviewed By: kirklandsign

Differential Revision: D59339823

Pulled By: cccclai

fbshipit-source-id: 51fcf14e406b04c51de6e421cccbad91a8ffa01e
  • Loading branch information
shewu-quic authored and facebook-github-bot committed Jul 6, 2024
1 parent 29fdaa1 commit 5584b9e
Show file tree
Hide file tree
Showing 68 changed files with 3,572 additions and 759 deletions.
4 changes: 4 additions & 0 deletions backends/qualcomm/aot/wrappers/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class TensorWrapper {
return QNN_VER_PTR(tensor_)->rank;
};

std::uint32_t GetBytes() const {
return bytes_;
};

const void* GetStaticTensorData() const {
return QNN_VER_PTR(tensor_)->clientBuf.data;
};
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -50,6 +49,7 @@
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand All @@ -62,7 +62,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -102,6 +101,7 @@
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand Down
30 changes: 12 additions & 18 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

from executorch.exir.dialects._ops import ops as exir_ops

from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
from .utils import (
deduce_dtype,
get_parameter,
is_graph_input,
is_graph_output,
is_parameter,
)


QNN_QUANT_TYPE_MAP = {
Expand Down Expand Up @@ -217,21 +223,7 @@ def get_data_type(
quant_config: Dict,
) -> PyQnnWrapper.Qnn_TensorType_t:
if quant_config:
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
unsigned = quant_config["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
if unsigned:
quant_config["dtype"] = torch.uint8
else:
quant_config["dtype"] = torch.int8
elif (
quant_range
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
):
if unsigned:
quant_config["dtype"] = torch.uint16
else:
quant_config["dtype"] = torch.int16
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]

return QNN_TENSOR_TYPE_MAP[tensor.dtype]
Expand Down Expand Up @@ -277,7 +269,6 @@ def define_tensor(
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
is_input_tensor: bool,
node_name: str = None,
is_tensor: bool = True,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Expand All @@ -296,7 +287,10 @@ def define_tensor(

if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached
tensor_name = node.name

tensor_name = f"{node.name}_{wrapper_idx}"
if is_graph_input(node, self.edge_program):
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
Expand Down
57 changes: 0 additions & 57 deletions backends/qualcomm/builders/op_cast.py

This file was deleted.

2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def define_node(
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
is_input_tensor=True,
)

indices_node = node.args[1]
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/op_pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def define_node(

# scalar input
scalar = node.args[1]
scalar_tensor = torch.full(input_tensor.size(), scalar).to(torch.float32)
scalar_tensor = torch.tensor(scalar).to(torch.float32)

# 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
scalar_node = torch.fx.Node(
node.graph,
node.name + "_runtime_scalar",
"call_function",
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.scalar_tensor.default,
(), # args
{}, # kwargs
)
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def define_node(
ranges = []
for i in range(input_tensor_rank):
if i == dim:
ranges.extend([start, end, 1])
# find step
step = node.args[4] if len(node.args) > 4 else 1
ranges.extend([start, end, step])
else:
ranges.extend([0, input_tensor.shape[i], 1])

Expand Down
1 change: 0 additions & 1 deletion backends/qualcomm/builders/op_split_with_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def define_node(
# Edge represents chunks by specifying the size of each chunk
# QNN represents chunks by specifying the index to split chunks
for index, _value in enumerate(chunks[:-1]):

sum = sum + chunks[index]
split_indices.append(sum)

Expand Down
104 changes: 104 additions & 0 deletions backends/qualcomm/builders/op_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class To(NodeVisitor):
target = ["aten._to_copy.default"]
sufixed_8_offset_diff = 128
sufixed_16_offset_diff = 32768
epsilon = 1e-6
sufixed_8 = {
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
}
sufixed_16 = {
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
}

def __init__(self, *args) -> None:
super().__init__(*args)

def is_cast_node(self, node):
input_node = node.args[0]

# Not a case which has two quant node, no need to consider the convert op
if not all([input_node.meta.get("quant_attrs"), node.meta.get("quant_attrs")]):
return True

input_tensor = self.get_tensor(input_node, node)
_, inp_qconfs = self.get_quant_encoding_conf(input_node, False)
inp_dtype = self.get_data_type(input_tensor, inp_qconfs)

output_tensor = self.get_tensor(node, node)
_, out_qconfs = self.get_quant_encoding_conf(node, False)
out_dtype = self.get_data_type(output_tensor, out_qconfs)
is_qparam_castable = (
lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon
and abs(o1 - o2) == diff
)

if {inp_dtype, out_dtype} == self.sufixed_8:
return is_qparam_castable(
inp_qconfs["offset"],
out_qconfs["offset"],
inp_qconfs["scale"],
out_qconfs["scale"],
self.sufixed_8_offset_diff,
)
elif {inp_dtype, out_dtype} == self.sufixed_16:
return is_qparam_castable(
inp_qconfs["offset"],
out_qconfs["offset"],
inp_qconfs["scale"],
out_qconfs["scale"],
self.sufixed_16_offset_diff,
)
return False

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)

input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

output_tensor = self.get_tensor(node, node)

output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

qnn_op = OpCast if self.is_cast_node(node) else OpConvert
op = PyQnnWrapper.PyQnnOpWrapper(
node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name
)
op.AddInputTensors([input_tensor_wrapper])
op.AddOutputTensors([output_tensor_wrapper])

return op
5 changes: 5 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class OpConv2d:
param_dilation: str = "dilation"


@dataclass(init=False, frozen=True)
class OpConvert:
op_name: str = "Convert"


@dataclass(init=False, frozen=True)
class OpDepthToSpace:
op_name: str = "DepthToSpace"
Expand Down
19 changes: 19 additions & 0 deletions backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Optional

import torch
from torch._export.utils import get_buffer, get_param, is_buffer, is_param

Expand Down Expand Up @@ -100,3 +102,20 @@ def is_constant(
return tensor.meta["val"].constant is not None

return False


def deduce_dtype(
tensor: torch.Tensor, quant_infos: Optional[Dict] = None
) -> torch.dtype:
if quant_infos:
quant_range = quant_infos["quant_max"] - quant_infos["quant_min"]
unsigned = quant_infos["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
return torch.uint8 if unsigned else torch.int8

elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min:
return torch.uint16 if unsigned else torch.int16

return quant_infos["dtype"]

return tensor.dtype
2 changes: 1 addition & 1 deletion backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
not_supported_operator = [
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.index_put.default,
]

Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
)

self.skip_node_id_set = skip_node_id_set
self.nodes_to_wrappers = self.nodes_to_wrappers = defaultdict(dict)
self.nodes_to_wrappers = defaultdict(dict)
self.qnn_manager = PyQnnManager.QnnManager(
generate_qnn_executorch_option(compiler_specs)
)
Expand Down Expand Up @@ -96,6 +96,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}")
return supported

def __del__(self):
self.qnn_manager.Destroy()


class QnnPartitioner(Partitioner):
def __init__(
Expand Down Expand Up @@ -145,6 +148,7 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
# pop certain keys in meta for not affecting the passes in compilation
# TODO: need to put property name in common definitions
node.meta.pop("axis_order", "")
del self.op_support_checker
return PartitionResult(
tagged_exported_program=edge_program, partition_tags=self.partition_tags
)
Loading

0 comments on commit 5584b9e

Please sign in to comment.