From 5ef8427cfe9a0ccc1e13bf49e96892467a20dcf0 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Thu, 11 Apr 2024 15:29:22 -0700 Subject: [PATCH] Extend constant prop pass to work with int/float/etc scalars and fix input specs. (#2950) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2950 1. Cleanup / Refactor constant prop pass. 2. Enable constant propagation for ops with constant scalar arguments -- int/float/dtype/bool/str. Nodes of type `Op(constant_tensor, some_int, some_float, some_dtype, ...)` can now be constant propagated. 3. Fix order of input spec to match the expected spec in `ExportGraphSignature` class. parameters->buffers->constants->user_inputs. Before this diff, input_specs for the newly added constant tensors were appended to graph_signature, which would cause failures. Reviewed By: dulinriley Differential Revision: D55891278 fbshipit-source-id: fe1867cb6a99d0140d6a2e076027688cb1ddc0cd --- exir/passes/TARGETS | 2 + exir/passes/constant_prop_pass.py | 342 ++++++++++++++++++++++-------- exir/tests/test_passes.py | 102 ++++++++- 3 files changed, 362 insertions(+), 84 deletions(-) diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index f7f56ece2b..7ec14fb7d8 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -92,6 +92,8 @@ python_library( ], deps = [ "//caffe2:torch", + "//executorch/exir/dialects:lib", + "//executorch/exir/dialects/edge:lib", ], ) diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 14ff651c93..0fabf223fb 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -4,58 +4,145 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from collections import OrderedDict +from typing import cast, Mapping, Optional + import torch -from torch._export.utils import get_buffer, get_param, is_buffer, is_param +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_buffer, + is_lifted_tensor_constant, + is_param, +) from torch._guards import detect_fake_mode from torch.export import ExportedProgram from torch.export.exported_program import InputKind, InputSpec, TensorArgument +from torch.utils import _pytree as pytree + + +# Avoid propagating constants for `exir.ops.edge.aten.full.default`. +# Propagating aten.full can significantly increase compiled model size. +_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default} +_PRIMITIVE_TYPES = ( + float, + int, + bool, + str, + torch.Tensor, + torch.device, + torch.dtype, + torch.layout, +) -def is_const(arg, exported_program, const_data_list) -> bool: + +def is_const( + arg, + exported_program: ExportedProgram, + const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], +) -> bool: if isinstance(arg, (tuple, list)): - return all(is_const(x, exported_program, const_data_list) for x in arg) + return all(is_const(x, exported_program, const_node_to_tensor) for x in arg) elif isinstance(arg, dict): - return all(is_const(x, exported_program, const_data_list) for x in arg.values()) - elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder": + return all( + is_const(x, exported_program, const_node_to_tensor) for x in arg.values() + ) + elif isinstance(arg, _PRIMITIVE_TYPES): + return True + elif not isinstance(arg, torch.fx.Node): return False - elif ( - is_param(exported_program, arg) - or is_buffer(exported_program, arg) - or arg.name in const_data_list - ): + elif arg in const_node_to_tensor: return True return False -def get_data(exported_program, arg): +def get_data( + arg, + exported_program: ExportedProgram, + const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], +): if isinstance(arg, (tuple, list)): - return [get_data(exported_program, x) for x in arg] - elif is_param(exported_program, arg): - return get_param(exported_program, arg) - elif is_buffer(exported_program, arg): - return get_buffer(exported_program, arg) + return type(arg)( + get_data(x, exported_program, const_node_to_tensor) for x in arg + ) + elif isinstance(arg, _PRIMITIVE_TYPES): + return arg + elif arg in const_node_to_tensor: + return const_node_to_tensor[arg] return None -def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: +def get_constant_placeholder_dict( + exported_program: ExportedProgram, +) -> OrderedDict[torch.fx.Node, torch.Tensor]: """ - This pass is for constant propagation for Exported Program with lifted parameters, - as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph. + Returns a dictionary of placeholder node -> constant tensor. """ - if ( - len([node for node in exported_program.graph.nodes if node.op == "placeholder"]) - == 0 - ): - return exported_program + const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() + for node in exported_program.graph.nodes: + if node.op != "placeholder": + continue + + if is_param(exported_program, node): + const_node_to_tensor[node] = cast( + torch.Tensor, get_param(exported_program, node) + ) + elif is_buffer(exported_program, node): + const_node_to_tensor[node] = cast( + torch.Tensor, get_buffer(exported_program, node) + ) + elif is_lifted_tensor_constant(exported_program, node): + const_node_to_tensor[node] = cast( + torch.Tensor, get_lifted_tensor_constant(exported_program, node) + ) + return const_node_to_tensor - has_cond = [ - node - for node in exported_program.graph.nodes - if node.target == torch.ops.higher_order.cond - ] - if len(has_cond) > 0: - raise RuntimeError("constant_prop_pass for control flow is not supported yet.") +def get_propagated_const_tensor_dict( + exported_program: ExportedProgram, + custom_skip_targets: Optional[set[EdgeOpOverload]], +) -> OrderedDict[torch.fx.Node, torch.Tensor]: + """ + Propagates constants and returns a dictionary of node->constant tensors. + """ + # Initialize dict with all constant placeholders. + const_node_to_tensor = get_constant_placeholder_dict(exported_program) + + all_skip_targets: set[EdgeOpOverload] = set() + # Default set of targets to skip. + all_skip_targets.update(_DEFAULT_SKIP_TARGETS) + if custom_skip_targets is not None: + all_skip_targets.update(custom_skip_targets) + + for node in exported_program.graph.nodes: + if node.op != "call_function" or node.target in all_skip_targets: + continue + + if not is_const( + node.args, + exported_program, + const_node_to_tensor, + ): + continue + + args_data, kwargs_data = pytree.tree_map( + lambda x: get_data(x, exported_program, const_node_to_tensor), + (node.args, node.kwargs), + ) + + # Execute the `node.target` and create a new propagated constant tensor. + prop_constant_tensor = node.target(*args_data, **kwargs_data) + const_node_to_tensor[node] = prop_constant_tensor + + return const_node_to_tensor + + +def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node: + """Returns the first user input node in the graph.""" first_user_input = None for node in exported_program.graph.nodes: if ( @@ -64,11 +151,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: ): first_user_input = node break + return first_user_input + + +def replace_with_constant_node( + node: torch.fx.Node, + prop_constant_tensor: torch.Tensor, + first_user_input: torch.fx.Node, + fake_mode, + exported_program: ExportedProgram, +) -> tuple[torch.fx.Node, str]: + # Add `prop_constant_tensor` to program.state_dict. + prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}" + exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor + + # Insert a new placeholder node for the propagated constant tensor. + with exported_program.graph.inserting_before(first_user_input): + const_placeholder_node = exported_program.graph.placeholder( + prop_constant_tensor_fqn + ) + + # Update the meta data of the new placeholder (buffer) node. + for k, v in node.meta.items(): + const_placeholder_node.meta[k] = v + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + prop_constant_tensor, static_shapes=True + ) + const_placeholder_node.meta["val"].constant = prop_constant_tensor + + # Replace the original node with the new constant node. + node.replace_all_uses_with(const_placeholder_node) + exported_program.graph.erase_node(node) + + return const_placeholder_node, prop_constant_tensor_fqn - buffers = exported_program.graph_signature.buffers - prop_constant_data = [] - const_data_to_be_removed = set() +def get_fake_mode(exported_program: ExportedProgram): fake_mode = detect_fake_mode( tuple( node.meta["val"] @@ -77,57 +195,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: ) ) assert fake_mode is not None + return fake_mode + +def erase_constant_node( + exported_program: ExportedProgram, + node: torch.fx.Node, +) -> None: + # Remove corresponding tensor from param/constants dict. + signature = exported_program.graph_signature + if name := signature.inputs_to_parameters.pop(node.name, None): + exported_program.state_dict.pop(name, None) + elif name := signature.inputs_to_lifted_tensor_constants.pop(node.name, None): + exported_program.constants.pop(name, None) + elif name := signature.inputs_to_buffers.pop(node.name, None): + exported_program.constants.pop(name, None) + exported_program.state_dict.pop(name, None) + + # Remove from graph. + exported_program.graph.erase_node(node) + + +def create_constant_nodes_and_return_specs( + const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], + exported_program: ExportedProgram, +) -> dict[str, InputSpec]: + """ + Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict. + """ + name_to_spec_dict: dict[str, InputSpec] = {} + + fake_mode = get_fake_mode(exported_program) + first_user_input = get_first_user_input(exported_program) + + # Iterate over nodes in reverse order. + for node, prop_constant_tensor in reversed(const_node_to_tensor.items()): + if all(x in const_node_to_tensor for x in node.users): + # All users of this constant node are also constant, so we don't need to create a new constant node. + erase_constant_node(exported_program, node) + continue + + if node.op == "placeholder": + continue + + const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node( + node, prop_constant_tensor, first_user_input, fake_mode, exported_program + ) + + # Create input spec for lifted constant. + name_to_spec_dict[const_placeholder_node.name] = InputSpec( + kind=InputKind.CONSTANT_TENSOR, + arg=TensorArgument(name=const_placeholder_node.name), + target=prop_constant_tensor_fqn, + persistent=True, + ) + return name_to_spec_dict + + +def constant_prop_pass( + exported_program: ExportedProgram, + custom_skip_targets: Optional[set[EdgeOpOverload]] = None, +) -> ExportedProgram: + """ + This pass is for constant propagation for Exported Program with lifted parameters, + as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph. + + Args: + exported_program: The ExportedProgram to perform constant propagation on. + custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation. + + Returns: + The modified ExportedProgram with constant propagation applied. + """ + if ( + len([node for node in exported_program.graph.nodes if node.op == "placeholder"]) + == 0 + ): + return exported_program + + has_control_flow = [ + node + for node in exported_program.graph.nodes + if node.target == torch.ops.higher_order.cond + ] + if len(has_control_flow) > 0: + raise RuntimeError("constant_prop_pass for control flow is not supported yet.") + + const_node_to_tensor = get_propagated_const_tensor_dict( + exported_program, custom_skip_targets + ) + + # Get old input specs. + name_to_spec_dict = { + s.arg.name: s for s in exported_program.graph_signature.input_specs + } + # Add the new constants to input specs dict. + name_to_spec_dict.update( + create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program) + ) + + # Generate new input spec. + new_input_specs = [] for node in exported_program.graph.nodes: - if node.op == "call_function": - constant_data_name_list = [ - input_spec.target for input_spec in prop_constant_data - ] - if is_const(node.args, exported_program, constant_data_name_list): - args_data = [get_data(exported_program, arg) for arg in node.args] - kwargs_data = node.kwargs - const_data_to_be_removed.update(node.args) - prop_constant_tensor = node.target(*args_data, **kwargs_data) - prop_constant_tensor_fqn = f"_prop_tensor_constant{len(buffers)}" - - with exported_program.graph.inserting_before(first_user_input): - const_placeholder_node = exported_program.graph.placeholder( - prop_constant_tensor_fqn - ) - # Update the meta data of the new placeholder (buffer) node - for k, v in node.meta.items(): - const_placeholder_node.meta[k] = v - const_placeholder_node.meta["val"] = fake_mode.from_tensor( - prop_constant_tensor, static_shapes=True - ) - const_placeholder_node.meta["val"].constant = prop_constant_tensor - - node.replace_all_uses_with(const_placeholder_node) - exported_program.graph.erase_node(node) - prop_constant_node_input_spec = InputSpec( - kind=InputKind.BUFFER, - arg=TensorArgument(name=const_placeholder_node.name), - target=prop_constant_tensor_fqn, - persistent=True, - ) - prop_constant_data.append(prop_constant_node_input_spec) - buffers.append(prop_constant_tensor_fqn) - exported_program.state_dict[prop_constant_tensor_fqn] = ( - prop_constant_tensor - ) - exported_program.graph_signature.input_specs.append( - prop_constant_node_input_spec - ) - - # Remove the propogated buffer from the state dict - for node in exported_program.graph.nodes: - if ( - node.op == "placeholder" - and node in const_data_to_be_removed - and len(node.users) == 0 - ): - exported_program.state_dict.pop(node.name, None) - exported_program.graph.erase_node(node) + if node.op != "placeholder": + continue + new_input_specs.append(name_to_spec_dict[node.name]) + exported_program.graph_signature.input_specs = new_input_specs + # Cleanup the graph. + exported_program.graph.eliminate_dead_code() exported_program.graph_module.recompile() + return exported_program diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index bfa0d39323..582941eb2c 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -65,6 +65,7 @@ XNNPACKQuantizer, ) from torch.export import export +from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter from torch.fx.experimental.proxy_tensor import make_fx from torch.library import impl, Library @@ -1145,7 +1146,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor FileCheck().check_not("_lifted_tensor_constant").check( - "_prop_tensor_constant1" + "_prop_tensor_constant0" ).check_not("executorch_exir_dialects_edge__ops_aten__to_copy_default").run( new_ep.graph_module.code ) @@ -1174,6 +1175,105 @@ def forward(self, x): new_ep = constant_prop_pass(aten) self.assertEqual(count_additions(new_ep.graph_module), 1) + def test_constant_prop_pass_graph_signature(self) -> None: + def count_additions(gm: torch.fx.GraphModule) -> int: + return sum( + (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.ones(1, 2, 3)) + + def forward(self, x): + b = self.a + self.a + c = torch.cat([self.a, b]) + return (c + c) + x + + aten = export( + M(), + (torch.zeros(2, 2, 3),), + ) + # Input signature will have two entries: + # (1) parameter `a` and (2) user input `x`. + self.assertEqual(len(aten.graph_signature.input_specs), 2) + new_ep = constant_prop_pass(aten) + # Check that there are exactly two propagated tensors - (1) propagated + # constant and (2) user input. + self.assertEqual( + new_ep.graph_signature.input_specs, + [ + InputSpec( + kind=InputKind.CONSTANT_TENSOR, + arg=TensorArgument(name="_prop_tensor_constant0"), + target="_prop_tensor_constant0", + persistent=True, + ), + # User input graph signature. + aten.graph_signature.input_specs[-1], + ], + ) + + def test_constant_prop_pass_for_parameter_slice(self) -> None: + def count_slice(gm: torch.fx.GraphModule) -> int: + return sum( + (node.target == torch.ops.aten.slice_copy.Tensor) + for node in gm.graph.nodes + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.ones(3, 2, 2)) + + def forward(self, x): + # Create slice of shape (1, 2, 2) + slice_tensor = torch.slice_copy(self.a, dim=0, start=0, end=1) + return torch.cat([x, slice_tensor]) + + aten = export( + M(), + (torch.zeros(2, 2, 2),), + ) + self.assertIn("a", aten.state_dict) + self.assertEqual(count_slice(aten.graph_module), 1) + + new_ep = constant_prop_pass(aten) + # Check there is a propagated tensor. + FileCheck().check("_prop_tensor_constant0").run(aten.graph_module.code) + self.assertIn("_prop_tensor_constant0", new_ep.constants) + self.assertNotIn("a", new_ep.state_dict) + # No more slice copy. + self.assertEqual(count_slice(new_ep.graph_module), 0) + + def test_constant_prop_pass_no_propagate(self) -> None: + def count_placeholder(gm: torch.fx.GraphModule) -> int: + return sum((node.op == "placeholder") for node in gm.graph.nodes) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.ones(3, 2, 4)) + + def forward(self, x, y): + # y is unused. + return x + self.a + + aten = export( + M(), + (torch.zeros(3, 2, 4), torch.zeros(3, 2, 4)), + ) + self.assertIn("a", aten.state_dict) + self.assertEqual(count_placeholder(aten.graph_module), 3) + + new_ep = constant_prop_pass(aten) + # Check there is no propagated tensor. + FileCheck().check("p_a").check("x").check("y").run(aten.graph_module.code) + self.assertNotIn("_prop_tensor_constant0", new_ep.constants) + self.assertIn("a", new_ep.state_dict) + self.assertEqual(count_placeholder(new_ep.graph_module), 3) + def test_constant_prop_pass_for_control_flow(self) -> None: class Module(torch.nn.Module): def __init__(self):