Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - support embedding op (#2057)
Browse files Browse the repository at this point in the history
Summary:
- support embedding op with int32 index input
- make mobilebert / llama2 be fully delegated
- add requantize passes for mixed precision
- bug fixes

Pull Request resolved: #2057

Reviewed By: dbort

Differential Revision: D54348816

Pulled By: cccclai

fbshipit-source-id: ec3c8e87cc879d6f642859231255d5094d78349f
  • Loading branch information
haowhsu-quic authored and facebook-github-bot committed Mar 3, 2024
1 parent 75352ad commit 57e192b
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 30 deletions.
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
op_depth_to_space,
op_dequantize,
op_div,
op_embedding,
op_expand,
op_gelu,
op_hardswish,
Expand Down Expand Up @@ -62,6 +63,7 @@
op_depth_to_space,
op_dequantize,
op_div,
op_embedding,
op_expand,
op_gelu,
op_hardswish,
Expand Down
7 changes: 6 additions & 1 deletion backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]:
{},
)

quant_attrs = node.meta["quant_attrs"]
quant_attrs = (
node.meta["requantize"]["dq_attrs"]
if "requantize" in node.meta
else node.meta["quant_attrs"]
)
encoding = quant_attrs["encoding"]

quant_config = {}
if encoding in PER_CHANNEL_ENCODING_MAPPING:
scales = quant_attrs["scales"]
Expand Down
74 changes: 74 additions & 0 deletions backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
from .utils import get_parameter


@register_node_visitor
class Embedding(NodeVisitor):
target = "aten.embedding.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
weight_node = node.args[0]
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
)

indices_node = node.args[1]
indices_tensor = self.get_tensor(indices_node, node)
indices_tensor_wrapper = self.define_scalar(
indices_node,
indices_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper]

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
gather_output_tensors = [output_tensor_wrapper]

gather_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpGather.op_name,
)
gather_op.AddInputTensors(gather_input_tensors)
gather_op.AddOutputTensors(gather_output_tensors)

# For now, default axis is zero.
gather_op.AddScalarParam(
OpGather.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
{"data": np.int32(0)},
)

return gather_op
1 change: 0 additions & 1 deletion backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.embedding.default,
]

