diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/folding_common_utils.h b/intel_extension_for_pytorch/csrc/jit/cpu/passes/folding_common_utils.h new file mode 100644 index 000000000..88bea7972 --- /dev/null +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/folding_common_utils.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +inline bool nonConstantParameters(Node* n) { + // Checks if the parameters, not including the + // first param are all constants. + for (size_t i = 1; i < n->inputs().size(); i++) { + if (n->inputs().at(i)->node()->kind() != prim::Constant) { + return true; + } + } + return false; +} + +inline bool supportedAddOrSub(Node* n) { + if (n->kind() == aten::add || n->kind() == aten::sub) { + return true; + } else { + return false; + } +} + +inline bool supportedMulOrDiv(Node* n) { + if (n->kind() == aten::mul || n->kind() == aten::div) { + return true; + } else { + return false; + } +} + +inline at::Tensor resizeConstantScalarOrTensorToShape( + Value* v, + const std::vector& shape, + at::TensorOptions options) { + at::Tensor ret_tensor; + if (v->type()->cast()) { + ret_tensor = constant_as(v).value(); + } else { + ret_tensor = at::zeros(shape, options); + if (v->type()->cast()) { + ret_tensor.fill_(constant_as(v).value()); + } else { + ret_tensor.fill_(constant_as(v).value()); + } + } + + if (ret_tensor.numel() == 1) { + // expand errors if the shape input has less # dims than the tensor input + ret_tensor = ret_tensor.reshape({1}); + ret_tensor = ret_tensor.expand(shape); + } else { + TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape)); + ret_tensor = ret_tensor.view(shape); + } + return ret_tensor; +} + +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_conv_folding.cpp b/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_conv_folding.cpp index a861639fc..c26ac3c0f 100644 --- a/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_conv_folding.cpp +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_conv_folding.cpp @@ -12,6 +12,7 @@ #include #include "csrc/aten/cpu/WeightPack.h" +#include "folding_common_utils.h" #include "frozen_conv_folding.h" namespace torch { @@ -21,17 +22,6 @@ namespace { using Tensor = at::Tensor; -bool nonConstantParameters(Node* n) { - // Checks if the parameters, not including the - // first param are all constants. - for (size_t i = 1; i < n->inputs().size(); i++) { - if (n->inputs().at(i)->node()->kind() != prim::Constant) { - return true; - } - } - return false; -} - bool supportedConvNode(Node* n) { if (n->kind() == aten::conv2d || n->kind() == aten::conv3d || n->kind() == Symbol::fromQualString("torch_ipex::convolution_forward")) { @@ -41,14 +31,6 @@ bool supportedConvNode(Node* n) { } } -bool supportedAddOrSub(Node* n) { - if (n->kind() == aten::add || n->kind() == aten::sub) { - return true; - } else { - return false; - } -} - // In order to fuse add/sub/mul/div with conv, the dimensions of its // constant tensor must satisfy the following: // - with resizing, broadcast to w/ weight/bias tensor shape @@ -137,33 +119,6 @@ bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) { return true; } -Tensor resizeConstantScalarOrTensorToShape( - Value* v, - const std::vector& shape, - at::TensorOptions options) { - Tensor ret_tensor; - if (v->type()->cast()) { - ret_tensor = constant_as(v).value(); - } else { - ret_tensor = at::zeros(shape, options); - if (v->type()->cast()) { - ret_tensor.fill_(constant_as(v).value()); - } else { - ret_tensor.fill_(constant_as(v).value()); - } - } - - if (ret_tensor.numel() == 1) { - // expand errors if the shape input has less # dims than the tensor input - ret_tensor = ret_tensor.reshape({1}); - ret_tensor = ret_tensor.expand(shape); - } else { - TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape)); - ret_tensor = ret_tensor.view(shape); - } - return ret_tensor; -} - void FoldFrozenConvAddOrSub(Block* b) { for (Node* n : b->nodes()) { for (Block* block : n->blocks()) { @@ -224,14 +179,6 @@ void FoldFrozenConvAddOrSub(Block* b) { } } -bool supportedMulOrDiv(Node* n) { - if (n->kind() == aten::mul || n->kind() == aten::div) { - return true; - } else { - return false; - } -} - void FoldFrozenConvMulOrDiv(Block* b) { for (Node* n : b->nodes()) { for (Block* block : n->blocks()) { diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_linear_folding.cpp b/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_linear_folding.cpp new file mode 100644 index 000000000..c5dd635ba --- /dev/null +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_linear_folding.cpp @@ -0,0 +1,251 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "csrc/aten/cpu/WeightPack.h" +#include "folding_common_utils.h" +#include "frozen_linear_folding.h" + +namespace torch { +namespace jit { + +namespace { + +using Tensor = at::Tensor; + +bool supportedLinearNode(Node* n) { + if (n->kind() == aten::linear || + n->kind() == Symbol::fromQualString("torch_ipex::ipex_linear")) { + return true; + } else { + return false; + } +} + +bool checkLinearAndBroadcastingOpPreConditions(Node* linear, Node* op) { + if (nonConstantParameters(linear) || nonConstantParameters(op)) { + return false; + } + + if (linear->output()->uses().size() > 1) { + return false; + } + + Tensor weight_tensor = + constant_as(linear->namedInput("weight")).value(); + + // avoid fusing op that causes type promotion + // resticting to float avoids int/float difficulties with scalar overload + if (!weight_tensor.is_floating_point()) { + return false; + } + + if (op->inputs().at(1)->type()->cast()) { + auto op_tensor = constant_as(op->inputs().at(1)).value(); + + int64_t output_channel; + if (linear->kind() == aten::linear) { + output_channel = + constant_as(linear->namedInput("weight")).value().size(0); + } else { + output_channel = + constant_as(linear->namedInput("out_features")).value(); + } + if (op_tensor.sizes() != at::IntArrayRef({1, output_channel}) && + op_tensor.sizes() != at::IntArrayRef({output_channel})) { + return false; + } + + if (!op_tensor.is_floating_point() && + c10::promoteTypes( + op_tensor.scalar_type(), weight_tensor.scalar_type()) != + weight_tensor.scalar_type()) { + return false; + } + } + + return true; +} + +void FoldFrozenLinearAddOrSub(Block* b) { + for (Node* n : b->nodes()) { + for (Block* block : n->blocks()) { + FoldFrozenLinearAddOrSub(block); + } + + if (supportedAddOrSub(n) && + supportedLinearNode(n->inputs().at(0)->node())) { + auto linear = n->inputs().at(0)->node(); + auto add_or_sub = n; + + if (!checkLinearAndBroadcastingOpPreConditions(linear, add_or_sub)) { + continue; + } + + Tensor weight_tensor = + constant_as(linear->namedInput("weight")).value(); + + Tensor add_or_sub_tensor; + if (linear->kind() == aten::linear) { + add_or_sub_tensor = resizeConstantScalarOrTensorToShape( + add_or_sub->inputs().at(1), + {weight_tensor.size(0)}, + weight_tensor.options()); + } else { + add_or_sub_tensor = resizeConstantScalarOrTensorToShape( + add_or_sub->inputs().at(1), + {constant_as(linear->namedInput("out_features")).value()}, + weight_tensor.options()); + } + Tensor bias; + if (linear->namedInput("bias")->type() == NoneType::get()) { + bias = at::zeros_like(add_or_sub_tensor, weight_tensor.dtype()); + } else { + bias = constant_as(linear->namedInput("bias")).value(); + } + + WithInsertPoint guard(linear); + + add_or_sub->replaceInputWith( + linear->output(), b->owningGraph()->insertConstant(bias)); + add_or_sub->replaceInput( + 1, b->owningGraph()->insertConstant(add_or_sub_tensor)); + + auto stack_out = runNodeIfInputsAreConstant(add_or_sub); + TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); + Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype()); + + auto fused_linear_b = b->owningGraph()->insertConstant(fuse_bias); + auto linear_b_value = linear->namedInput("bias"); + + fused_linear_b->setDebugName( + linear_b_value->debugName() + "_fused_" + + add_or_sub->kind().toUnqualString()); + linear->replaceInputWith(linear_b_value, fused_linear_b); + add_or_sub->output()->replaceAllUsesWith(linear->output()); + // DCE run after cleans up nodes + } + } +} + +void FoldFrozenLinearMulOrDiv(Block* b) { + for (Node* n : b->nodes()) { + for (Block* block : n->blocks()) { + FoldFrozenLinearMulOrDiv(block); + } + + if (supportedMulOrDiv(n) && + supportedLinearNode(n->inputs().at(0)->node())) { + auto linear = n->inputs().at(0)->node(); + auto mul_or_div = n; + + if (!checkLinearAndBroadcastingOpPreConditions(linear, mul_or_div)) { + continue; + } + + Tensor weight_tensor; + if (linear->kind() == aten::linear) { + weight_tensor = + constant_as(linear->namedInput("weight")).value(); + } else { + weight_tensor = torch_ipex::cpu::linear_weight_unpack( + constant_as(linear->namedInput("weight")).value(), + constant_as(linear->inputs().at(2)).value(), + constant_as(linear->inputs().at(3)).value(), + false, + c10::nullopt); + } + + int64_t out_channels = weight_tensor.size(0); + + // We've already verified that the second input has numel == 1 or + // channels-out resize it to the shape that will broadcast to + // weight_tensor when the op is run so we dont change weight size + std::vector weight_compatible_size = {out_channels}; + for (const auto i : c10::irange(1, weight_tensor.ndimension())) { + (void)i; // Suppress unused variable warning + weight_compatible_size.push_back(1); + } + + WithInsertPoint guard(linear); + + Tensor mul_tensor = resizeConstantScalarOrTensorToShape( + mul_or_div->inputs().at(1), + weight_compatible_size, + weight_tensor.options()); + + // First fold with weight tensor + mul_or_div->replaceInputWith( + linear->output(), b->owningGraph()->insertConstant(weight_tensor)); + mul_or_div->replaceInput(1, b->owningGraph()->insertConstant(mul_tensor)); + + auto stack_out = runNodeIfInputsAreConstant(mul_or_div); + TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); + + Tensor fuse_weight; + if (linear->kind() == aten::linear) { + fuse_weight = (*stack_out)[0].toTensor().to(weight_tensor.dtype()); + } else { + fuse_weight = torch_ipex::cpu::linear_weight_pack( + (*stack_out)[0].toTensor().to(weight_tensor.dtype()), c10::nullopt); + } + + auto fused_linear_weight = b->owningGraph()->insertConstant(fuse_weight); + auto linear_weight_value = linear->namedInput("weight"); + + fused_linear_weight->setDebugName( + linear_weight_value->debugName() + "_fused_" + + mul_or_div->kind().toUnqualString()); + linear->replaceInputWith(linear_weight_value, fused_linear_weight); + mul_or_div->output()->replaceAllUsesWith(linear->output()); + + // now fold with bias tensor + if (linear->namedInput("bias")->type() != NoneType::get()) { + Tensor bias = constant_as(linear->namedInput("bias")).value(); + // bias is of shape {channels_out} + auto mul_tensor = resizeConstantScalarOrTensorToShape( + mul_or_div->inputs().at(1), {out_channels}, bias.options()); + + mul_or_div->replaceInput(0, b->owningGraph()->insertConstant(bias)); + mul_or_div->replaceInput( + 1, b->owningGraph()->insertConstant(mul_tensor)); + + auto stack_out = runNodeIfInputsAreConstant(mul_or_div); + TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); + Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype()); + + auto fused_linear_bias = b->owningGraph()->insertConstant(fuse_bias); + auto linear_b_value = linear->namedInput("bias"); + + fused_linear_weight->setDebugName( + linear_b_value->debugName() + "_fused_" + + mul_or_div->kind().toUnqualString()); + linear->replaceInputWith(linear_b_value, fused_linear_bias); + } + // DCE run after cleans up nodes + } + } +} + +} // namespace + +void FoldFrozenLinearAddOrSub(std::shared_ptr& graph) { + FoldFrozenLinearAddOrSub(graph->block()); + EliminateDeadCode(graph); +} + +void FoldFrozenLinearMulOrDiv(std::shared_ptr& graph) { + FoldFrozenLinearMulOrDiv(graph->block()); + EliminateDeadCode(graph); +} + +} // namespace jit +} // namespace torch diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_linear_folding.h b/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_linear_folding.h new file mode 100644 index 000000000..11f3ce37b --- /dev/null +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/frozen_linear_folding.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// Fuses Linear -> Add/Sub into a single Linear by +// folding add constant tensor into linear weights. +// This pass only works on Frozen Graphs; otherwise it is a No-Op. +TORCH_API void FoldFrozenLinearAddOrSub(std::shared_ptr& graph); + +// Fuses Linear -> Mul/Div into a single Linear by +// folding add constant tensor into linear weights. +// This pass only works on Frozen Graphs; otherwise it is a No-Op. +TORCH_API void FoldFrozenLinearMulOrDiv(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp b/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp index 475f83957..3862ca51b 100644 --- a/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp +++ b/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp @@ -8,6 +8,7 @@ #include "cpu/kernels/Matmul.h" #include "cpu/passes/concat_linear.h" #include "cpu/passes/frozen_conv_folding.h" +#include "cpu/passes/frozen_linear_folding.h" #include #include @@ -365,6 +366,10 @@ void IPEXFusionPass(std::shared_ptr& graph) { graph_rewrite::fuseConvWithEltwise(graph); graph_rewrite::fuseConvAddRelu(graph); + // linear folding + FoldFrozenLinearAddOrSub(graph); + FoldFrozenLinearMulOrDiv(graph); + // linear fusion graph_rewrite::insertPrePackedLinearOp(graph); graph_rewrite::fuseLinearWithEltwise(graph); diff --git a/tests/cpu/test_jit.py b/tests/cpu/test_jit.py index 7f452d174..3141c33a4 100644 --- a/tests/cpu/test_jit.py +++ b/tests/cpu/test_jit.py @@ -308,6 +308,29 @@ def forward(self, x): b = self.bn2(b) return F.relu(a.add_(b), inplace=True) +class Linear_Scalar_Binary(nn.Module): + def __init__(self, op, in_channels, out_channels, **kwargs): + super(Linear_Scalar_Binary, self).__init__() + seed = 2018 + torch.manual_seed(seed) + self.linear = nn.Linear(in_channels, out_channels, **kwargs) + self.op = op + + def forward(self, x): + return self.op(self.linear(x), 2.0) + +class Linear_Tensor_Binary(nn.Module): + def __init__(self, op, in_channels, out_channels, **kwargs): + super(Linear_Tensor_Binary, self).__init__() + seed = 2018 + torch.manual_seed(seed) + self.linear = nn.Linear(in_channels, out_channels, **kwargs) + self.op = op + self.tensor = torch.randn(out_channels) + + def forward(self, x): + return self.op(self.linear(x), self.tensor) + class LinearRelu(nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super(LinearRelu, self).__init__() @@ -1779,6 +1802,114 @@ def test_linear_auto_kernel_selection_bf16(self): # for bfloat16 path, we will use ipex linear for 'O0' and 'O1' self.assertTrue(any(n.kind() == 'ipex_prepack::linear_relu_run' for n in trace_graph.nodes())) + def test_output_linear_scalar_binary(self): + for bias in [True, False]: + self._test_output( + Linear_Scalar_Binary(torch.add, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::add") + + self._test_output( + Linear_Scalar_Binary(torch.sub, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::sub") + + self._test_output( + Linear_Scalar_Binary(torch.mul, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::mul") + + self._test_output( + Linear_Scalar_Binary(torch.div, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::div") + + self._test_output_bf16( + Linear_Scalar_Binary(torch.add, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::add", + prec=0.1) + + self._test_output_bf16( + Linear_Scalar_Binary(torch.sub, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::sub", + prec=0.1) + + self._test_output_bf16( + Linear_Scalar_Binary(torch.mul, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::mul", + prec=0.1) + + self._test_output_bf16( + Linear_Scalar_Binary(torch.div, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::div", + prec=0.1) + + def test_output_linear_tensor_binary(self): + for bias in [True, False]: + self._test_output( + Linear_Tensor_Binary(torch.add, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::add") + + self._test_output( + Linear_Tensor_Binary(torch.sub, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::sub") + + self._test_output( + Linear_Tensor_Binary(torch.mul, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::mul") + + self._test_output( + Linear_Tensor_Binary(torch.div, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="aten::linear", + kind_not_in_graph="aten::div") + + self._test_output_bf16( + Linear_Tensor_Binary(torch.add, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::add", + prec=0.1) + + self._test_output_bf16( + Linear_Tensor_Binary(torch.sub, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::sub", + prec=0.1) + + self._test_output_bf16( + Linear_Tensor_Binary(torch.mul, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::mul", + prec=0.1) + + self._test_output_bf16( + Linear_Tensor_Binary(torch.div, 3, 32, bias=bias), + torch.randn(52, 3), + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="aten::div", + prec=0.2) + def test_output_linear_relu(self): self._test_output( LinearRelu(3, 32, bias=True),