Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fuse conv and batch_norm #3769

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backends/transforms/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ runtime.python_library(
],
)

runtime.python_library(
name = "fuse_batch_norm_with_conv",
srcs = ["fuse_batch_norm_with_conv.py"],
visibility = [
"//executorch/backends/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
":utils",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir:sym_util",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "mean_to_sum_div",
srcs = ["mean_to_sum_div.py"],
Expand All @@ -48,6 +64,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "utils",
srcs = ["utils.py"],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_manager",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/dialects:lib",
"//pytorch/ao:torchao", # @manual
],
)

runtime.python_library(
name = "duplicate_dynamic_quant_chain",
srcs = ["duplicate_dynamic_quant_chain.py"],
Expand Down
161 changes: 161 additions & 0 deletions backends/transforms/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator

import torch

from executorch.backends.transforms.utils import get_param_tensor, is_param_node
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from torch.nn.utils.fusion import fuse_conv_bn_weights


class FuseBatchNormWithConvPass(ExportPass):
"""
Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase
memory usage since we serialize new weights to represent the convolution. In most cases,
Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused
with the previous convolution
"""

def __init__(self, exported_program: ExportedProgram):
super().__init__()
self.exported_program = exported_program

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
counter = 0
for conv in graph.nodes:
# We want to discover a chain of conv -> batch_norm.
# Only proceed if the current node is a conv node, and has a single
# user/successor.
if (
conv.target != exir_ops.edge.aten.convolution.default
or len(conv.users) != 1
):
continue

# The single user of conv op must be batch_norm. If not, bail.
bn = list(conv.users.keys())[0]
if (
bn.target != exir_ops.edge.aten.native_batch_norm.default
and bn.target
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
):
continue

if not self.can_fuse(conv, bn, self.exported_program):
continue

# Get the parameters from conv op
assert len(conv.args) == 9

conv_weight = get_param_tensor(self.exported_program, conv.args[1])
assert conv_weight is not None

conv_bias = get_param_tensor(self.exported_program, conv.args[2])

# Get the parameters from the batchnorm op
assert (
bn.target == exir_ops.edge.aten.native_batch_norm.default
and len(bn.args) == 8
) or (
bn.target
== exir_ops.edge.aten._native_batch_norm_legit_no_training.default
and len(bn.args) == 7
)
bn_weight = get_param_tensor(self.exported_program, bn.args[1])
bn_bias = get_param_tensor(self.exported_program, bn.args[2])

running_mean = get_param_tensor(self.exported_program, bn.args[3])
assert running_mean is not None

running_var = get_param_tensor(self.exported_program, bn.args[4])
assert running_var is not None

# args[7] for native_batch_norm, but args[6] for
# _native_batch_norm_legit_no_training (which doesn't have training
# as an arg)
eps = bn.args[-1]

# Compute the updated weight and bias after fusing conv op
# with batchnorm op.
fused_weight, fused_bias = fuse_conv_bn_weights(
conv_weight,
conv_bias,
running_mean,
running_var,
eps,
bn_weight,
bn_bias,
)

# Modify the graph by updating the weight and bias of conv op
# with the fused weight and bias params, and replacing all the users
# of getitem(batchnorm) with the conv op.
with graph.inserting_before(conv):
fused_weight_name = f"_fused_with_bn_weight_{counter}"
graph_module.register_parameter(fused_weight_name, fused_weight)
fused_weight_node = graph.get_attr(fused_weight_name)
fused_bias_name = f"_fused_with_bn_bias_{counter}"
graph_module.register_parameter(fused_bias_name, fused_bias)
fused_bias_node = graph.get_attr(fused_bias_name)

# Update the weight and bias of conv op
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else [])
conv_args[1] = fused_weight_node
conv_args[2] = fused_bias_node
conv.args = tuple(conv_args)
# Remove any use of batchnorm from the graph
for user in bn.users.copy():
assert user.target == operator.getitem
user.replace_all_uses_with(conv)
graph.erase_node(user)

graph.erase_node(bn)

counter += 1

graph_module.recompile()
# To Regenerate meta data and shape information, retrace module
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)

@staticmethod
def can_fuse(
conv: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
) -> bool:
"""
Determine whether a batch norm node can be fused with a preceding conv node.
"""

# All the users of batchnorm node must be getitem ops. batchnorm
# returns a 3-element tuple. Each user must only access the first
# element of the tuple.
if [
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
].count(False):
return False

conv_weights = conv.args[1]
bn_weights = bn.args[1]

# Check that the weights for conv and batchnorm are both params
if not isinstance(conv_weights, torch.fx.Node) or not isinstance(
bn_weights, torch.fx.Node
):
return False

if [is_param_node(program, node) for node in {conv_weights, bn_weights}].count(
False
):
return False

return True
55 changes: 55 additions & 0 deletions backends/transforms/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
from executorch.exir import ExportedProgram

from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)


def is_get_attr_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"


def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_get_attr_node(node)
or is_param(exp_prog, node)
or is_buffer(exp_prog, node)
or is_lifted_tensor_constant(exp_prog, node)
)


def get_param_tensor(
exp_prog: ExportedProgram, node: torch.fx.Node
) -> Optional[torch.Tensor]:
if node is None:
return None
elif is_param(exp_prog, node):
return get_param(exp_prog, node)
elif is_buffer(exp_prog, node):
return get_buffer(exp_prog, node)
elif is_lifted_tensor_constant(exp_prog, node):
return get_lifted_tensor_constant(exp_prog, node)
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
raise RuntimeError(f"unsupported param type, {node.op}.")
1 change: 1 addition & 0 deletions backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ runtime.python_library(
],
deps = [
"//executorch/backends/transforms:addmm_mm_to_linear",
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
"//executorch/exir:graph_module",
"//executorch/exir/_serialize:_bindings",
"//executorch/exir/_serialize:lib",
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from typing import final, List

from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform
from executorch.backends.transforms.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
Expand Down Expand Up @@ -40,6 +43,7 @@ def preprocess( # noqa: C901
) -> PreprocessResult:
passes = [
AddmmToLinearTransform(),
FuseBatchNormWithConvPass(program),
SpecPropPass(),
ConstraintBasedSymShapeEvalPass(),
MemoryPlanningPass("greedy"),
Expand Down
Loading