Skip to content

Commit

Permalink
Update on "Add Vulkan Quantizer to Llama export lib"
Browse files Browse the repository at this point in the history
TSIA.

Note that only 8 bit weight only quantization is supported for now since `VulkanQuantizer` does not support 4 bit weight only quantization at the moment.

Differential Revision: [D64249615](https://our.internmc.facebook.com/intern/diff/D64249615/)

[ghstack-poisoned]
  • Loading branch information
SS-JIA committed Oct 11, 2024
2 parents a650605 + 96948c1 commit 201741e
Show file tree
Hide file tree
Showing 23 changed files with 437 additions and 28 deletions.
65 changes: 65 additions & 0 deletions backends/qualcomm/_passes/decompose_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.experimental.proxy_tensor import make_fx


class DecomposeEinsum(ExportPass):
"""
Decompose einsum for quantization annotation to work properly.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target == torch.ops.aten.einsum.default:
decomposed_module = make_fx(
node.target,
tracing_mode="fake",
)(node.args[0], [arg.meta["val"] for arg in node.args[1]])

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correclty updated in the new graph
remap = {}
# Different from other nodes, einsum args[0] is the einsum equation,
# while input nodes are stored in args[1]
for i, arg in enumerate(node.args[1]):
remap[f"arg1_{i+1}"] = arg

for decomposed_node in decomposed_module.graph.nodes:
# This is the arg[0] equation string, which is not required anymore after decomposition
if "arg0" in decomposed_node.name:
continue

# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/insert_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InsertRequantize(ExportPass):
# we don't use the 2nd output, 2nd output is an integer, etc.
multi_output_op_ignore_set = {
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.topk.default,
}

def __init__(
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.topk.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.split_with_sizes.default,
*q_ops,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
op_sum_int_list,
op_tanh,
op_to,
op_topk,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand Down Expand Up @@ -107,6 +108,7 @@
op_sub,
op_sum_int_list,
op_tanh,
op_topk,
op_to,
op_transpose,
op_unsqueeze,
Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/builders/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -85,7 +86,10 @@ def define_node(
if len(node.args) > 6:
divisor_override = cast(int, node.args[6])
if divisor_override != pooling_region:
print("Not support divisor_override which is not equal to pooling region.")
warnings.warn(
"[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.",
stacklevel=1,
)
return

avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -43,8 +44,9 @@ def define_node(
)

if len(list_of_tensors) != len(list_of_tensor_wrappers):
print(
"The number or input tensors is not equal to the number of input tensor wrappers."
warnings.warn(
"[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
stacklevel=1,
)
return

Expand Down
11 changes: 9 additions & 2 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -189,12 +190,18 @@ def _define_conv1d(

# args[6] = transposed
if cast(bool, node.args[6]):
print("Currently, No support for transposed convolution")
warnings.warn(
"[QNN Delegate Op Builder]: Currently, No support for transposed convolution.",
stacklevel=1,
)
return

# args[7] = output padding
if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
print("QNN does not support output padding")
warnings.warn(
"[QNN Delegate Op Builder]: QNN does not support output padding.",
stacklevel=1,
)
return

stride_shape = [len(stride)]
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/op_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -52,8 +53,9 @@ def define_node(
output_dims = len(output_tensor.size())

if input_dims < output_dims:
print(
f"The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}."
warnings.warn(
f"[QNN Delegate Op Builder]: The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.",
stacklevel=1,
)
return

Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -44,7 +45,10 @@ def define_node(
len(normalized_shapes) != 1
and normalized_shapes[0] != input_tensor.shape[-1]
):
print("Only supports normalization with last input dimension")
warnings.warn(
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
stacklevel=1,
)
return
axis = [len(input_tensor.shape) - 1]
axis_shape = [len(axis)]
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -70,8 +71,9 @@ def define_node(

# TODO remove this when qnn sdk support
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
print(
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
warnings.warn(
f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.",
stacklevel=1,
)
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
Expand Down
11 changes: 7 additions & 4 deletions backends/qualcomm/builders/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -42,8 +43,9 @@ def define_node(
if user.target.__name__ == "getitem":
getitem_index = user.args[1]
if getitem_index != 0:
print(
f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}"
warnings.warn(
f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
stacklevel=1,
)
return

Expand Down Expand Up @@ -78,8 +80,9 @@ def define_node(
if len(node.args) > 4:
dilation = cast(List[int], node.args[4])
if not (dilation == 1 or dilation == [1, 1]):
print(
f"Not support dilation argument for max pool2d, but got {dilation}"
warnings.warn(
f"[QNN Delegate Op Builder]: Not support dilation argument for max pool2d, but got {dilation}",
stacklevel=1,
)
return

Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/builders/op_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -47,7 +48,10 @@ def define_node(
len(normalized_shapes) != 1
and normalized_shapes[0] != input_tensor.shape[-1]
):
print("Only supports normalization with last input dimension")
warnings.warn(
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
stacklevel=1,
)
return
axes = [node.args[0].meta["val"].dim() - 1]
axes_shape = [len(axes)]
Expand Down
107 changes: 107 additions & 0 deletions backends/qualcomm/builders/op_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpTopK, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class TopK(NodeVisitor):
target = ["aten.topk.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:

input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=True,
)

k = cast(int, node.args[1])

if len(node.args) > 2:
dim = cast(int, node.args[2])
if dim < 0:
dim = dim % len(input_tensor.shape)
if QCOM_AXIS_ORDER in node.meta:
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
if dim != len(input_tensor.shape) - 1:
warnings.warn(
"[QNN Delegate Op Builder]: QNN currently only supports channel as dimension for topK.",
stacklevel=1,
)
return

topk_input_tensors = [input_tensor_wrapper]

output_val_tensor = self.get_tensor(node, node, 0)
output_idx_tensor = self.get_tensor(node, node, 1).to(torch.int32)

# QNN constraint, topk output_0 requires having the same quant config as input
node.meta["quant_attrs"] = input_node.meta.get("quant_attrs")
output_val_tensor_wrapper = self.define_tensor(
node,
output_val_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

# topk output_1 is index, do not quantize it.
node.meta.pop("quant_attrs", None)
output_index_tensor_wrapper = self.define_tensor(
node,
output_idx_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
wrapper_idx=1,
)
topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper]

topk_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpTopK.op_name,
)
topk_op.AddInputTensors(topk_input_tensors)
topk_op.AddOutputTensors(topk_output_tensors)

topk_op.AddScalarParam(
OpTopK.param_k,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(k)},
)

# As of QNN 2.26, QNN HTP backend only allows users to set this value to 1, or else it will fail at op validation
if len(node.args) > 3:
largest = cast(bool, node.args[3])
topk_op.AddScalarParam(
OpTopK.param_largest,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: largest},
)

return topk_op
Loading

0 comments on commit 201741e

Please sign in to comment.