Skip to content

Commit

Permalink
Extend constant prop pass to work with int/float/etc scalars and fix …
Browse files Browse the repository at this point in the history
…input specs. (#2950)

Summary:
Pull Request resolved: #2950

1. Cleanup / Refactor constant prop pass.

2. Enable constant propagation for ops with constant scalar arguments -- int/float/dtype/bool/str.
Nodes of type `Op(constant_tensor, some_int, some_float, some_dtype, ...)` can now be constant propagated.

3. Fix order of input spec to match the expected spec in `ExportGraphSignature` class.
parameters->buffers->constants->user_inputs.
Before this diff, input_specs for the newly added constant tensors were appended to graph_signature, which would cause failures.

Reviewed By: dulinriley

Differential Revision: D55891278

fbshipit-source-id: fe1867cb6a99d0140d6a2e076027688cb1ddc0cd
  • Loading branch information
hsharma35 authored and facebook-github-bot committed Apr 11, 2024
1 parent 3b727a7 commit 5ef8427
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 84 deletions.
2 changes: 2 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ python_library(
],
deps = [
"//caffe2:torch",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
],
)

Expand Down
342 changes: 259 additions & 83 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,145 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict
from typing import cast, Mapping, Optional

import torch
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
from torch._guards import detect_fake_mode
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.utils import _pytree as pytree


# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
# Propagating aten.full can significantly increase compiled model size.
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}

_PRIMITIVE_TYPES = (
float,
int,
bool,
str,
torch.Tensor,
torch.device,
torch.dtype,
torch.layout,
)

def is_const(arg, exported_program, const_data_list) -> bool:

def is_const(
arg,
exported_program: ExportedProgram,
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
) -> bool:
if isinstance(arg, (tuple, list)):
return all(is_const(x, exported_program, const_data_list) for x in arg)
return all(is_const(x, exported_program, const_node_to_tensor) for x in arg)
elif isinstance(arg, dict):
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
return all(
is_const(x, exported_program, const_node_to_tensor) for x in arg.values()
)
elif isinstance(arg, _PRIMITIVE_TYPES):
return True
elif not isinstance(arg, torch.fx.Node):
return False
elif (
is_param(exported_program, arg)
or is_buffer(exported_program, arg)
or arg.name in const_data_list
):
elif arg in const_node_to_tensor:
return True
return False


def get_data(exported_program, arg):
def get_data(
arg,
exported_program: ExportedProgram,
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
):
if isinstance(arg, (tuple, list)):
return [get_data(exported_program, x) for x in arg]
elif is_param(exported_program, arg):
return get_param(exported_program, arg)
elif is_buffer(exported_program, arg):
return get_buffer(exported_program, arg)
return type(arg)(
get_data(x, exported_program, const_node_to_tensor) for x in arg
)
elif isinstance(arg, _PRIMITIVE_TYPES):
return arg
elif arg in const_node_to_tensor:
return const_node_to_tensor[arg]
return None


