Skip to content

Commit

Permalink
[fx_importer] Convert non-persistent buffers lifted as tensor constan…
Browse files Browse the repository at this point in the history
…ts (llvm#2902)

The investigation is largely recorded in
llvm#2881, but this change allows us
to capture non-persistent buffers that were lifted as tensor constants
(after pytorch/pytorch#118969 landed in upstream
PyTorch), and propagate them to `Torch` dialect as "frozen"
`torch.vtensor.literal`. I believe this patch should work with both
nightly and stable PyTorch, but will let CI confirm the same. Thanks
@stellaraccident for the valuable pointers and guidance.

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
sjain-stanford and vivekkhandelwal1 authored Feb 13, 2024
1 parent 9b967f6 commit 3e836d8
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
27 changes: 20 additions & 7 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,26 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram):
sig = prog.graph_signature
state_dict = prog.state_dict
arg_replacements: dict[str, Any] = {}
# Lift buffers.
for input_name, state_name in sig.inputs_to_buffers.items():
try:
state_value = state_dict[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for buffer") from e
arg_replacements[input_name] = state_value

# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
if hasattr(prog, "constants"):
constants = prog.constants
# Lift tensor constants.
for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items():
try:
state_value = constants[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for tensor constants") from e
arg_replacements[input_name] = state_value
else:
# Lift buffers.
for input_name, state_name in sig.inputs_to_buffers.items():
try:
state_value = state_dict[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for buffer") from e
arg_replacements[input_name] = state_value

# Lift parameters.
for input_name, state_name in sig.inputs_to_parameters.items():
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
72fcb9ad662bb941a266e3d747835382634c2be6
3cbc8e89fd09b0ffb4914187b438f15c121e2302
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.3.0.dev20240122
torch==2.3.0.dev20240207
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.18.0.dev20240122
torchvision==0.18.0.dev20240207

0 comments on commit 3e836d8

Please sign in to comment.