From 8f7d9d5219132508a72f8b7918f3cfc35d94f81c Mon Sep 17 00:00:00 2001 From: helunwencser Date: Tue, 17 Sep 2024 15:31:46 -0700 Subject: [PATCH] Allow mutating input tensor (#4850) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4850 To support dynamic kv cache, we need to pass in kv cache as an input tensor and update it inside the model. This PR allows mutating input tensor. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: JacobSzwejbka Differential Revision: D61683366 Pulled By: helunwencser fbshipit-source-id: b480073d16ddcc624d12c23918a78dfca966e0dd --- exir/emit/test/test_emit.py | 20 ++++++++++ .../insert_write_back_for_buffers_pass.py | 38 +++++++++++-------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 4d362d1b51..2feeefc4ef 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1649,3 +1649,23 @@ def forward(self, x): self.assertEqual( pte_data.execution_plan, model.executorch_program.execution_plan ) + + def test_mutate_input_tensor(self) -> None: + class MutateInputTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x.add_(1) + + model = to_edge( + export(MutateInputTensorModule(), (torch.zeros(1),)) + ).to_executorch( + config=ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False) + ) + ) + executorch_model = _load_for_executorch_from_buffer(model.buffer) + input = torch.zeros(1) + executorch_model(input) + self.assertEqual(input, torch.ones(1)) diff --git a/exir/passes/insert_write_back_for_buffers_pass.py b/exir/passes/insert_write_back_for_buffers_pass.py index 7aef357191..1ddbf98e7e 100644 --- a/exir/passes/insert_write_back_for_buffers_pass.py +++ b/exir/passes/insert_write_back_for_buffers_pass.py @@ -15,6 +15,7 @@ OutputKind, OutputSpec, ) +from torch.export.graph_signature import TensorArgument from torch.utils import _pytree as pytree @@ -73,20 +74,21 @@ def insert_write_back_for_buffers_pass( ep: ExportedProgram, ) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: gm: torch.fx.GraphModule = ep.graph_module - lifted_inputs: List[Optional[str]] = [ - ( - in_spec.target - if in_spec.kind - in ( - InputKind.BUFFER, - InputKind.CONSTANT_TENSOR, - InputKind.PARAMETER, - InputKind.CUSTOM_OBJ, - ) - else None - ) - for in_spec in ep.graph_signature.input_specs - ] + lifted_inputs: List[Optional[str]] = [] + for in_spec in ep.graph_signature.input_specs: + if in_spec.kind in ( + InputKind.BUFFER, + InputKind.CONSTANT_TENSOR, + InputKind.PARAMETER, + InputKind.CUSTOM_OBJ, + ): + lifted_inputs.append(in_spec.target) + elif in_spec.kind is InputKind.USER_INPUT and isinstance( + in_spec.arg, TensorArgument + ): + lifted_inputs.append(in_spec.arg.name) + else: + lifted_inputs.append(None) input_name_to_node: Dict[str, torch.fx.Node] = {} @@ -101,7 +103,8 @@ def insert_write_back_for_buffers_pass( mutated_outputs: List[Optional[str]] = [ ( out_spec.target - if out_spec.kind in (OutputKind.BUFFER_MUTATION,) + if out_spec.kind + in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) and out_spec.arg.name not in { val.name for val in input_name_to_node.values() @@ -121,7 +124,10 @@ def insert_write_back_for_buffers_pass( new_output_specs: List[OutputSpec] = [] i = 0 for output_spec in ep.graph_signature.output_specs: - if output_spec.kind == OutputKind.BUFFER_MUTATION: + if output_spec.kind in ( + OutputKind.BUFFER_MUTATION, + OutputKind.USER_INPUT_MUTATION, + ): output_spec.arg.name = buffer_output_nodes[i].name i += 1 new_output_specs.append(output_spec)