Skip to content

Commit

Permalink
exclude mutated buffer (#2876)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2876

Fixing the tag constant for mutable buffer. The buffer shouldn't be tagged if it's going to be mutated by the delegated. It's more common in hardware backends

Will follow up and test having delegate consume mutation

Reviewed By: mcr229, angelayi

Differential Revision: D55812844

fbshipit-source-id: e0be4c2dc295141d673cccb1aeecee45894b1e70
  • Loading branch information
cccclai authored and facebook-github-bot committed Apr 8, 2024
1 parent dc7e4d5 commit 599cfde
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 15 deletions.
84 changes: 83 additions & 1 deletion exir/backend/test/test_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
ExecutorBackend,
)
from executorch.exir.backend.utils import get_delegates
from executorch.exir.backend.utils import get_delegates, tag_constant_data

from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -523,3 +523,85 @@ def partition(
"constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
str(error.exception),
)

def test_not_delegate_mutable_buffers(self) -> None:
"""
A test case to check the mutated buffer is not delegated. We'll need to add a test case
to consider when the delegate can consume the mutable buffer.
"""

class MutableStateModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("my_state", torch.zeros(1))

def forward(self, x):
y = x + self.my_state
self.my_state.add_(1)
return y

edge = exir.to_edge(
torch.export.export(
MutableStateModule(),
(torch.zeros(1),),
)
)
self.assertGreater(
len(edge.exported_program().graph_signature.buffers_to_mutate),
0,
"The test case should at leaset one mutable buffer",
)

class PartitionerTagData(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
ExecutorBackend.__name__,
[CompileSpec(key, value) for key, value in self.spec.items()],
)

def partition(
self, edge_exported_program: ExportedProgram
) -> PartitionResult:
partition_tags = {}
for node in edge_exported_program.graph.nodes:
if node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor
]:
delegation_tag = "tag0"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
tag_constant_data(edge_exported_program)
return PartitionResult(
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)

# Check the edge program inital buffers_to_mutate
mutate_op = "aten_add_tensor_1"
self.assertEqual(
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
"my_state",
)
edge = edge.to_backend(PartitionerTagData())
# After to_backend, add is delegated and is no longer in buffers_to_mutate.
self.assertNotIn(
mutate_op,
edge.exported_program().graph_signature.buffers_to_mutate,
)

mutate_op = "getitem_1"
# Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate)
self.assertEqual(
edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
"my_state",
)
# Check the copy_ node is inserted
edge = edge.to_executorch()
copy_node = [
node
for node in edge.exported_program().graph.nodes
if node.op == "call_function"
and node.target == torch.ops.aten.copy_.default
]
self.assertEqual(len(copy_node), 1)
43 changes: 29 additions & 14 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,27 +508,42 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
subgraph. Throw error when const/param/buffers is used across different partitions. That is the
underlying data will be owned by multiple delegates.
"""
mutated_buffer = set()
for node in edge_program.graph.nodes:
if node.op == "placeholder" and (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
):
for node_user in node.users:
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
logging.info(
"The buffer node is a mutated buffer node, which is not constant."
)
mutated_buffer.add(node)

for node in edge_program.graph.nodes:
# go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
if node.op == "placeholder" and (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
):
user_tags = set()
for user in node.users:
user_tag = user.meta.get("delegation_tag", None)
if user_tag is not None:
user_tags.add(user_tag)
if len(user_tags) > 1:
logging.info(
f"The data node is used across multiple partitions, including {user_tags}. "
"If the data is too large and it's not preferred to copy, please tag the "
"constant node like node.['no_copy'] = True and they won't be copied."
)
# tag the data node with the same tag as the last user
if len(user_tags) > 0:
node.meta["delegation_tag"] = user_tags.pop()
if node not in mutated_buffer:
user_tags = set()
for user in node.users:
user_tag = user.meta.get("delegation_tag", None)
if user_tag is not None:
user_tags.add(user_tag)
if len(user_tags) > 1:
logging.info(
f"The data node is used across multiple partitions, including {user_tags}. "
"If the data is too large and it's not preferred to copy, please tag the "
"constant node like node.['no_copy'] = True and they won't be copied."
)
# tag the data node with the same tag as the last user
if len(user_tags) > 0:
node.meta["delegation_tag"] = user_tags.pop()


# TODO - style: use templated types
Expand Down

0 comments on commit 599cfde

Please sign in to comment.