Skip to content

Commit

Permalink
[AO][Inductor] Enable WOQ fusion pattern with permute (pytorch#135928)
Browse files Browse the repository at this point in the history
**Summary**
Fix pytorch#135831 and pytorch/ao#890. The root cause of the numerical failure was that the customized woq-int8 kernel was not triggered due to changes in the pattern. After re-adding the fusion pattern, the accuracy check now passes. I will open a separate TorchAO PR to enable these unit tests in TorchAO.

**Test Plan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_woq_int8
```

Pull Request resolved: pytorch#135928
Approved by: https://github.com/jgong5, https://github.com/eellison
  • Loading branch information
leslie-fang-intel authored and Chao1Han committed Sep 20, 2024
1 parent 4d8f1e3 commit 5127c4b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 9 deletions.
25 changes: 21 additions & 4 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,19 +2585,36 @@ def forward(self, x):
@skipIfNoDynamoSupport
def test_woq_int8(self):
class M(torch.nn.Module):
def __init__(self, is_permute):
super().__init__()
self.is_permute = is_permute

def forward(self, x, weight, scales):
return torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales
if self.is_permute:
weight = weight.t()
m = torch.mm(
x.reshape(-1, x.shape[-1]),
weight.to(x.dtype),
)
y = m * scales.to(m.dtype)
y = y.reshape(*x.shape[:-1], y.shape[-1])
return y
else:
return (
torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales
)

mod = M().eval()
x_shape = (1, 1, 256)
w_shape = (12, 256)
s_shape = 12
x_strides = [
(256, 256, 1), # linear dispatching to mm
(256, 32, 1), # linear dispatching to bmm
]
for x_stride in x_strides:
is_permutes = [False, True]
for x_stride, is_permute in itertools.product(x_strides, is_permutes):
mod = M(is_permute=is_permute).eval()
x = torch.randn(x_shape, dtype=torch.bfloat16).as_strided(x_shape, x_stride)
w_shape = (12, 256)
w = torch.randint(-128, 127, w_shape, dtype=torch.int8)
s = torch.randn(s_shape, dtype=torch.bfloat16)

Expand Down
24 changes: 19 additions & 5 deletions torch/_inductor/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,27 @@ def _deduce_value(self, node: torch.fx.Node) -> Any:
return super().run_node(node)

def is_impure(self, node: torch.fx.node.Node) -> bool:
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
return (
node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value]
and isinstance(node.args[0], torch.fx.Node)
and "val" in node.args[0].meta
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
)

if (
node.target == torch.ops.prims.convert_element_type.default
and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type]
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
is_woq_int8_pattern(node)
or (
node.target == torch.ops.aten.permute.default
and len(node.users) == 1
and is_woq_int8_pattern(next(iter(node.users)))
)
) and is_const_source(
node.args[0], self.lifted_constants # type: ignore[arg-type]
):
# For int8_weight -> dq -> bf16_weight
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight
return True

quant_registered = (
Expand Down
22 changes: 22 additions & 0 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,27 @@ def _register_woq_mm_int8_pattern3():
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)


def _register_woq_mm_int8_pattern4():
_woq_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.mm.default,
KeywordArg("x"),
CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.permute.default,
KeywordArg("weight"),
Arg(),
),
Arg(),
),
),
KeywordArg("scales"),
)
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)


def _register_quantization_lowerings():
_register_quantization_unary_fusion()
_register_quantization_binary_fusion()
Expand All @@ -1573,6 +1594,7 @@ def _register_woq_lowerings():
_register_woq_mm_int8_pattern1()
_register_woq_mm_int8_pattern2()
_register_woq_mm_int8_pattern3()
_register_woq_mm_int8_pattern4()


def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
Expand Down

0 comments on commit 5127c4b

Please sign in to comment.