From be349962f3362f8afde4f083ec04d335245992bb Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 6 Nov 2023 13:00:04 +0800 Subject: [PATCH] Cherry pick #2235 to rls2.1 (#2236) * cherry pick #2235 * Update linear_fusion.py * Update linear_fusion.py * Update linear_fusion.py --- .../models/cpu/fusions/linear_fusion.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py index 806833257..38094f095 100644 --- a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py @@ -27,7 +27,7 @@ def __init__(self, module, tpp=False, woq=False): self.linear = module def forward(self, x): - if self.tpp: + if self.tpp and not self.linear.tpp_fallback: x = x.to(self.dtype).contiguous() return torch.ops.torch_ipex.tpp_linear_silu( x, @@ -45,7 +45,7 @@ def __init__(self, module, tpp=False, woq=False): self.linear = module def forward(self, x): - if self.tpp: + if self.tpp and not self.linear.tpp_fallback: x = x.to(self.dtype).contiguous() return torch.ops.torch_ipex.tpp_linear_relu( x, @@ -63,7 +63,7 @@ def __init__(self, module, tpp=False, woq=False): self.linear = module def forward(self, x, y): - if self.tpp: + if self.tpp and not self.linear.tpp_fallback: x = x.to(self.dtype).contiguous() y = y.to(self.dtype).contiguous() return torch.ops.torch_ipex.tpp_linear_mul( @@ -83,7 +83,7 @@ def __init__(self, module, tpp=False, woq=False): self.linear = module def forward(self, x, y): - if self.tpp: + if self.tpp and not self.linear.tpp_fallback: x = x.to(self.dtype).contiguous() y = y.to(self.dtype).contiguous() return torch.ops.torch_ipex.tpp_linear_add( @@ -114,7 +114,7 @@ def __init__(self, module, tpp=False, woq=False): self.linear = module def forward(self, x, y, z): - if self.tpp: + if self.tpp and not self.linear.tpp_fallback: x = x.to(self.dtype).contiguous() y = y.to(self.dtype).contiguous() z = z.to(self.dtype).contiguous() @@ -147,7 +147,7 @@ def __init__(self, module, tpp=False, woq=False): self.linear = module def forward(self, x): - if self.tpp: + if self.tpp and not self.linear.tpp_fallback: x = x.to(self.dtype).contiguous() return torch.ops.torch_ipex.tpp_linear_gelu( x, @@ -186,7 +186,7 @@ def __init__(self, module, tpp=False, woq=False): self.gelu = nn.GELU() def forward(self, x): - if self.tpp: + if self.tpp and not self.linear.tpp_fallback: x = x.to(self.dtype).contiguous() return torch.ops.torch_ipex.tpp_linear_gelu( x, @@ -320,7 +320,11 @@ def __init__(self, module_s, module_m, tpp=False, woq=False): self.dtype = module_s.weight.dtype if self.tpp else None def forward(self, x): - if self.tpp: + if ( + self.tpp + and not self.linear_s.tpp_fallback + and not self.linear_m.tpp_fallback + ): x = x.to(self.dtype).contiguous() x1 = torch.ops.torch_ipex.tpp_linear_silu( x,