Skip to content

Commit

Permalink
Read from constants if it exists (#6510)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and bhavya01 committed Apr 22, 2024
1 parent 9b31b3d commit 5b1f7d2
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
3 changes: 2 additions & 1 deletion test/stablehlo/test_export_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_llama_export(self):
arg = (torch.randint(0, 1000, (8, 100)), torch.arange(0, 100), None)
options = StableHLOExportOptions()
options.override_tracing_arguments = arg
exported = torch.export.export(model, arg)
with torch.no_grad():
exported = torch.export.export(model, arg)
with tempfile.TemporaryDirectory() as tempdir:
save_as_stablehlo(exported, tempdir, options)

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ def _exported_program_to_stablehlo_bundle(exported_model,
param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
state_dict = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device),
exported_model.state_dict)

if (constants := getattr(exported_model, 'constants')) is not None:
state_dict.update(
pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device),
constants))

param_buffer_values = (state_dict[key] for key in param_and_buffer_keys)

if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def inner(*args):
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
return tfxla.call_module(
tuple(call_args),
version=5,
version=6,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
Expand Down

0 comments on commit 5b1f7d2

Please sign in to comment.