diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index e5ce8ec2d7..a362c7dd8f 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -396,7 +396,8 @@ def register_node_visitor(visitor): and issubclass(visitor, NodeVisitor) and hasattr(visitor, "target") ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}" - _node_visitor_dict[visitor.target] = visitor + for target in visitor.target: + _node_visitor_dict[target] = visitor def generate_node_to_external_map( diff --git a/backends/qualcomm/builders/op_add.py b/backends/qualcomm/builders/op_add.py index f151ca6698..ce61db2d6a 100644 --- a/backends/qualcomm/builders/op_add.py +++ b/backends/qualcomm/builders/op_add.py @@ -15,7 +15,7 @@ @register_node_visitor class Add(NodeVisitor): - target = "aten.add.Tensor" + target = ["aten.add.Tensor"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py index 38c3bd6d47..e6a6fd3f1c 100644 --- a/backends/qualcomm/builders/op_avg_pool2d.py +++ b/backends/qualcomm/builders/op_avg_pool2d.py @@ -16,7 +16,7 @@ @register_node_visitor class AvgPool2d(NodeVisitor): - target = "aten.avg_pool2d.default" + target = ["aten.avg_pool2d.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py index 280cc86d7b..a0efda4072 100644 --- a/backends/qualcomm/builders/op_batch_norm.py +++ b/backends/qualcomm/builders/op_batch_norm.py @@ -16,7 +16,7 @@ @register_node_visitor class BatchNorm(NodeVisitor): - target = "aten._native_batch_norm_legit_no_training.default" + target = ["aten._native_batch_norm_legit_no_training.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_bmm.py b/backends/qualcomm/builders/op_bmm.py index 4648321a6b..c207d73ad7 100644 --- a/backends/qualcomm/builders/op_bmm.py +++ b/backends/qualcomm/builders/op_bmm.py @@ -15,7 +15,7 @@ @register_node_visitor class BMM(NodeVisitor): - target = "aten.bmm.default" + target = ["aten.bmm.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_cast.py b/backends/qualcomm/builders/op_cast.py index 18666e5544..d8173126c1 100644 --- a/backends/qualcomm/builders/op_cast.py +++ b/backends/qualcomm/builders/op_cast.py @@ -15,7 +15,7 @@ @register_node_visitor class Cast(NodeVisitor): - target = "aten._to_copy.default" + target = ["aten._to_copy.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index 4cbfd6b542..bd7e8153c7 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -16,7 +16,7 @@ @register_node_visitor class Cat(NodeVisitor): - target = "aten.cat.default" + target = ["aten.cat.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_ceil.py b/backends/qualcomm/builders/op_ceil.py index 00ce561440..c486669fc8 100644 --- a/backends/qualcomm/builders/op_ceil.py +++ b/backends/qualcomm/builders/op_ceil.py @@ -15,7 +15,7 @@ @register_node_visitor class Ceil(NodeVisitor): - target = "aten.ceil.default" + target = ["aten.ceil.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py index 9417f726d5..24f2e01964 100644 --- a/backends/qualcomm/builders/op_clamp.py +++ b/backends/qualcomm/builders/op_clamp.py @@ -16,7 +16,7 @@ @register_node_visitor class Clamp(NodeVisitor): - target = "aten.clamp.default" + target = ["aten.clamp.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index f899e98efd..5c20b1372e 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -24,7 +24,7 @@ @register_node_visitor class Conv2d(NodeVisitor): - target = "aten.convolution.default" + target = ["aten.convolution.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py index 41e141cbfa..8624b6eb07 100644 --- a/backends/qualcomm/builders/op_depth_to_space.py +++ b/backends/qualcomm/builders/op_depth_to_space.py @@ -17,7 +17,7 @@ @register_node_visitor class DepthToSpaceVisitor(NodeVisitor): - target = "aten.pixel_shuffle.default" + target = ["aten.pixel_shuffle.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_dequantize.py b/backends/qualcomm/builders/op_dequantize.py index 0574a4e2e2..56eb609575 100644 --- a/backends/qualcomm/builders/op_dequantize.py +++ b/backends/qualcomm/builders/op_dequantize.py @@ -55,19 +55,19 @@ def define_node( @register_node_visitor class PerTensorDequantizeDefault(DequantizeOpBase): - target = "quantized_decomposed.dequantize_per_tensor.default" + target = ["quantized_decomposed.dequantize_per_tensor.default"] @register_node_visitor class PerTensorDequantizeTensor(DequantizeOpBase): - target = "quantized_decomposed.dequantize_per_tensor.tensor" + target = ["quantized_decomposed.dequantize_per_tensor.tensor"] @register_node_visitor class PerChannelDequantizeDefault(DequantizeOpBase): - target = "quantized_decomposed.dequantize_per_channel.default" + target = ["quantized_decomposed.dequantize_per_channel.default"] @register_node_visitor class PerChannelDequantizeTensor(DequantizeOpBase): - target = "quantized_decomposed.dequantize_per_channel.tensor" + target = ["quantized_decomposed.dequantize_per_channel.tensor"] diff --git a/backends/qualcomm/builders/op_div.py b/backends/qualcomm/builders/op_div.py index 4f0157bbdf..6b4e674349 100644 --- a/backends/qualcomm/builders/op_div.py +++ b/backends/qualcomm/builders/op_div.py @@ -15,7 +15,7 @@ @register_node_visitor class Div(NodeVisitor): - target = "aten.div.Tensor" + target = ["aten.div.Tensor"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index 60d3a3906c..faf33eac12 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -17,7 +17,7 @@ @register_node_visitor class Embedding(NodeVisitor): - target = "aten.embedding.default" + target = ["aten.embedding.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py index afef5e2269..a1ef1c2949 100644 --- a/backends/qualcomm/builders/op_expand.py +++ b/backends/qualcomm/builders/op_expand.py @@ -16,7 +16,7 @@ @register_node_visitor class Expand(NodeVisitor): - target = "aten.expand_copy.default" + target = ["aten.expand_copy.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_gelu.py b/backends/qualcomm/builders/op_gelu.py index 7dd627ce58..c488d6b5d8 100644 --- a/backends/qualcomm/builders/op_gelu.py +++ b/backends/qualcomm/builders/op_gelu.py @@ -16,7 +16,7 @@ @register_node_visitor class GeluVisitor(NodeVisitor): - target = "aten.gelu.default" + target = ["aten.gelu.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_hardswish.py b/backends/qualcomm/builders/op_hardswish.py index 940bfc7d42..c7ad702ae6 100644 --- a/backends/qualcomm/builders/op_hardswish.py +++ b/backends/qualcomm/builders/op_hardswish.py @@ -16,7 +16,7 @@ @register_node_visitor class HardSwishVisitor(NodeVisitor): - target = "aten.hardswish.default" + target = ["aten.hardswish.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py index 0f16d006da..d7d322cbdc 100644 --- a/backends/qualcomm/builders/op_hardtanh.py +++ b/backends/qualcomm/builders/op_hardtanh.py @@ -17,7 +17,7 @@ @register_node_visitor class HardTanhVisitor(NodeVisitor): - target = "aten.hardtanh.default" + target = ["aten.hardtanh.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 53a30e434f..1f4b47672d 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -18,7 +18,7 @@ @register_node_visitor class LayerNormVisitor(NodeVisitor): - target = "aten.native_layer_norm.default" + target = ["aten.native_layer_norm.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 907bda3d81..1e75df7b1a 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -17,7 +17,7 @@ @register_node_visitor class LinearVisitor(NodeVisitor): - target = "aten.linear.default" + target = ["aten.linear.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index a8259a5ca9..de01a14f89 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -16,7 +16,7 @@ @register_node_visitor class LogSoftmax(NodeVisitor): - target = "aten._log_softmax.default" + target = ["aten._log_softmax.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_matmul.py b/backends/qualcomm/builders/op_matmul.py index 68540949b7..9a94cb1d60 100644 --- a/backends/qualcomm/builders/op_matmul.py +++ b/backends/qualcomm/builders/op_matmul.py @@ -15,7 +15,7 @@ @register_node_visitor class Matmul(NodeVisitor): - target = "aten.matmul.default" + target = ["aten.matmul.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py index 9b8076ba22..f64a13faed 100644 --- a/backends/qualcomm/builders/op_max_pool2d.py +++ b/backends/qualcomm/builders/op_max_pool2d.py @@ -16,7 +16,7 @@ @register_node_visitor class MaxPool2d(NodeVisitor): - target = "aten.max_pool2d_with_indices.default" + target = ["aten.max_pool2d_with_indices.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index 4d151eb9f2..29d9f8b30f 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -17,7 +17,7 @@ @register_node_visitor class MeanDim(NodeVisitor): - target = "aten.mean.dim" + target = ["aten.mean.dim"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_mul.py b/backends/qualcomm/builders/op_mul.py index 943891a993..645910b7d9 100644 --- a/backends/qualcomm/builders/op_mul.py +++ b/backends/qualcomm/builders/op_mul.py @@ -15,7 +15,7 @@ @register_node_visitor class Mul(NodeVisitor): - target = "aten.mul.Tensor" + target = ["aten.mul.Tensor"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index bf3bbbcab8..677cf77bd2 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -16,7 +16,7 @@ @register_node_visitor class Pad(NodeVisitor): - target = "aten.constant_pad_nd.default" + target = ["aten.constant_pad_nd.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py index 14f4edd9f5..cae2e68161 100644 --- a/backends/qualcomm/builders/op_pow.py +++ b/backends/qualcomm/builders/op_pow.py @@ -17,7 +17,7 @@ # TODO Add more class Like PowTensorTensor if needed @register_node_visitor class PowTensorScalar(NodeVisitor): - target = "aten.pow.Tensor_Scalar" + target = ["aten.pow.Tensor_Scalar"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_quantize.py b/backends/qualcomm/builders/op_quantize.py index e1d491cadb..b74ca7fb6d 100644 --- a/backends/qualcomm/builders/op_quantize.py +++ b/backends/qualcomm/builders/op_quantize.py @@ -61,9 +61,9 @@ def define_node( @register_node_visitor class PerTensorQuantize(QuantizeOpBase): - target = "quantized_decomposed.quantize_per_tensor.default" + target = ["quantized_decomposed.quantize_per_tensor.default"] @register_node_visitor class PerChannelQuantize(QuantizeOpBase): - target = "quantized_decomposed.quantize_per_channel.default" + target = ["quantized_decomposed.quantize_per_channel.default"] diff --git a/backends/qualcomm/builders/op_relu.py b/backends/qualcomm/builders/op_relu.py index 52cbd410ee..d6c5ff79bc 100644 --- a/backends/qualcomm/builders/op_relu.py +++ b/backends/qualcomm/builders/op_relu.py @@ -15,7 +15,7 @@ @register_node_visitor class Relu(NodeVisitor): - target = "aten.relu.default" + target = ["aten.relu.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_reshape.py b/backends/qualcomm/builders/op_reshape.py index 96278b0f80..23eb1ff59b 100644 --- a/backends/qualcomm/builders/op_reshape.py +++ b/backends/qualcomm/builders/op_reshape.py @@ -15,7 +15,7 @@ @register_node_visitor class Reshape(NodeVisitor): - target = "aten.view_copy.default" + target = ["aten.view_copy.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_rsqrt.py b/backends/qualcomm/builders/op_rsqrt.py index 5976cab67f..cf3e8c5e38 100644 --- a/backends/qualcomm/builders/op_rsqrt.py +++ b/backends/qualcomm/builders/op_rsqrt.py @@ -15,7 +15,7 @@ @register_node_visitor class Rsqrt(NodeVisitor): - target = "aten.rsqrt.default" + target = ["aten.rsqrt.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py index ea53467521..5d74d038f7 100644 --- a/backends/qualcomm/builders/op_select_copy.py +++ b/backends/qualcomm/builders/op_select_copy.py @@ -17,7 +17,7 @@ @register_node_visitor class SelectCopy(NodeVisitor): - target = "aten.select_copy.int" + target = ["aten.select_copy.int", "aten.select.int"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_sigmoid.py b/backends/qualcomm/builders/op_sigmoid.py index b6eeb88935..3b7dd2abe2 100644 --- a/backends/qualcomm/builders/op_sigmoid.py +++ b/backends/qualcomm/builders/op_sigmoid.py @@ -15,7 +15,7 @@ @register_node_visitor class Sigmoid(NodeVisitor): - target = "aten.sigmoid.default" + target = ["aten.sigmoid.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_skip_ops.py b/backends/qualcomm/builders/op_skip_ops.py index f91ef70d44..9a1839f604 100644 --- a/backends/qualcomm/builders/op_skip_ops.py +++ b/backends/qualcomm/builders/op_skip_ops.py @@ -35,7 +35,7 @@ class OpGetItem(OpSkipOps): do nothing if node is getitem """ - target = "getitem" + target = ["getitem"] def define_node( self, diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 5ed2b99cc0..6d121135e4 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -16,7 +16,7 @@ @register_node_visitor class StrideSlice(NodeVisitor): - target = "aten.slice_copy.Tensor" + target = ["aten.slice_copy.Tensor"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py index 2a1abce3d5..031c0244f3 100644 --- a/backends/qualcomm/builders/op_softmax.py +++ b/backends/qualcomm/builders/op_softmax.py @@ -16,7 +16,7 @@ @register_node_visitor class Softmax(NodeVisitor): - target = "aten._softmax.default" + target = ["aten._softmax.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_squeeze.py b/backends/qualcomm/builders/op_squeeze.py index 43ef39fab7..b13643783c 100644 --- a/backends/qualcomm/builders/op_squeeze.py +++ b/backends/qualcomm/builders/op_squeeze.py @@ -15,7 +15,7 @@ @register_node_visitor class Squeeze(NodeVisitor): - target = "aten.squeeze_copy.dims" + target = ["aten.squeeze_copy.dims", "aten.squeeze.dims"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_sub.py b/backends/qualcomm/builders/op_sub.py index 212e7a75cd..131fecd4cf 100644 --- a/backends/qualcomm/builders/op_sub.py +++ b/backends/qualcomm/builders/op_sub.py @@ -15,7 +15,7 @@ @register_node_visitor class Sub(NodeVisitor): - target = "aten.sub.Tensor" + target = ["aten.sub.Tensor"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_tanh.py b/backends/qualcomm/builders/op_tanh.py index cff4f7e447..af37256046 100644 --- a/backends/qualcomm/builders/op_tanh.py +++ b/backends/qualcomm/builders/op_tanh.py @@ -16,7 +16,7 @@ @register_node_visitor class Tanh(NodeVisitor): - target = "aten.tanh.default" + target = ["aten.tanh.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py index 161e8cef9d..7dc9352673 100644 --- a/backends/qualcomm/builders/op_transpose.py +++ b/backends/qualcomm/builders/op_transpose.py @@ -17,7 +17,7 @@ @register_node_visitor class TransposeVisitor(NodeVisitor): - target = "aten.permute_copy.default" + target = ["aten.permute_copy.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_unsqueeze.py b/backends/qualcomm/builders/op_unsqueeze.py index 636dc94e84..1a94903291 100644 --- a/backends/qualcomm/builders/op_unsqueeze.py +++ b/backends/qualcomm/builders/op_unsqueeze.py @@ -15,7 +15,7 @@ @register_node_visitor class Unsqueeze(NodeVisitor): - target = "aten.unsqueeze_copy.default" + target = ["aten.unsqueeze_copy.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py index f32f136aa1..b383693ead 100644 --- a/backends/qualcomm/builders/op_upsample_bilinear2d.py +++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py @@ -15,7 +15,7 @@ @register_node_visitor class ResizeBilinear(NodeVisitor): - target = "aten.upsample_bilinear2d.default" + target = ["aten.upsample_bilinear2d.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index ee7d6a7a3b..585711c378 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -201,7 +201,8 @@ def lower_module_and_test_output( # Assert the backend name is qnn self.assertEqual( - len(exec_prog.program.execution_plan[0].delegates), expected_partitions + len(exec_prog.program.execution_plan[0].delegates), + expected_partitions, ) for i in range(expected_partitions): self.assertEqual( diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 7fa696efba..b5c5d4dfed 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Tuple +from typing import Callable, Dict, List, Tuple import executorch.exir as exir @@ -19,8 +19,6 @@ ConvertBinaryOpsWithScalar, ) from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul -from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid -from executorch.backends.qualcomm.passes.convert_hardswish import ConvertHardswish from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import ( ConvertInterpolateWithUpsample2D, ) @@ -29,9 +27,6 @@ from executorch.backends.qualcomm.passes.i64_to_i32 import I64toI32 from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform -from executorch.backends.qualcomm.passes.recompose_pixel_shuffle import ( - RecomposePixelShuffle, -) from executorch.backends.qualcomm.passes.remove_clone import RemoveClone from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( _soc_info_table, @@ -49,7 +44,9 @@ convert_to_flatbuffer, convert_to_option, ) +from executorch.exir import ExirExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes @@ -86,16 +83,27 @@ def canonicalize_program(prog: ExportedProgram): ) +def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: + source_decompositions = torch_core_aten_decompositions() + # The below super ops are supported by QNN + remove_decompositions = [ + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.hardswish.default, + ] + + for key in remove_decompositions: + source_decompositions.pop(key) + + return source_decompositions + + def _transform(edge_program: ExportedProgram) -> None: # currently ExirExportedProgram.transform does not accept # changes of input number which was caused by FoldQDQ # apply passes one by one here to avoid IR capture failure graph_module = edge_program.graph_module RemoveClone()(graph_module) - RecomposePixelShuffle()(graph_module) ConvertToLinear()(graph_module) - ConvertHardsigmoid()(graph_module) - ConvertHardswish()(graph_module) ConvertBmmToMatmul()(graph_module) ConvertInterpolateWithUpsample2D()(graph_module) I64toI32(edge_program)(graph_module) @@ -111,19 +119,18 @@ def capture_program( module: torch.nn.Module, inputs: Tuple[torch.Tensor], ) -> exir.ExirExportedProgram: - # TODO: should switch to torch.export.export & custom deomposition - # to reduce maintaining effort. - exir_exported_program = exir.capture( - module, - inputs, - qnn_capture_config(), - ) + ep = torch.export.export(module, inputs) + decomposed_ep = ep.run_decompositions(get_decomp_table()) + # We choose call_operator by target in ConvertBinaryOpsWithScalar # because it is the same source_fn_stack for MultiheadAttention - exir_exported_program.transform(ConvertBinaryOpsWithScalar()) - ex_prog = exir_exported_program.to_edge(qnn_edge_config()) - _transform(ex_prog.exported_program) - return ex_prog + # TODO: Should modify the scalar op in the op builder instead of + # using transformation + core_ep = ExirExportedProgram(decomposed_ep, False) + core_ep.transform(ConvertBinaryOpsWithScalar()) + edge_ep = core_ep.to_edge(qnn_edge_config()) + _transform(edge_ep.exported_program) + return edge_ep def draw_graph(title, path, graph_module: torch.fx.GraphModule):