allow_list_operator = [
Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/passes/annotate_and_quant_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def _traverse_binary_node(self, graph_module: torch.fx.GraphModule):
q_node = dq_node.args[0]
q_node_attrs = get_quant_attrs(graph_module, q_node)

scalar_node = [n for n in output.args if n != dq_node][0]
scalar_nodes = [n for n in output.args if n != dq_node]
if len(scalar_nodes) == 0:
continue

scalar_node = scalar_nodes[0]
source_scalar_node = self._get_source_scalar_node(scalar_node)
# we'll abandon cast op here, since the constant scalar will
# be pre-loaded into QNN context binary
Expand Down
14 changes: 11 additions & 3 deletions backends/qualcomm/passes/insert_io_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ def _insert_node(
graph_module: torch.fx.GraphModule,
node: torch.fx.node,
target: torch.fx.node.Target,
quant_attrs: Dict = None,
) -> torch.fx.node:
quant_attrs = node.meta.get("quant_attrs")
# check if there has a specified quant_attrs
# if not, use the existent info. from current node
if quant_attrs is None:
quant_attrs = node.meta.get("quant_attrs")

inserted_node = graph_module.graph.create_node(
"call_function",
target,
Expand All @@ -69,13 +74,16 @@ def _insert_quant_node(
graph_module: torch.fx.GraphModule,
node: torch.fx.node,
target: torch.fx.node.Target,
) -> None:
quant_attrs: Dict = None,
) -> torch.fx.Node:
with graph_module.graph.inserting_after(node):
users = list(node.users.keys())
inserted_node = self._insert_node(graph_module, node, target)
inserted_node = self._insert_node(graph_module, node, target, quant_attrs)
for user in users:
user.replace_input_with(node, inserted_node)

return inserted_node

def _insert_dequant_node(
self,
graph_module: torch.fx.GraphModule,
Expand Down
57 changes: 57 additions & 0 deletions backends/qualcomm/passes/insert_requantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ
from executorch.exir.dialects._ops import ops as exir_ops


class InsertRequantize(InsertIOQDQ):
"""
This pass inserts dq/q nodes for non-arithmetic operators which have
different quantization specs in input and activation
"""

def __init__(
self,
edge_program: torch.export.ExportedProgram,
insert_requantize: bool = False,
):
super().__init__(edge_program)
# add non-arithmetic operators here if condition met
self.op_map = {
exir_ops.edge.aten.permute_copy.default: self._single_io_annotation,
}
self.insert_requantize = insert_requantize

def _single_io_annotation(self, gm: torch.fx.GraphModule, n: torch.fx.node) -> None:
in_q_attr = n.args[0].meta.get("quant_attrs")
out_q_attr = n.meta["quant_attrs"]
if in_q_attr is not None and in_q_attr["dtype"] != out_q_attr["dtype"]:
if self.insert_requantize:
dq_attr = n.meta["requantize"]["dq_attrs"]
q_attr = n.meta["requantize"]["q_attrs"]
# insert dq with given quantization attribute in input node
dq = self._insert_quant_node(gm, n, dq_attr["encoding"], dq_attr)
dq.meta["quant_attrs"] = dq_attr
# insert q with given quantization attribute in current node
q = self._insert_quant_node(gm, dq, q_attr["encoding"], q_attr)
q.meta["quant_attrs"] = q_attr
else:
dq_attr = in_q_attr.copy()
dq_attr["encoding"] = self.q_dq_map[out_q_attr["encoding"]]
q_attr = out_q_attr.copy()
n.meta["requantize"] = {"dq_attrs": dq_attr, "q_attrs": q_attr}

def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for n in graph_module.graph.nodes:
if (
n.op == "call_function"
and n.meta.get("quant_attrs")
and n.target in self.op_map
):
self.op_map[n.target](graph_module, n)
7 changes: 7 additions & 0 deletions backends/qualcomm/passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.sym_util import eval_shape

from .utils import dq_ops, q_ops


class LayoutTransform(ExportPass):
"""
Expand Down Expand Up @@ -50,6 +52,8 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.gelu.default,
*q_ops,
*dq_ops,
_operator.getitem,
}

Expand Down Expand Up @@ -77,6 +81,7 @@ def __init__(
super(LayoutTransform, self).__init__()
self.edge_program = edge_program
self.insert_permute = insert_permute
self.qdq_opset = {*q_ops, *dq_ops}

def mark_as_transformed(self, node: torch.fx.Node) -> None:
if isinstance(node.meta["val"], (tuple, list)):
Expand Down Expand Up @@ -108,6 +113,8 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
# if dimemsion is not kept, we'll have no clue how to do layout transform
if len(node.args) < 3 or not node.args[2]:
return False
if node.target in self.qdq_opset:
return "requantize" in node.meta
return node.target in self.layout_agnostic_ops

def is_edge_condition(self, node):
Expand Down
46 changes: 46 additions & 0 deletions backends/qualcomm/passes/recompose_pixel_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class RecomposePixelShuffle(ExportPass):
"""
Merge decomposed operators back to one super node.
"""

def __init__(self):
super().__init__()

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
# decomposed core aten ops
partitions = get_source_partitions(graph, [torch.nn.PixelShuffle])
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
input_node = src_partition.input_nodes[0]
output_node = src_partition.output_nodes[0]
with graph.inserting_after(input_node):
h_in_shape = input_node.meta["val"].shape[2]
h_out_shape = output_node.meta["val"].shape[2]
upscale_factor = h_out_shape / h_in_shape

pixel_shuffle_node = graph.create_node(
"call_function",
exir_ops.edge.aten.pixel_shuffle.default,
(input_node, int(upscale_factor)),
)
users = output_node.users.copy()
for user in users:
user.replace_input_with(output_node, pixel_shuffle_node)
# copy metadata
pixel_shuffle_node.meta = output_node.meta

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/qualcomm/qnn_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear
from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ
from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize
from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform
from executorch.backends.qualcomm.utils.utils import generate_qnn_executorch_option
from executorch.exir.backend.backend_details import (
Expand Down Expand Up @@ -44,6 +45,7 @@ def preprocess(
passes=[
ConvertToLinear(),
InsertIOQDQ(edge_program),
InsertRequantize(edge_program, insert_requantize=True),
LayoutTransform(edge_program, insert_permute=True),
]
)
Expand Down
16 changes: 15 additions & 1 deletion backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def annotate_single_in_single_out(
input_qspec_map[input_act] = quantization_config.input_activation

node_tensor = node.meta.get("val")
if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
if torch.is_tensor(node_tensor) and node_tensor.dtype != torch.float32:
return

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
Expand Down Expand Up @@ -356,6 +356,20 @@ def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> N
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.embedding.default])
def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
weight = node.args[0]

input_qspec_map = {}
input_qspec_map[weight] = quantization_config.input_activation

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((weight, node)),
_annotated=True,
)


@register_annotator([torch.ops.aten.expand.default])
def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
Expand Down
Loading

0 comments on commit 57e192b

Please sign in to comment.