-
Notifications
You must be signed in to change notification settings - Fork 318
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: - note that in the printed ops, there isn't `batch_norm` anymore - 52 conv+batch_norm instances have been fused | fuse | Loading(ms) | vmRss(KB) | vmaBlock(KB) | Inference(ms) | vmRss(KB) | vmaBlock(KB) | | -------- | ------- | ------- | ------- | ------- | | Yes | 380 | 22928 | 65536 | 148 | 24296 | 65536 | | No | 473 | 26036 | 65536 | 161 | 27416 | 65536 | Differential Revision: D57895439
- Loading branch information
1 parent
55d11e1
commit e9ad21d
Showing
5 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import operator | ||
|
||
import torch | ||
|
||
from executorch.backends.transforms.utils import get_param_tensor, is_param_node | ||
from executorch.exir import ExportedProgram | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
from torch.nn.utils.fusion import fuse_conv_bn_weights | ||
|
||
|
||
class FuseBatchNormWithConvPass(ExportPass): | ||
""" | ||
Batch Norm can be implemented using 1x1 Depthwise Convolution. However doing so will increase | ||
memory usage since we serialize new weights to represent the convolution. In most cases, | ||
Batch norm is used after convolution. The 1x1 depthwise convolution can then be fused | ||
with the previous convolution | ||
""" | ||
|
||
def __init__(self, exported_program: ExportedProgram): | ||
super().__init__() | ||
self.exported_program = exported_program | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
graph = graph_module.graph | ||
counter = 0 | ||
for conv in graph.nodes: | ||
# We want to discover a chain of conv -> batch_norm. | ||
# Only proceed if the current node is a conv node, and has a single | ||
# user/successor. | ||
if ( | ||
conv.target != exir_ops.edge.aten.convolution.default | ||
or len(conv.users) != 1 | ||
): | ||
continue | ||
|
||
# The single user of conv op must be batch_norm. If not, bail. | ||
bn = list(conv.users.keys())[0] | ||
if ( | ||
bn.target != exir_ops.edge.aten.native_batch_norm.default | ||
and bn.target | ||
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default | ||
): | ||
continue | ||
|
||
if not self.can_fuse(conv, bn, self.exported_program): | ||
continue | ||
|
||
# Get the parameters from conv op | ||
assert len(conv.args) == 9 | ||
|
||
conv_weight = get_param_tensor(self.exported_program, conv.args[1]) | ||
assert conv_weight is not None | ||
|
||
conv_bias = get_param_tensor(self.exported_program, conv.args[2]) | ||
|
||
# Get the parameters from the batchnorm op | ||
assert ( | ||
bn.target == exir_ops.edge.aten.native_batch_norm.default | ||
and len(bn.args) == 8 | ||
) or ( | ||
bn.target | ||
== exir_ops.edge.aten._native_batch_norm_legit_no_training.default | ||
and len(bn.args) == 7 | ||
) | ||
bn_weight = get_param_tensor(self.exported_program, bn.args[1]) | ||
bn_bias = get_param_tensor(self.exported_program, bn.args[2]) | ||
|
||
running_mean = get_param_tensor(self.exported_program, bn.args[3]) | ||
assert running_mean is not None | ||
|
||
running_var = get_param_tensor(self.exported_program, bn.args[4]) | ||
assert running_var is not None | ||
|
||
# args[7] for native_batch_norm, but args[6] for | ||
# _native_batch_norm_legit_no_training (which doesn't have training | ||
# as an arg) | ||
eps = bn.args[-1] | ||
|
||
# Compute the updated weight and bias after fusing conv op | ||
# with batchnorm op. | ||
fused_weight, fused_bias = fuse_conv_bn_weights( | ||
conv_weight, | ||
conv_bias, | ||
running_mean, | ||
running_var, | ||
eps, | ||
bn_weight, | ||
bn_bias, | ||
) | ||
|
||
# Modify the graph by updating the weight and bias of conv op | ||
# with the fused weight and bias params, and replacing all the users | ||
# of getitem(batchnorm) with the conv op. | ||
with graph.inserting_before(conv): | ||
fused_weight_name = f"_fused_with_bn_weight_{counter}" | ||
graph_module.register_parameter(fused_weight_name, fused_weight) | ||
fused_weight_node = graph.get_attr(fused_weight_name) | ||
fused_bias_name = f"_fused_with_bn_bias_{counter}" | ||
graph_module.register_parameter(fused_bias_name, fused_bias) | ||
fused_bias_node = graph.get_attr(fused_bias_name) | ||
|
||
# Update the weight and bias of conv op | ||
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else []) | ||
conv_args[1] = fused_weight_node | ||
conv_args[2] = fused_bias_node | ||
conv.args = tuple(conv_args) | ||
# Remove any use of batchnorm from the graph | ||
for user in bn.users.copy(): | ||
assert user.target == operator.getitem | ||
user.replace_all_uses_with(conv) | ||
graph.erase_node(user) | ||
|
||
graph.erase_node(bn) | ||
|
||
counter += 1 | ||
|
||
graph_module.recompile() | ||
# To Regenerate meta data and shape information, retrace module | ||
graph_module = super().call(graph_module).graph_module | ||
|
||
return PassResult(graph_module, True) | ||
|
||
@staticmethod | ||
def can_fuse( | ||
conv: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram | ||
) -> bool: | ||
""" | ||
Determine whether a batch norm node can be fused with a preceding conv node. | ||
""" | ||
|
||
# All the users of batchnorm node must be getitem ops. batchnorm | ||
# returns a 3-element tuple. Each user must only access the first | ||
# element of the tuple. | ||
if [ | ||
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users | ||
].count(False): | ||
return False | ||
|
||
conv_weights = conv.args[1] | ||
bn_weights = bn.args[1] | ||
|
||
# Check that the weights for conv and batchnorm are both params | ||
if not isinstance(conv_weights, torch.fx.Node) or not isinstance( | ||
bn_weights, torch.fx.Node | ||
): | ||
return False | ||
|
||
if [is_param_node(program, node) for node in {conv_weights, bn_weights}].count( | ||
False | ||
): | ||
return False | ||
|
||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from executorch.exir import ExportedProgram | ||
|
||
from torch._export.utils import ( | ||
get_buffer, | ||
get_lifted_tensor_constant, | ||
get_param, | ||
is_buffer, | ||
is_lifted_tensor_constant, | ||
is_param, | ||
) | ||
|
||
|
||
def is_get_attr_node(node: torch.fx.Node) -> bool: | ||
""" | ||
Returns true if the given node is a get attr node for a tensor of the model | ||
""" | ||
return isinstance(node, torch.fx.Node) and node.op == "get_attr" | ||
|
||
|
||
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: | ||
return ( | ||
is_get_attr_node(node) | ||
or is_param(exp_prog, node) | ||
or is_buffer(exp_prog, node) | ||
or is_lifted_tensor_constant(exp_prog, node) | ||
) | ||
|
||
|
||
def get_param_tensor( | ||
exp_prog: ExportedProgram, node: torch.fx.Node | ||
) -> Optional[torch.Tensor]: | ||
if node is None: | ||
return None | ||
elif is_param(exp_prog, node): | ||
return get_param(exp_prog, node) | ||
elif is_buffer(exp_prog, node): | ||
return get_buffer(exp_prog, node) | ||
elif is_lifted_tensor_constant(exp_prog, node): | ||
return get_lifted_tensor_constant(exp_prog, node) | ||
elif is_get_attr_node(node): | ||
# This is a hack to support both lifted and unlifted graph | ||
try: | ||
return getattr(node.graph.owning_module, node.target) | ||
except AttributeError: | ||
return getattr(exp_prog.graph_module, node.target) | ||
raise RuntimeError(f"unsupported param type, {node.op}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters