Skip to content

Commit

Permalink
[ET-VK] Include FuseDequantLinearPass() in vulkan_preprocess
Browse files Browse the repository at this point in the history
## Context

Include `FuseDequantLinearPass` as a part of `vulkan_preprocess`, so that fusing the quant/dequant nodes added by `VulkanQuantizer` can be done as part of the lowering process.

Differential Revision: [D64249613](https://our.internmc.facebook.com/intern/diff/D64249613/)

[ghstack-poisoned]
  • Loading branch information
SS-JIA committed Oct 11, 2024
1 parent 5696b35 commit 2bac617
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ runtime.python_library(
"//executorch/backends/transforms:addmm_mm_to_linear",
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
"//executorch/backends/transforms:fuse_conv_with_clamp",
"//executorch/backends/transforms:fuse_dequant_linear",
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __contains__(self, op):

PRIM_OPS = [
operator.getitem,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
]

SUPPORTS_DYNAMIC_SHAPE = [
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FuseBatchNormWithConvPass,
)
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
Expand Down Expand Up @@ -59,6 +60,7 @@ def preprocess( # noqa: C901
passes = [
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
FuseViewCopyTransform(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
Expand Down

0 comments on commit 2bac617

Please sign in to comment.