Skip to content

Commit

Permalink
add flag for dynamic shapes to filter out static ops (#3733)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3733

Since we have updated XNNPACK to support almost 100% of dynamic shape ops, we can now create static_op lists which do not have any dynamic shape support and filter them out instead

Reviewed By: digantdesai, kirklandsign

Differential Revision: D57787384

fbshipit-source-id: 21c9a2491c1b0a8ac46459ba70fec645025e9a2d
  • Loading branch information
mcr229 authored and facebook-github-bot committed May 29, 2024
1 parent 2eaed1b commit 55d11e1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 46 deletions.
18 changes: 8 additions & 10 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,12 @@

SUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES

# TODO delete this once we catch up to 100% of the supported op with dynamic shape support.
# This is tobe used only during the transition when we may not want to partition all the
# nodes for a dynamic model.
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
]
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE = [
torch.nn.Conv1d,
torch.nn.Conv2d,
# XNNPACK supports majority of shape dynamism, however some ops are
# explicitly static, so we maintain a set here to exclude them from
# dynamic shape support.
STATIC_OPS = [
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.slice_copy.Tensor,
]

STATIC_MODULES = []
46 changes: 11 additions & 35 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch

from executorch.backends.xnnpack.partition.configs import (
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE,
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE,
STATIC_MODULES,
STATIC_OPS,
SUPPORTED_DYN_QUANT_LINEAR_MODULES,
SUPPORTED_DYN_QUANT_MODULES,
SUPPORTED_MODULES,
Expand Down Expand Up @@ -838,7 +838,7 @@ def __init__(
supported_quant_modules: List[Callable] = SUPPORTED_QUANT_MODULES,
supported_quant_ops: Optional[List[Callable]] = SUPPORTED_QUANT_OPS,
quant: Optional[bool] = None,
_only_ops_with_dynamic_shape_support: Optional[bool] = False,
has_dynamic_shapes: bool = False,
_lower_recomposed_sdpa: Optional[bool] = True,
):
super().__init__()
Expand All @@ -851,44 +851,20 @@ def __init__(

self.quant = quant

if _only_ops_with_dynamic_shape_support is True:
self._update_op_lists_for_dynamic_shapes()

# TODO(T174256335) - remove this once we have a better way to handle >2d Mask
self._lower_recomposed_sdpa: bool = _lower_recomposed_sdpa or True

self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
self.partition_tags: Dict[str, DelegationSpec] = {}

def _update_op_lists_for_dynamic_shapes(self):
# Not ready for quants yet
assert (
self.quant is not True
), "Dynamic shape only supported for valid FP32 ops, no quants support yet."
self.supported_quant_ops = set()
self.supported_quant_modules = set()

# for supported ops
self.supported_ops_with_dynamic_shape = set(_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE)
assert self.supported_ops_with_dynamic_shape.issubset(
self.supported_ops
), "All ops with dynamic shape support must be in SUPPORTED_OPS"
self.supported_ops = self.supported_ops_with_dynamic_shape
log.info(
f"Xnnpack Partitioner updated supported op for dynamic shapes: {self.supported_ops}"
)

# for supported modules
self.supported_modules_with_dynamic_shape = set(
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE
)
assert self.supported_modules_with_dynamic_shape.issubset(
self.supported_modules
), "All modules with dynamic shape support must be in SUPPORTED_MODULES"
self.supported_modules = self.supported_modules_with_dynamic_shape
log.info(
f"Xnnpack Partitioner updated supported modules with dynamic shapes: {self.supported_modules}"
)
self.has_dynamic_shapes = has_dynamic_shapes
if has_dynamic_shapes:
self.supported_ops = self.supported_ops - set(STATIC_OPS)
self.supported_modules = self.supported_modules - set(STATIC_MODULES)
self.supported_quant_ops = self.supported_quant_ops - set(STATIC_OPS)
self.supported_quant_modules = self.supported_quant_modules - set(
STATIC_MODULES
)

def get_supported_modules(self, quant: bool) -> Set[Callable]:
"""
Expand Down
23 changes: 22 additions & 1 deletion backends/xnnpack/test/ops/slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.test.tester import Partition, Tester


class TestSliceCopy(unittest.TestCase):
Expand Down Expand Up @@ -112,6 +113,26 @@ def forward(self, x):
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
)

def test_fp32_static_slice_with_dynamic_dim(self):
"""
XNNPACK does not support dynamic dims with static slice
"""

class SliceCopy(torch.nn.Module):
def forward(self, x):
return x[1:3, -2:, :-1]

inputs = (torch.randn(5, 5, 5),)
(
Tester(SliceCopy(), inputs)
.export()
.to_edge()
.partition(
Partition(partitioner=XnnpackPartitioner(has_dynamic_shapes=True))
)
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
)

# Note: Slice ends up as slice_copy later in the process, but during quantization,
# it's still slice, which isn't supported by the XNNPACK quantizer.
@unittest.skip("T156004676 - slice isn't propagated")
Expand Down

0 comments on commit 55d11e1

Please sign in to comment.