From e9ad21d4dc47f15f44de5bec843720bcc1a43fb6 Mon Sep 17 00:00:00 2001 From: Wei Lu Date: Wed, 29 May 2024 13:08:45 -0700 Subject: [PATCH] fuse conv and batch_norm Summary: - note that in the printed ops, there isn't `batch_norm` anymore - 52 conv+batch_norm instances have been fused | fuse | Loading(ms) | vmRss(KB) | vmaBlock(KB) | Inference(ms) | vmRss(KB) | vmaBlock(KB) | | -------- | ------- | ------- | ------- | ------- | | Yes | 380 | 22928 | 65536 | 148 | 24296 | 65536 | | No | 473 | 26036 | 65536 | 161 | 27416 | 65536 | Differential Revision: D57895439 --- backends/transforms/TARGETS | 29 ++++ .../transforms/fuse_batch_norm_with_conv.py | 161 ++++++++++++++++++ backends/transforms/utils.py | 55 ++++++ backends/vulkan/TARGETS | 1 + backends/vulkan/vulkan_preprocess.py | 4 + 5 files changed, 250 insertions(+) create mode 100644 backends/transforms/fuse_batch_norm_with_conv.py create mode 100644 backends/transforms/utils.py diff --git a/backends/transforms/TARGETS b/backends/transforms/TARGETS index 549de43bd81..41c9cfc7bec 100644 --- a/backends/transforms/TARGETS +++ b/backends/transforms/TARGETS @@ -33,6 +33,22 @@ runtime.python_library( ], ) +runtime.python_library( + name = "fuse_batch_norm_with_conv", + srcs = ["fuse_batch_norm_with_conv.py"], + visibility = [ + "//executorch/backends/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "mean_to_sum_div", srcs = ["mean_to_sum_div.py"], @@ -48,6 +64,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:pass_manager", + "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", + "//executorch/exir/dialects:lib", + "//pytorch/ao:torchao", # @manual + ], +) + runtime.python_library( name = "duplicate_dynamic_quant_chain", srcs = ["duplicate_dynamic_quant_chain.py"], diff --git a/backends/transforms/fuse_batch_norm_with_conv.py b/backends/transforms/fuse_batch_norm_with_conv.py new file mode 100644 index 00000000000..dda74a3dd6b --- /dev/null +++ b/backends/transforms/fuse_batch_norm_with_conv.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +import torch + +from executorch.backends.transforms.utils import get_param_tensor, is_param_node +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from torch.nn.utils.fusion import fuse_conv_bn_weights + + +class FuseBatchNormWithConvPass(ExportPass): + """ + Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase + memory usage since we serialize new weights to represent the convolution. In most cases, + Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused + with the previous convolution + """ + + def __init__(self, exported_program: ExportedProgram): + super().__init__() + self.exported_program = exported_program + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + counter = 0 + for conv in graph.nodes: + # We want to discover a chain of conv -> batch_norm. + # Only proceed if the current node is a conv node, and has a single + # user/successor. + if ( + conv.target != exir_ops.edge.aten.convolution.default + or len(conv.users) != 1 + ): + continue + + # The single user of conv op must be batch_norm. If not, bail. + bn = list(conv.users.keys())[0] + if ( + bn.target != exir_ops.edge.aten.native_batch_norm.default + and bn.target + != exir_ops.edge.aten._native_batch_norm_legit_no_training.default + ): + continue + + if not self.can_fuse(conv, bn, self.exported_program): + continue + + # Get the parameters from conv op + assert len(conv.args) == 9 + + conv_weight = get_param_tensor(self.exported_program, conv.args[1]) + assert conv_weight is not None + + conv_bias = get_param_tensor(self.exported_program, conv.args[2]) + + # Get the parameters from the batchnorm op + assert ( + bn.target == exir_ops.edge.aten.native_batch_norm.default + and len(bn.args) == 8 + ) or ( + bn.target + == exir_ops.edge.aten._native_batch_norm_legit_no_training.default + and len(bn.args) == 7 + ) + bn_weight = get_param_tensor(self.exported_program, bn.args[1]) + bn_bias = get_param_tensor(self.exported_program, bn.args[2]) + + running_mean = get_param_tensor(self.exported_program, bn.args[3]) + assert running_mean is not None + + running_var = get_param_tensor(self.exported_program, bn.args[4]) + assert running_var is not None + + # args[7] for native_batch_norm, but args[6] for + # _native_batch_norm_legit_no_training (which doesn't have training + # as an arg) + eps = bn.args[-1] + + # Compute the updated weight and bias after fusing conv op + # with batchnorm op. + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_weight, + conv_bias, + running_mean, + running_var, + eps, + bn_weight, + bn_bias, + ) + + # Modify the graph by updating the weight and bias of conv op + # with the fused weight and bias params, and replacing all the users + # of getitem(batchnorm) with the conv op. + with graph.inserting_before(conv): + fused_weight_name = f"_fused_with_bn_weight_{counter}" + graph_module.register_parameter(fused_weight_name, fused_weight) + fused_weight_node = graph.get_attr(fused_weight_name) + fused_bias_name = f"_fused_with_bn_bias_{counter}" + graph_module.register_parameter(fused_bias_name, fused_bias) + fused_bias_node = graph.get_attr(fused_bias_name) + + # Update the weight and bias of conv op + conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else []) + conv_args[1] = fused_weight_node + conv_args[2] = fused_bias_node + conv.args = tuple(conv_args) + # Remove any use of batchnorm from the graph + for user in bn.users.copy(): + assert user.target == operator.getitem + user.replace_all_uses_with(conv) + graph.erase_node(user) + + graph.erase_node(bn) + + counter += 1 + + graph_module.recompile() + # To Regenerate meta data and shape information, retrace module + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) + + @staticmethod + def can_fuse( + conv: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram + ) -> bool: + """ + Determine whether a batch norm node can be fused with a preceding conv node. + """ + + # All the users of batchnorm node must be getitem ops. batchnorm + # returns a 3-element tuple. Each user must only access the first + # element of the tuple. + if [ + (user.target == operator.getitem and user.args[1] == 0) for user in bn.users + ].count(False): + return False + + conv_weights = conv.args[1] + bn_weights = bn.args[1] + + # Check that the weights for conv and batchnorm are both params + if not isinstance(conv_weights, torch.fx.Node) or not isinstance( + bn_weights, torch.fx.Node + ): + return False + + if [is_param_node(program, node) for node in {conv_weights, bn_weights}].count( + False + ): + return False + + return True diff --git a/backends/transforms/utils.py b/backends/transforms/utils.py new file mode 100644 index 00000000000..03c48039b93 --- /dev/null +++ b/backends/transforms/utils.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.exir import ExportedProgram + +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_buffer, + is_lifted_tensor_constant, + is_param, +) + + +def is_get_attr_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node is a get attr node for a tensor of the model + """ + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + +def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: + return ( + is_get_attr_node(node) + or is_param(exp_prog, node) + or is_buffer(exp_prog, node) + or is_lifted_tensor_constant(exp_prog, node) + ) + + +def get_param_tensor( + exp_prog: ExportedProgram, node: torch.fx.Node +) -> Optional[torch.Tensor]: + if node is None: + return None + elif is_param(exp_prog, node): + return get_param(exp_prog, node) + elif is_buffer(exp_prog, node): + return get_buffer(exp_prog, node) + elif is_lifted_tensor_constant(exp_prog, node): + return get_lifted_tensor_constant(exp_prog, node) + elif is_get_attr_node(node): + # This is a hack to support both lifted and unlifted graph + try: + return getattr(node.graph.owning_module, node.target) + except AttributeError: + return getattr(exp_prog.graph_module, node.target) + raise RuntimeError(f"unsupported param type, {node.op}.") diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 274cae57ceb..c6de1b0c4b2 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -23,6 +23,7 @@ runtime.python_library( ], deps = [ "//executorch/backends/transforms:addmm_mm_to_linear", + "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/exir:graph_module", "//executorch/exir/_serialize:_bindings", "//executorch/exir/_serialize:lib", diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 51d14a985ec..30e5c8c7331 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -7,6 +7,9 @@ from typing import final, List from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform +from executorch.backends.transforms.fuse_batch_norm_with_conv import ( + FuseBatchNormWithConvPass, +) from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( @@ -40,6 +43,7 @@ def preprocess( # noqa: C901 ) -> PreprocessResult: passes = [ AddmmToLinearTransform(), + FuseBatchNormWithConvPass(program), SpecPropPass(), ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass("greedy"),