From 3e836d8dad551b6e5302de1b84840b90ee039c83 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 13 Feb 2024 12:38:32 -0800 Subject: [PATCH] [fx_importer] Convert non-persistent buffers lifted as tensor constants (#2902) The investigation is largely recorded in https://github.com/llvm/torch-mlir/pull/2881, but this change allows us to capture non-persistent buffers that were lifted as tensor constants (after https://github.com/pytorch/pytorch/pull/118969 landed in upstream PyTorch), and propagate them to `Torch` dialect as "frozen" `torch.vtensor.literal`. I believe this patch should work with both nightly and stable PyTorch, but will let CI confirm the same. Thanks @stellaraccident for the valuable pointers and guidance. --------- Co-authored-by: Vivek Khandelwal --- python/torch_mlir/extras/fx_importer.py | 27 ++++++++++++++++++------- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 6749c6078e49..b70487ad5ad9 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -373,13 +373,26 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): sig = prog.graph_signature state_dict = prog.state_dict arg_replacements: dict[str, Any] = {} - # Lift buffers. - for input_name, state_name in sig.inputs_to_buffers.items(): - try: - state_value = state_dict[state_name] - except KeyError as e: - raise AssertionError("Could not find state mapping for buffer") from e - arg_replacements[input_name] = state_value + + # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look + # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 + if hasattr(prog, "constants"): + constants = prog.constants + # Lift tensor constants. + for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items(): + try: + state_value = constants[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for tensor constants") from e + arg_replacements[input_name] = state_value + else: + # Lift buffers. + for input_name, state_name in sig.inputs_to_buffers.items(): + try: + state_value = state_dict[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for buffer") from e + arg_replacements[input_name] = state_value # Lift parameters. for input_name, state_name in sig.inputs_to_parameters.items(): diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 16be42d6c147..d78b0d0694d4 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -72fcb9ad662bb941a266e3d747835382634c2be6 +3cbc8e89fd09b0ffb4914187b438f15c121e2302 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 1de47ff9a195..540e78dccd49 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240122 +torch==2.3.0.dev20240207 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index fad713123493..4f775c549c6c 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.18.0.dev20240122 +torchvision==0.18.0.dev20240207