Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx_importer] Convert non-persistent buffers lifted as tensor constants #2902

Merged
merged 4 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,26 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram):
sig = prog.graph_signature
state_dict = prog.state_dict
arg_replacements: dict[str, Any] = {}
# Lift buffers.
for input_name, state_name in sig.inputs_to_buffers.items():
try:
state_value = state_dict[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for buffer") from e
arg_replacements[input_name] = state_value

# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
if hasattr(prog, "constants"):
constants = prog.constants
# Lift tensor constants.
for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items():
try:
state_value = constants[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for tensor constants") from e
arg_replacements[input_name] = state_value
else:
# Lift buffers.
for input_name, state_name in sig.inputs_to_buffers.items():
try:
state_value = state_dict[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for buffer") from e
arg_replacements[input_name] = state_value

# Lift parameters.
for input_name, state_name in sig.inputs_to_parameters.items():
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
72fcb9ad662bb941a266e3d747835382634c2be6
3cbc8e89fd09b0ffb4914187b438f15c121e2302
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.3.0.dev20240122
torch==2.3.0.dev20240207
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.18.0.dev20240122
torchvision==0.18.0.dev20240207
Loading