-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Qualcomm AI Engine Direct - support embedding op (#2057)
Summary: - support embedding op with int32 index input - make mobilebert / llama2 be fully delegated - add requantize passes for mixed precision - bug fixes Pull Request resolved: #2057 Reviewed By: dbort Differential Revision: D54348816 Pulled By: cccclai fbshipit-source-id: ec3c8e87cc879d6f642859231255d5094d78349f
- Loading branch information
1 parent
75352ad
commit 57e192b
Showing
16 changed files
with
279 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Dict | ||
|
||
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from .node_visitor import NodeVisitor, register_node_visitor | ||
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW | ||
from .utils import get_parameter | ||
|
||
|
||
@register_node_visitor | ||
class Embedding(NodeVisitor): | ||
target = "aten.embedding.default" | ||
|
||
def __init__(self, *args) -> None: | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], | ||
) -> PyQnnWrapper.PyQnnOpWrapper: | ||
weight_node = node.args[0] | ||
weight_tensor = get_parameter(weight_node, self.edge_program) | ||
weight_tensor_wrapper = self.define_tensor( | ||
weight_node, | ||
weight_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, | ||
nodes_to_wrappers, | ||
) | ||
|
||
indices_node = node.args[1] | ||
indices_tensor = self.get_tensor(indices_node, node) | ||
indices_tensor_wrapper = self.define_scalar( | ||
indices_node, | ||
indices_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
) | ||
|
||
gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper] | ||
|
||
output_tensor = self.get_tensor(node, node) | ||
output_tensor_wrapper = self.define_tensor( | ||
node, | ||
output_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
) | ||
gather_output_tensors = [output_tensor_wrapper] | ||
|
||
gather_op = PyQnnWrapper.PyQnnOpWrapper( | ||
node.name, | ||
QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
OpGather.op_name, | ||
) | ||
gather_op.AddInputTensors(gather_input_tensors) | ||
gather_op.AddOutputTensors(gather_output_tensors) | ||
|
||
# For now, default axis is zero. | ||
gather_op.AddScalarParam( | ||
OpGather.param_axis, | ||
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, | ||
{"data": np.int32(0)}, | ||
) | ||
|
||
return gather_op |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
|
||
from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
|
||
class InsertRequantize(InsertIOQDQ): | ||
""" | ||
This pass inserts dq/q nodes for non-arithmetic operators which have | ||
different quantization specs in input and activation | ||
""" | ||
|
||
def __init__( | ||
self, | ||
edge_program: torch.export.ExportedProgram, | ||
insert_requantize: bool = False, | ||
): | ||
super().__init__(edge_program) | ||
# add non-arithmetic operators here if condition met | ||
self.op_map = { | ||
exir_ops.edge.aten.permute_copy.default: self._single_io_annotation, | ||
} | ||
self.insert_requantize = insert_requantize | ||
|
||
def _single_io_annotation(self, gm: torch.fx.GraphModule, n: torch.fx.node) -> None: | ||
in_q_attr = n.args[0].meta.get("quant_attrs") | ||
out_q_attr = n.meta["quant_attrs"] | ||
if in_q_attr is not None and in_q_attr["dtype"] != out_q_attr["dtype"]: | ||
if self.insert_requantize: | ||
dq_attr = n.meta["requantize"]["dq_attrs"] | ||
q_attr = n.meta["requantize"]["q_attrs"] | ||
# insert dq with given quantization attribute in input node | ||
dq = self._insert_quant_node(gm, n, dq_attr["encoding"], dq_attr) | ||
dq.meta["quant_attrs"] = dq_attr | ||
# insert q with given quantization attribute in current node | ||
q = self._insert_quant_node(gm, dq, q_attr["encoding"], q_attr) | ||
q.meta["quant_attrs"] = q_attr | ||
else: | ||
dq_attr = in_q_attr.copy() | ||
dq_attr["encoding"] = self.q_dq_map[out_q_attr["encoding"]] | ||
q_attr = out_q_attr.copy() | ||
n.meta["requantize"] = {"dq_attrs": dq_attr, "q_attrs": q_attr} | ||
|
||
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
for n in graph_module.graph.nodes: | ||
if ( | ||
n.op == "call_function" | ||
and n.meta.get("quant_attrs") | ||
and n.target in self.op_map | ||
): | ||
self.op_map[n.target](graph_module, n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import torch | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions | ||
|
||
|
||
class RecomposePixelShuffle(ExportPass): | ||
""" | ||
Merge decomposed operators back to one super node. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
graph = graph_module.graph | ||
# decomposed core aten ops | ||
partitions = get_source_partitions(graph, [torch.nn.PixelShuffle]) | ||
for _, src_partitions in partitions.items(): | ||
for src_partition in src_partitions: | ||
input_node = src_partition.input_nodes[0] | ||
output_node = src_partition.output_nodes[0] | ||
with graph.inserting_after(input_node): | ||
h_in_shape = input_node.meta["val"].shape[2] | ||
h_out_shape = output_node.meta["val"].shape[2] | ||
upscale_factor = h_out_shape / h_in_shape | ||
|
||
pixel_shuffle_node = graph.create_node( | ||
"call_function", | ||
exir_ops.edge.aten.pixel_shuffle.default, | ||
(input_node, int(upscale_factor)), | ||
) | ||
users = output_node.users.copy() | ||
for user in users: | ||
user.replace_input_with(output_node, pixel_shuffle_node) | ||
# copy metadata | ||
pixel_shuffle_node.meta = output_node.meta | ||
|
||
graph.eliminate_dead_code() | ||
graph_module.recompile() | ||
return PassResult(graph_module, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.