Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Adapt to new IR capture flow (#2431)
Browse files Browse the repository at this point in the history
Summary:
- Change existent IR capture flow (exir.capture) to torch.export.export
- Add custom decomposition table for mitigating maintaining effort
- Fix breakages encountered and make sure all tests passed as well

Pull Request resolved: #2431

Reviewed By: mergennachin

Differential Revision: D55353449

Pulled By: cccclai

fbshipit-source-id: aa2e27d0ae93aa62208fd03ec39b3891a70b954e
  • Loading branch information
chuntl authored and facebook-github-bot committed Mar 28, 2024
1 parent 3cf9f22 commit e7a429a
Show file tree
Hide file tree
Showing 44 changed files with 76 additions and 67 deletions.
3 changes: 2 additions & 1 deletion backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ def register_node_visitor(visitor):
and issubclass(visitor, NodeVisitor)
and hasattr(visitor, "target")
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
_node_visitor_dict[visitor.target] = visitor
for target in visitor.target:
_node_visitor_dict[target] = visitor


def generate_node_to_external_map(
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Add(NodeVisitor):
target = "aten.add.Tensor"
target = ["aten.add.Tensor"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class AvgPool2d(NodeVisitor):
target = "aten.avg_pool2d.default"
target = ["aten.avg_pool2d.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class BatchNorm(NodeVisitor):
target = "aten._native_batch_norm_legit_no_training.default"
target = ["aten._native_batch_norm_legit_no_training.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class BMM(NodeVisitor):
target = "aten.bmm.default"
target = ["aten.bmm.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Cast(NodeVisitor):
target = "aten._to_copy.default"
target = ["aten._to_copy.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class Cat(NodeVisitor):
target = "aten.cat.default"
target = ["aten.cat.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_ceil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Ceil(NodeVisitor):
target = "aten.ceil.default"
target = ["aten.ceil.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class Clamp(NodeVisitor):
target = "aten.clamp.default"
target = ["aten.clamp.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@register_node_visitor
class Conv2d(NodeVisitor):
target = "aten.convolution.default"
target = ["aten.convolution.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_depth_to_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class DepthToSpaceVisitor(NodeVisitor):
target = "aten.pixel_shuffle.default"
target = ["aten.pixel_shuffle.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
8 changes: 4 additions & 4 deletions backends/qualcomm/builders/op_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ def define_node(

@register_node_visitor
class PerTensorDequantizeDefault(DequantizeOpBase):
target = "quantized_decomposed.dequantize_per_tensor.default"
target = ["quantized_decomposed.dequantize_per_tensor.default"]


@register_node_visitor
class PerTensorDequantizeTensor(DequantizeOpBase):
target = "quantized_decomposed.dequantize_per_tensor.tensor"
target = ["quantized_decomposed.dequantize_per_tensor.tensor"]


@register_node_visitor
class PerChannelDequantizeDefault(DequantizeOpBase):
target = "quantized_decomposed.dequantize_per_channel.default"
target = ["quantized_decomposed.dequantize_per_channel.default"]


@register_node_visitor
class PerChannelDequantizeTensor(DequantizeOpBase):
target = "quantized_decomposed.dequantize_per_channel.tensor"
target = ["quantized_decomposed.dequantize_per_channel.tensor"]
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Div(NodeVisitor):
target = "aten.div.Tensor"
target = ["aten.div.Tensor"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class Embedding(NodeVisitor):
target = "aten.embedding.default"
target = ["aten.embedding.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class Expand(NodeVisitor):
target = "aten.expand_copy.default"
target = ["aten.expand_copy.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class GeluVisitor(NodeVisitor):
target = "aten.gelu.default"
target = ["aten.gelu.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_hardswish.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class HardSwishVisitor(NodeVisitor):
target = "aten.hardswish.default"
target = ["aten.hardswish.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class HardTanhVisitor(NodeVisitor):
target = "aten.hardtanh.default"
target = ["aten.hardtanh.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

@register_node_visitor
class LayerNormVisitor(NodeVisitor):
target = "aten.native_layer_norm.default"
target = ["aten.native_layer_norm.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class LinearVisitor(NodeVisitor):
target = "aten.linear.default"
target = ["aten.linear.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class LogSoftmax(NodeVisitor):
target = "aten._log_softmax.default"
target = ["aten._log_softmax.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Matmul(NodeVisitor):
target = "aten.matmul.default"
target = ["aten.matmul.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class MaxPool2d(NodeVisitor):
target = "aten.max_pool2d_with_indices.default"
target = ["aten.max_pool2d_with_indices.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class MeanDim(NodeVisitor):
target = "aten.mean.dim"
target = ["aten.mean.dim"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Mul(NodeVisitor):
target = "aten.mul.Tensor"
target = ["aten.mul.Tensor"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class Pad(NodeVisitor):
target = "aten.constant_pad_nd.default"
target = ["aten.constant_pad_nd.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# TODO Add more class Like PowTensorTensor if needed
@register_node_visitor
class PowTensorScalar(NodeVisitor):
target = "aten.pow.Tensor_Scalar"
target = ["aten.pow.Tensor_Scalar"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/op_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def define_node(

@register_node_visitor
class PerTensorQuantize(QuantizeOpBase):
target = "quantized_decomposed.quantize_per_tensor.default"
target = ["quantized_decomposed.quantize_per_tensor.default"]


@register_node_visitor
class PerChannelQuantize(QuantizeOpBase):
target = "quantized_decomposed.quantize_per_channel.default"
target = ["quantized_decomposed.quantize_per_channel.default"]
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Relu(NodeVisitor):
target = "aten.relu.default"
target = ["aten.relu.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Reshape(NodeVisitor):
target = "aten.view_copy.default"
target = ["aten.view_copy.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_rsqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Rsqrt(NodeVisitor):
target = "aten.rsqrt.default"
target = ["aten.rsqrt.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_select_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class SelectCopy(NodeVisitor):
target = "aten.select_copy.int"
target = ["aten.select_copy.int", "aten.select.int"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Sigmoid(NodeVisitor):
target = "aten.sigmoid.default"
target = ["aten.sigmoid.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_skip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class OpGetItem(OpSkipOps):
do nothing if node is getitem
"""

target = "getitem"
target = ["getitem"]

def define_node(
self,
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class StrideSlice(NodeVisitor):
target = "aten.slice_copy.Tensor"
target = ["aten.slice_copy.Tensor"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class Softmax(NodeVisitor):
target = "aten._softmax.default"
target = ["aten._softmax.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Squeeze(NodeVisitor):
target = "aten.squeeze_copy.dims"
target = ["aten.squeeze_copy.dims", "aten.squeeze.dims"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Sub(NodeVisitor):
target = "aten.sub.Tensor"
target = ["aten.sub.Tensor"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@register_node_visitor
class Tanh(NodeVisitor):
target = "aten.tanh.default"
target = ["aten.tanh.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@register_node_visitor
class TransposeVisitor(NodeVisitor):
target = "aten.permute_copy.default"
target = ["aten.permute_copy.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class Unsqueeze(NodeVisitor):
target = "aten.unsqueeze_copy.default"
target = ["aten.unsqueeze_copy.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_upsample_bilinear2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@register_node_visitor
class ResizeBilinear(NodeVisitor):
target = "aten.upsample_bilinear2d.default"
target = ["aten.upsample_bilinear2d.default"]

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
Loading

0 comments on commit e7a429a

Please sign in to comment.