Skip to content

Commit

Permalink
Handle constants in exported program
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Feb 9, 2024
1 parent 54bd43f commit e64f5b1
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion test/stablehlo/test_export_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_llama_export(self):
arg = (torch.randint(0, 1000, (8, 100)), torch.arange(0, 100), None)
options = StableHLOExportOptions()
options.override_tracing_arguments = arg
exported = torch.export.export(model, arg)
with torch.no_grad():
exported = torch.export.export(model, arg)
with tempfile.TemporaryDirectory() as tempdir:
save_as_stablehlo(exported, tempdir, options)

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
state_dict = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device),
exported_model.state_dict)

if (constants := getattr(exported_model, 'constants')) is not None
if (constants := getattr(exported_model, 'constants')) is not None:
state_dict.update(pytree.tree_map_only(
torch.Tensor, lambda x: x.to(device=device), constants))

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def inner(*args):
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
return tfxla.call_module(
tuple(call_args),
version=5,
version=6,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
Expand Down

0 comments on commit e64f5b1

Please sign in to comment.