def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
def get_constant_placeholder_dict(
exported_program: ExportedProgram,
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
Returns a dictionary of placeholder node -> constant tensor.
"""
if (
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
== 0
):
return exported_program
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
for node in exported_program.graph.nodes:
if node.op != "placeholder":
continue

if is_param(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_param(exported_program, node)
)
elif is_buffer(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_buffer(exported_program, node)
)
elif is_lifted_tensor_constant(exported_program, node):
const_node_to_tensor[node] = cast(
torch.Tensor, get_lifted_tensor_constant(exported_program, node)
)
return const_node_to_tensor

has_cond = [
node
for node in exported_program.graph.nodes
if node.target == torch.ops.higher_order.cond
]
if len(has_cond) > 0:
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")

def get_propagated_const_tensor_dict(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]],
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
"""
Propagates constants and returns a dictionary of node->constant tensors.
"""
# Initialize dict with all constant placeholders.
const_node_to_tensor = get_constant_placeholder_dict(exported_program)

all_skip_targets: set[EdgeOpOverload] = set()
# Default set of targets to skip.
all_skip_targets.update(_DEFAULT_SKIP_TARGETS)
if custom_skip_targets is not None:
all_skip_targets.update(custom_skip_targets)

for node in exported_program.graph.nodes:
if node.op != "call_function" or node.target in all_skip_targets:
continue

if not is_const(
node.args,
exported_program,
const_node_to_tensor,
):
continue

args_data, kwargs_data = pytree.tree_map(
lambda x: get_data(x, exported_program, const_node_to_tensor),
(node.args, node.kwargs),
)

# Execute the `node.target` and create a new propagated constant tensor.
prop_constant_tensor = node.target(*args_data, **kwargs_data)
const_node_to_tensor[node] = prop_constant_tensor

return const_node_to_tensor


def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node:
"""Returns the first user input node in the graph."""
first_user_input = None
for node in exported_program.graph.nodes:
if (
Expand All @@ -64,11 +151,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
):
first_user_input = node
break
return first_user_input


def replace_with_constant_node(
node: torch.fx.Node,
prop_constant_tensor: torch.Tensor,
first_user_input: torch.fx.Node,
fake_mode,
exported_program: ExportedProgram,
) -> tuple[torch.fx.Node, str]:
# Add `prop_constant_tensor` to program.state_dict.
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor

# Insert a new placeholder node for the propagated constant tensor.
with exported_program.graph.inserting_before(first_user_input):
const_placeholder_node = exported_program.graph.placeholder(
prop_constant_tensor_fqn
)

# Update the meta data of the new placeholder (buffer) node.
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
prop_constant_tensor, static_shapes=True
)
const_placeholder_node.meta["val"].constant = prop_constant_tensor

# Replace the original node with the new constant node.
node.replace_all_uses_with(const_placeholder_node)
exported_program.graph.erase_node(node)

return const_placeholder_node, prop_constant_tensor_fqn

buffers = exported_program.graph_signature.buffers
prop_constant_data = []
const_data_to_be_removed = set()

def get_fake_mode(exported_program: ExportedProgram):
fake_mode = detect_fake_mode(
tuple(
node.meta["val"]
Expand All @@ -77,57 +195,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
)
)
assert fake_mode is not None
return fake_mode


def erase_constant_node(
exported_program: ExportedProgram,
node: torch.fx.Node,
) -> None:
# Remove corresponding tensor from param/constants dict.
signature = exported_program.graph_signature
if name := signature.inputs_to_parameters.pop(node.name, None):
exported_program.state_dict.pop(name, None)
elif name := signature.inputs_to_lifted_tensor_constants.pop(node.name, None):
exported_program.constants.pop(name, None)
elif name := signature.inputs_to_buffers.pop(node.name, None):
exported_program.constants.pop(name, None)
exported_program.state_dict.pop(name, None)

# Remove from graph.
exported_program.graph.erase_node(node)


def create_constant_nodes_and_return_specs(
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
exported_program: ExportedProgram,
) -> dict[str, InputSpec]:
"""
Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
"""
name_to_spec_dict: dict[str, InputSpec] = {}

fake_mode = get_fake_mode(exported_program)
first_user_input = get_first_user_input(exported_program)

# Iterate over nodes in reverse order.
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
if all(x in const_node_to_tensor for x in node.users):
# All users of this constant node are also constant, so we don't need to create a new constant node.
erase_constant_node(exported_program, node)
continue

if node.op == "placeholder":
continue

const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node(
node, prop_constant_tensor, first_user_input, fake_mode, exported_program
)

# Create input spec for lifted constant.
name_to_spec_dict[const_placeholder_node.name] = InputSpec(
kind=InputKind.CONSTANT_TENSOR,
arg=TensorArgument(name=const_placeholder_node.name),
target=prop_constant_tensor_fqn,
persistent=True,
)
return name_to_spec_dict


def constant_prop_pass(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
) -> ExportedProgram:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
Args:
exported_program: The ExportedProgram to perform constant propagation on.
custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
Returns:
The modified ExportedProgram with constant propagation applied.
"""
if (
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
== 0
):
return exported_program

has_control_flow = [
node
for node in exported_program.graph.nodes
if node.target == torch.ops.higher_order.cond
]
if len(has_control_flow) > 0:
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")

const_node_to_tensor = get_propagated_const_tensor_dict(
exported_program, custom_skip_targets
)

# Get old input specs.
name_to_spec_dict = {
s.arg.name: s for s in exported_program.graph_signature.input_specs
}
# Add the new constants to input specs dict.
name_to_spec_dict.update(
create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program)
)

# Generate new input spec.
new_input_specs = []
for node in exported_program.graph.nodes:
if node.op == "call_function":
constant_data_name_list = [
input_spec.target for input_spec in prop_constant_data
]
if is_const(node.args, exported_program, constant_data_name_list):
args_data = [get_data(exported_program, arg) for arg in node.args]
kwargs_data = node.kwargs
const_data_to_be_removed.update(node.args)
prop_constant_tensor = node.target(*args_data, **kwargs_data)
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(buffers)}"

with exported_program.graph.inserting_before(first_user_input):
const_placeholder_node = exported_program.graph.placeholder(
prop_constant_tensor_fqn
)
# Update the meta data of the new placeholder (buffer) node
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
prop_constant_tensor, static_shapes=True
)
const_placeholder_node.meta["val"].constant = prop_constant_tensor

node.replace_all_uses_with(const_placeholder_node)
exported_program.graph.erase_node(node)
prop_constant_node_input_spec = InputSpec(
kind=InputKind.BUFFER,
arg=TensorArgument(name=const_placeholder_node.name),
target=prop_constant_tensor_fqn,
persistent=True,
)
prop_constant_data.append(prop_constant_node_input_spec)
buffers.append(prop_constant_tensor_fqn)
exported_program.state_dict[prop_constant_tensor_fqn] = (
prop_constant_tensor
)
exported_program.graph_signature.input_specs.append(
prop_constant_node_input_spec
)

# Remove the propogated buffer from the state dict
for node in exported_program.graph.nodes:
if (
node.op == "placeholder"
and node in const_data_to_be_removed
and len(node.users) == 0
):
exported_program.state_dict.pop(node.name, None)
exported_program.graph.erase_node(node)
if node.op != "placeholder":
continue
new_input_specs.append(name_to_spec_dict[node.name])
exported_program.graph_signature.input_specs = new_input_specs

# Cleanup the graph.
exported_program.graph.eliminate_dead_code()
exported_program.graph_module.recompile()

return exported_program
Loading

0 comments on commit 5ef8427

Please sign in to comment.