diff --git a/test/stablehlo/test_export_llama.py b/test/stablehlo/test_export_llama.py index 9387d11b2bfa..74449e39b7bd 100644 --- a/test/stablehlo/test_export_llama.py +++ b/test/stablehlo/test_export_llama.py @@ -39,7 +39,8 @@ def test_llama_export(self): arg = (torch.randint(0, 1000, (8, 100)), torch.arange(0, 100), None) options = StableHLOExportOptions() options.override_tracing_arguments = arg - exported = torch.export.export(model, arg) + with torch.no_grad(): + exported = torch.export.export(model, arg) with tempfile.TemporaryDirectory() as tempdir: save_as_stablehlo(exported, tempdir, options) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index a8ff5df76675..ba37dad6af3b 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -303,6 +303,12 @@ def _exported_program_to_stablehlo_bundle(exported_model, param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers state_dict = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), exported_model.state_dict) + + if (constants := getattr(exported_model, 'constants')) is not None: + state_dict.update( + pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), + constants)) + param_buffer_values = (state_dict[key] for key in param_and_buffer_keys) if hasattr(exported_model.graph_signature, "lifted_tensor_constants"): diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 4737c7cf9ff2..946d5460f5f1 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -33,7 +33,7 @@ def inner(*args): call_args = stablehlo._extract_call_parameters(args, func.meta, bundle) return tfxla.call_module( tuple(call_args), - version=5, + version=6, Tout=Touts, # dtype information Sout=Souts, # Shape information function_list=[],