From 85b97dce76ddce800c43bfa864f30caaa4c51c26 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 23 Aug 2023 21:46:09 +0000 Subject: [PATCH] Refactor stablehlo API and put them in official location. --- test/stablehlo/test_export_llama.py | 13 +- test/stablehlo/test_saved_model.py | 45 +++ test/stablehlo/test_stablehlo_inference.py | 51 ++- ...lo_dump.py => test_stablehlo_save_load.py} | 22 +- torch_xla/__init__.py | 2 + .../stablehlo_saved_model.py => stablehlo.py} | 375 +++++++++++------- torch_xla/tf_saved_model_integration.py | 125 ++++++ 7 files changed, 470 insertions(+), 163 deletions(-) create mode 100644 test/stablehlo/test_saved_model.py rename test/stablehlo/{test_stablehlo_dump.py => test_stablehlo_save_load.py} (76%) rename torch_xla/{experimental/stablehlo_saved_model.py => stablehlo.py} (58%) create mode 100644 torch_xla/tf_saved_model_integration.py diff --git a/test/stablehlo/test_export_llama.py b/test/stablehlo/test_export_llama.py index 6046ec76e332..211f6f3c4d81 100644 --- a/test/stablehlo/test_export_llama.py +++ b/test/stablehlo/test_export_llama.py @@ -1,6 +1,6 @@ import torch_xla import torch_xla.core.xla_model as xm -from torch_xla.experimental.stablehlo_saved_model import save_as_stablehlo +from torch_xla.stablehlo import save_as_stablehlo, StableHLOExportOptions import torch import torch._export import torchvision @@ -26,22 +26,23 @@ def test_llama_export(self): gen = llama_model.GenLoop(model).eval() arg = (torch.randint(0, 1000, (1, 10)),) - print(gen(*arg).shape) arg[0].requires_grad = False + options = StableHLOExportOptions() + options.override_tracing_arguments = arg with torch.no_grad(): exported2 = torch._export.export(gen, arg) - print(exported2.graph_module.code) with tempfile.TemporaryDirectory() as tempdir: - save_as_stablehlo(exported2, arg, tempdir) + save_as_stablehlo(exported2, tempdir, options) def test_llama_export(self): options = llama_model.ModelArgs() model = llama_model2.Transformer(options) arg = (torch.randint(0, 1000, (8, 100)), torch.arange(0, 100), None) - print('val', model(*arg)) + options = StableHLOExportOptions() + options.override_tracing_arguments = arg exported = torch._export.export(model, arg) with tempfile.TemporaryDirectory() as tempdir: - save_as_stablehlo(exported, arg, tempdir) + save_as_stablehlo(exported, tempdir, options) if __name__ == '__main__': diff --git a/test/stablehlo/test_saved_model.py b/test/stablehlo/test_saved_model.py new file mode 100644 index 000000000000..75c10c299b6b --- /dev/null +++ b/test/stablehlo/test_saved_model.py @@ -0,0 +1,45 @@ +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.stablehlo import StableHLOExportOptions, exported_program_to_stablehlo +from torch_xla.tf_saved_model_integration import make_tf_function, save_torch_module_as_tf_saved_model +from torch.utils import _pytree as pytree +import torch +import torchvision + +import tempfile +import unittest +import tensorflow as tf + + +class StableHLOInferenceTest(unittest.TestCase): + + def test_resnet18_inference(self): + resnet18 = torchvision.models.resnet18().eval() + data = torch.randn(4, 3, 224, 224) + output = resnet18(data) + + exported = torch.export.export(resnet18, (data,)) + options = StableHLOExportOptions(override_tracing_arguments=(data,)) + stablehlo_program = exported_program_to_stablehlo(exported, options) + tf_func = make_tf_function(stablehlo_program) + + output_tf = tf_func(*options.override_tracing_arguments) + output2 = torch.tensor(output_tf[0].numpy()) + self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + + def test_resnet18_save_load(self): + resnet18 = torchvision.models.resnet18().eval() + data = torch.randn(4, 3, 224, 224) + output = resnet18(data) + + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model(resnet18, (data,), tempdir) + loaded_m = tf.saved_model.load(tempdir) + res = loaded_m.f(data.detach().numpy())[0] + output2 = torch.tensor(res.numpy()) + self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/stablehlo/test_stablehlo_inference.py b/test/stablehlo/test_stablehlo_inference.py index 73504c918bf9..d8f09343b897 100644 --- a/test/stablehlo/test_stablehlo_inference.py +++ b/test/stablehlo/test_stablehlo_inference.py @@ -1,6 +1,7 @@ import torch_xla import torch_xla.core.xla_model as xm -from torch_xla.experimental.stablehlo_saved_model import export_torch_model +from torch_xla.stablehlo import exported_program_to_stablehlo, StableHLOExportOptions +from torch.utils import _pytree as pytree import torch import torchvision @@ -8,6 +9,13 @@ import unittest +def export_torch_model(model, args): + exported = torch._export.export(model, args) + options = StableHLOExportOptions() + options.override_tracing_arguments = args + return exported_program_to_stablehlo(exported, options) + + class StableHLOInferenceTest(unittest.TestCase): def test_resnet18_inference(self): @@ -19,18 +27,18 @@ def test_resnet18_inference(self): resnet18, (data,), ) - output2 = exported(data)[0].cpu() - - self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + output2 = exported(data).cpu() + self.assertTrue(torch.allclose(output, output2, atol=1e-3)) def test_resnet18_save_load(self): + return import tensorflow as tf resnet18 = torchvision.models.resnet18() data = torch.randn(4, 3, 224, 224) output = resnet18(data) output_np = output.detach().numpy() - exported = export_torch_model(resnet18, (data,), to_tf=True) + exported = export_torch_model(resnet18, (data,)) tf_m = tf.Module() tf_m.f = tf.function( exported, @@ -58,9 +66,38 @@ def forward(self, x, y): data = (torch.randn(100, 100), torch.tensor(4.4)) output = m(*data) exported = export_torch_model(m, data) - output2 = exported(*data)[0].cpu() - self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + device = xm.xla_device() + data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) + output2 = exported(*data).cpu() + + self.assertTrue(torch.allclose(output, output2, atol=1e-3)) + + def test_model_with_dict(self): + + class DictInput(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, inputs): + return (inputs['x'] + inputs['y'], inputs['x'] * inputs['y']) + + m = DictInput() + data = ({ + 'x': torch.randn((10, 10)), + 'y': torch.randn((10, 10)), + },) + + output = m(*data) + exported = export_torch_model(m, data) + device = xm.xla_device() + data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) + output2 = exported(*data) + self.assertEqual(len(output2), 2) + + self.assertTrue(torch.allclose(output[0], output2[0].cpu())) + self.assertTrue(torch.allclose(output[1], output2[1].cpu())) if __name__ == '__main__': diff --git a/test/stablehlo/test_stablehlo_dump.py b/test/stablehlo/test_stablehlo_save_load.py similarity index 76% rename from test/stablehlo/test_stablehlo_dump.py rename to test/stablehlo/test_stablehlo_save_load.py index b14ca3e013d4..8bb2806b536d 100644 --- a/test/stablehlo/test_stablehlo_dump.py +++ b/test/stablehlo/test_stablehlo_save_load.py @@ -1,7 +1,8 @@ import tempfile import torch_xla import torch_xla.core.xla_model as xm -from torch_xla.experimental import stablehlo_saved_model +from torch_xla import save_torch_model_as_stablehlo, save_as_stablehlo +from torch_xla.stablehlo import StableHLOExportOptions, StableHLOGraphModule import torch import torch._export import torchvision @@ -91,13 +92,22 @@ def test_save_load(self): model = ElementwiseAdd() inputs = model.get_random_inputs() exported = torch._export.export(model, inputs) - bundle = stablehlo_saved_model._exported_program_to_stablehlo_bundle( - exported, inputs) + options = StableHLOExportOptions() + options.override_tracing_arguments = inputs with tempfile.TemporaryDirectory() as tempdir: - stablehlo_saved_model._save_program_bundle(bundle, tempdir) - bundle2 = stablehlo_saved_model._load_program_bundle(tempdir) + save_as_stablehlo(exported, tempdir, options) + program2 = StableHLOGraphModule.load(tempdir) + result = program2(*inputs).detach().cpu() + self.assertTrue(torch.allclose(model(*inputs), result)) - self.assertEqual(bundle.stablehlo_funcs, bundle2.stablehlo_funcs) + def test_save_load2(self): + model = ElementwiseAdd() + inputs = model.get_random_inputs() + with tempfile.TemporaryDirectory() as tempdir: + save_torch_model_as_stablehlo(model, inputs, tempdir) + program2 = StableHLOGraphModule.load(tempdir) + result = program2(*inputs).detach().cpu() + self.assertTrue(torch.allclose(model(*inputs), result)) if __name__ == '__main__': diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 098ce1b82360..6faf5286419e 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -138,3 +138,5 @@ def _init_xla_lazy_backend(): # keep PyTorch/XLA CI healthy. # TODO @wonjoo come up with a long term fix in Dynamo. torch._dynamo.config.automatic_dynamic_shapes = False + +from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo \ No newline at end of file diff --git a/torch_xla/experimental/stablehlo_saved_model.py b/torch_xla/stablehlo.py similarity index 58% rename from torch_xla/experimental/stablehlo_saved_model.py rename to torch_xla/stablehlo.py index d42ce99899e4..e87ef574a478 100644 --- a/torch_xla/experimental/stablehlo_saved_model.py +++ b/torch_xla/stablehlo.py @@ -19,7 +19,7 @@ from torch_xla.debug import metrics import torchvision import torch._dynamo as torchdynamo -from torch.utils._pytree import tree_map_only +from torch.utils import _pytree as pytree from typing import Tuple, Type, Callable @@ -38,85 +38,60 @@ def _get_numpy_dtype(dtype): }.get(dtype) -class SHLOModel: +def _extract_call_parameters(args, meta, bundle): + call_args = [] + if meta.input_pytree_spec is not None: + args, _ = pytree.tree_flatten(args) + for loc in meta.input_locations: + if loc.type_ == VariableType.PARAMETER: + call_args.append(bundle.state_dict[loc.name]) + elif loc.type_ == VariableType.CONSTANT: + call_args.append(bundle.additional_constants[loc.position]) + else: + call_args.append(args[loc.position]) + return call_args + + +class StableHLOGraphModule: - def __init__(self, bundle, to_tf=False, default_method_name='forward'): + def __init__(self, bundle): self._bundle = bundle self._name_to_stablehlo = { - meta.name: (meta, stablehlo) - for meta, stablehlo in bundle.stablehlo_funcs + func.meta.name: func for func in bundle.stablehlo_funcs } - self._default_method = default_method_name - self._to_tf = to_tf + self._default_method = bundle.stablehlo_funcs[0].meta.name def evaluate(self, method_name, args): - meta, stablehlo = self._name_to_stablehlo[method_name] - call_args = [] - for loc in meta.input_locations: - if loc.type_ == VariableType.PARAMETER: - call_args.append(torch.from_numpy(self._bundle.state_dict[loc.name])) - elif loc.type_ == VariableType.CONSTANT: - call_args.append(self._bundle.additional_constants[loc.position]) - else: - call_args.append(args[loc.position]) - if not self._to_tf: - return torch_xla._XLAC._run_stablehlo(stablehlo, call_args) - else: - from tensorflow.compiler.tf2xla.python import xla as tfxla - output_sig = meta.output_signature[0] - return tfxla.call_module( - tuple(call_args), - version=5, - Tout=[output_sig.dtype], # dtype information - Sout=[output_sig.shape], # Shape information - function_list=[], - platforms=('CPU',), - module=stablehlo, - ) + func = self._name_to_stablehlo[method_name] + call_args = _extract_call_parameters(args, func.meta, self._bundle) + call_args = [ + torch.from_numpy(x) if isinstance(x, np.ndarray) else x + for x in call_args + ] + res = torch_xla._XLAC._run_stablehlo(func.bytecode, call_args) + if func.meta.output_pytree_spec is not None: + out_spec = pytree.str_to_pytree(func.meta.output_pytree_spec) + res = pytree.tree_unflatten(res, out_spec) + return res + + def get_stablehlo_bytecode(self, method_name): + return self._name_to_stablehlo[method_name].bytecode + + def get_stablehlo_text(self, method_name): + return self._name_to_stablehlo[method_name].text + + def save(self, directory_path): + _save_program_bundle(self._bundle, directory_path) + + @classmethod + def load(cls, directory_path): + bundle = _load_program_bundle(directory_path) + return cls(bundle) def __call__(self, *args): return self.evaluate(self._default_method, args) -def export_torch_model(model: torch.nn.Module, - sample_inputs: Tuple[torch.Tensor], - to_tf: bool = False): - """Convert model into a callable backed by StableHLO. - - Args: - model: torch.nn.Module - a pytorch model - sample_inputs: Tuple[torch.Tensor] - The input to this model - to_tf: bool - If export a callable that is compatible with tf.saved_model - - This function will return a callable backed by StableHLO such that, - - model(*sample_inputs) == export_torch_model(model, sample_inputs)(*sample_inputs) - (up to numerics) - - In other words, returned callable have the same calling convention of the input model, and - on the sample input, or inputs sufficiently similar* to sample input, - it is will to return same result as the original model. - - * sufficiently similar input because this function will use tracing to extract the model operations - so it might specialize on the shapes of the sample input. - - For now, model has to only take Tensors as input and has to return 1 tensor as output. - - """ - - device = xm.xla_device() - args = tuple( - tree_map_only(torch.Tensor, lambda x: x.to(device=device), sample_inputs)) - orig_state_dict = copy.copy(model.state_dict()) - orig_state_dict = tree_map_only(torch.Tensor, lambda x: x.numpy(), - orig_state_dict) - model = model.to(device) - bundle = _callable_to_stablehlo_bundle(model, args, model.state_dict()) - bundle.state_dict = orig_state_dict - stablehlo_model = SHLOModel(bundle, to_tf) - return stablehlo_model - - class VariableType(enum.Enum): INPUT_ARG = 'input_arg' PARAMETER = 'parameter' @@ -161,6 +136,10 @@ class StableHLOFunctionMeta: # the arguments the user supplied, OR a parameter, OR a constant input_locations: List[InputLocation] + # input_pytree_spec + input_pytree_spec: Optional[str] = None + output_pytree_spec: Optional[str] = None + class StableHLOJSONSerializer(json.JSONEncoder): @@ -175,6 +154,7 @@ def default(self, obj): def stablehlo_obj_hook(dct): targets = [ StableHLOFunctionMeta, + StableHLOFunc, VariableSignature, InputLocation, VariableSignature, @@ -197,6 +177,13 @@ def _try_convert_as_enum(v): return clazz(**new_dict) +@dataclass +class StableHLOFunc: + meta: StableHLOFunctionMeta + bytecode: bytes + text: Optional[str] + + @dataclass class StableHLOModelBundle: # original state dict; but torch.Tensor's converted to np.array @@ -204,20 +191,70 @@ class StableHLOModelBundle: # Additional constants that we decide to hardcode. additional_constants: List[np.ndarray] # can support the case of multiple callable of the same model. - stablehlo_funcs: List[Tuple[StableHLOFunctionMeta, bytes]] + stablehlo_funcs: List[StableHLOFunc] -@dataclass -class StableHLOExportOptions: - pass +class XLAExportInterpreter(torch.fx.Interpreter): + def __init__(self, module, device): + self._device = device + super().__init__(module) + + def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: + # NOTE(qihqi): We need to do this because there are some operators + # that creates new tensor. And those operators would create it on CPU. + # this bit of code basically move it to XLA device, otherwise we would + # get an error saying we cannot do math between a XLA tensor and a CPU + # tensor. + new_kwargs = dict(kwargs) + if 'device' in kwargs: + new_kwargs['device'] = self._device + return super().call_function(target, args, new_kwargs) + + +def _exported_program_to_stablehlo_bundle(exported_model, + options) -> StableHLOModelBundle: + if options is None: + options = StableHLOExportOptions() + + if options.override_tracing_arguments is not None: + args = options.override_tracing_arguments + else: + args = getattr(exported_model, 'original_traced_arguments', None) + if args is None: + raise ValueError( + 'No argument is provided, please set tracing argument in options.override_tracing_arguments' + ) + + device = xm.xla_device() + + if exported_model.call_spec.in_spec is not None: + input_args = fx_pytree.tree_flatten_spec(args, + exported_model.call_spec.in_spec) + else: + input_args = copy.deepcopy(args) + + input_args = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), + input_args) + + # NOTE call convention: (parameters, buffers, user_inputs) + param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers + state_dict = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), + exported_model.state_dict) + param_buffer_values = (state_dict[key] for key in param_and_buffer_keys) + + num_mutations = len(exported_model.graph_signature.buffers_to_mutate) -def _callable_to_stablehlo_bundle(func, input_args, state_dict): xm.mark_step() xm.wait_device_ops() metrics.clear_counters() device = xm.xla_device() - res = func(*input_args) + + # Run the fx graph tracing using lazy tensor + with torch.no_grad(): + res = XLAExportInterpreter(exported_model.graph_module, device).run( + *param_buffer_values, *input_args, enable_io_processing=False) + res = res[num_mutations:] # If there are any fallback ops, this means that in torch/XLA side, # not all ops are lowerable to HLO. @@ -226,9 +263,6 @@ def _callable_to_stablehlo_bundle(func, input_args, state_dict): "\n".join(fallback_ops)) raise RuntimeError(message) - if isinstance(res, torch.Tensor): - res = (res,) - ( graph_input_tensor_ids, graph_input_xla_values, @@ -241,6 +275,10 @@ def _callable_to_stablehlo_bundle(func, input_args, state_dict): } stablehlo_content = xm.get_stablehlo_bytecode(res) + if options.include_human_readable_text: + stablehlo_text = xm.get_stablehlo(res) + else: + stablehlo_text = None pos_to_orig_pos = {} pos_to_param = {} @@ -252,6 +290,7 @@ def _callable_to_stablehlo_bundle(func, input_args, state_dict): for pos, tensor in enumerate(input_args) if isinstance(tensor, torch.Tensor) } + for hlo_input_pos, (tensor_id, tensor_value) in enumerate( zip(graph_input_tensor_ids, graph_input_xla_values)): if tensor_id in input_ids: # this is input @@ -283,69 +322,21 @@ def _callable_to_stablehlo_bundle(func, input_args, state_dict): input_signature=input_signatures, output_signature=output_signature, input_locations=input_locations, + input_pytree_spec=pytree.pytree_to_str(exported_model.call_spec.in_spec), + output_pytree_spec=pytree.pytree_to_str( + exported_model.call_spec.out_spec), ) - - return StableHLOModelBundle( - stablehlo_funcs=[(meta, stablehlo_content)], - state_dict={}, + bundle = StableHLOModelBundle( + stablehlo_funcs=[StableHLOFunc(meta, stablehlo_content, stablehlo_text)], + state_dict=pytree.tree_map_only(torch.Tensor, + lambda x: x.detach().cpu().numpy(), + exported_model.state_dict), additional_constants=additional_constants, ) - -class XLAExportInterpreter(torch.fx.Interpreter): - - def __init__(self, module, device): - self._device = device - super().__init__(module) - - def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: - # NOTE(qihqi): We need to do this because there are some operators - # that creates new tensor. And those operators would create it on CPU. - # this bit of code basically move it to XLA device, otherwise we would - # get an error saying we cannot do math between a XLA tensor and a CPU - # tensor. - new_kwargs = dict(kwargs) - if 'device' in kwargs: - new_kwargs['device'] = self._device - return super().call_function(target, args, new_kwargs) - - -def _exported_program_to_stablehlo_bundle(exported_model, - args) -> StableHLOModelBundle: - device = xm.xla_device() - - if exported_model.call_spec.in_spec is not None: - args = fx_pytree.tree_flatten_spec(args, exported_model.call_spec.in_spec) - else: - args = copy.deepcopy(args) - - args = tree_map_only(torch.Tensor, lambda x: x.to(device=device), args) - - # NOTE call convention: (parameters, buffers, user_inputs) - param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers - state = tree_map_only(torch.Tensor, lambda x: x.to(device=device), - exported_model.state_dict) - param_buffer_values = (state[key] for key in param_and_buffer_keys) - - num_mutations = len(exported_model.graph_signature.buffers_to_mutate) - - def forward(*args): - with torch.no_grad(): - res = XLAExportInterpreter(exported_model.graph_module, device).run( - *param_buffer_values, *args, enable_io_processing=False) - return res[num_mutations:] - - bundle = _callable_to_stablehlo_bundle(forward, args, state) - bundle.state_dict = tree_map_only(torch.Tensor, - lambda x: x.detach().cpu().numpy(), - exported_model.state_dict) return bundle -class StableHLOExportOptions: - pass - - def _save_program_bundle(bundle: StableHLOModelBundle, stablehlo_dir: os.PathLike) -> None: @@ -358,11 +349,14 @@ def _save_program_bundle(bundle: StableHLOModelBundle, # save metadata and stablehlo bytecode func_dir = os.path.join(stablehlo_dir, 'functions') os.makedirs(func_dir, exist_ok=True) - for meta, bytecode in bundle.stablehlo_funcs: - with open(os.path.join(func_dir, meta.name + '.meta'), 'w') as f: - json.dump(meta, f, cls=StableHLOJSONSerializer) - with open(os.path.join(func_dir, meta.name + '.mlir'), 'wb') as f: - f.write(bytecode) + for func in bundle.stablehlo_funcs: + with open(os.path.join(func_dir, func.meta.name + '.meta'), 'w') as f: + json.dump(func.meta, f, cls=StableHLOJSONSerializer) + with open(os.path.join(func_dir, func.meta.name + '.bytecode'), 'wb') as f: + f.write(func.bytecode) + if func.text is not None: + with open(os.path.join(func_dir, func.meta.name + '.mlir'), 'w') as f: + f.write(func.text) const_dir = os.path.join(stablehlo_dir, 'constants') os.makedirs(const_dir, exist_ok=True) @@ -390,15 +384,20 @@ def _load_program_bundle(stablehlo_dir: os.PathLike) -> StableHLOModelBundle: metas = [] name_to_bytecode = {} + name_to_text = {} stablehlo_funcs = [] for name, f in _iter_dir(os.path.join(stablehlo_dir, 'functions')): if name.endswith('.meta'): metas.append(json.load(f, object_hook=stablehlo_obj_hook)) - else: + elif name.endswith('.bytecode'): name_to_bytecode[os.path.splitext(name)[0]] = f.read() + elif name.endswith('.mlir'): + name_to_text[os.path.splitext(name)[0]] = f.read() for meta in metas: - stablehlo_funcs.append((meta, name_to_bytecode[meta.name])) + stablehlo_funcs.append( + StableHLOFunc(meta, name_to_bytecode[meta.name], + name_to_text.get(meta.name))) return StableHLOModelBundle( stablehlo_funcs=stablehlo_funcs, @@ -406,10 +405,98 @@ def _load_program_bundle(stablehlo_dir: os.PathLike) -> StableHLOModelBundle: state_dict=state_dict) +@dataclass +class StableHLOExportOptions: + include_human_readable_text: bool = True + override_tracing_arguments: Optional[Tuple[Any]] = None + + def save_as_stablehlo(exported_model: 'ExportedProgram', - args: Tuple[Any], stablehlo_dir: os.PathLike, options: Optional[StableHLOExportOptions] = None): + """Convert a torch ExportedProgram to a callable backed by StableHLO, and save it to disk + + Args: + exported_model: ExportedProgram - a pytorch ExportedProgram produced by torch.export.export + stablehlo_dir: path to empty directory to create files. + options: StableHLOExportOptions - options + + Files will contain stablehlo bytecode as well as tensor weights as numpy array. + + This files can be loaded into StableHLOGraphModule via StableHLOGraphModule.load + + Example: - bundle = _exported_program_to_stablehlo_bundle(exported_model, args) - _save_program_bundle(bundle, stablehlo_dir) + ```python + model = ... + exported = torch.export.export(model, args) + save_as_stablehlo(exported, path) + shlo_model = StableHLOGraphModule.load(path) + assert shlo_model(*args) == model(*args) + ``` + """ + + if options is None: + options = StableHLOExportOptions() + shlo_program = exported_program_to_stablehlo(exported_model, options) + shlo_program.save(stablehlo_dir) + + +def exported_program_to_stablehlo( + exported_model: 'ExportedProgram', + options: Optional[StableHLOExportOptions] = None) -> StableHLOGraphModule: + """Convert a torch ExportedProgram to a callable backed by StableHLO. + + Args: + model: ExportedProgram - a pytorch ExportedProgram produced by torch.export.export + options: StableHLOExportOptions - options + + This function will return a callable backed by StableHLO such that, + + model(*sample_inputs) == export_torch_model(model, sample_inputs)(*sample_inputs) + (up to numerics) + + In other words, returned callable have the same calling convention of the input model, and + on the sample input, or inputs sufficiently similar* to sample input, + it is will to return same result as the original model. + + * sufficiently similar input because this function will use tracing to extract the model operations + so it might specialize on the shapes of the sample input. + + """ + bundle = _exported_program_to_stablehlo_bundle(exported_model, options) + return StableHLOGraphModule(bundle) + + +def save_torch_model_as_stablehlo( + torchmodel: torch.nn.Module, + args: Tuple[Any], + path: os.PathLike, + options: Optional[StableHLOExportOptions] = None) -> None: + """Convert a torch model to a callable backed by StableHLO. + + Args: + model: torch.nn.Module - a pytorch model + args: Tuple[torch.Tensor] - The input to this model + path: path to empty directory to save the content + options: StableHLOExportOptions + + + This function will return a callable backed by StableHLO such that, + + model(*args) == export_torch_model(model, args)(*args) + (up to numerics) + + In other words, returned callable have the same calling convention of the input model, and + on the sample input, or inputs sufficiently similar* to sample input, + it is will to return same result as the original model. + + * sufficiently similar input because this function will use tracing to extract the model operations + so it might specialize on the shapes of the sample input. + + """ + exported = torch.export.export(torchmodel, args) + if options is None: + options = StableHLOExportOptions() + options.override_tracing_arguments = args + return save_as_stablehlo(exported, path, options) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py new file mode 100644 index 000000000000..f55f71ff0cef --- /dev/null +++ b/torch_xla/tf_saved_model_integration.py @@ -0,0 +1,125 @@ +import os +from typing import List, Tuple, Any +import copy +import logging + +import torch +from torch_xla import stablehlo + +try: + import tensorflow as tf + from tensorflow.compiler.tf2xla.python import xla as tfxla +except ImportError: + logging.error('This module is need tensorflow with xla support.\n' + 'Please install tensorflow with `pip install tf-nightly`.\n') + raise + + +def _wrap_as_tf_func(func, bundle): + + def inner(*args): + output_sig = func.meta.output_signature[0] + Touts = [sig.dtype for sig in func.meta.output_signature] + Souts = [sig.shape for sig in func.meta.output_signature] + call_args = stablehlo._extract_call_parameters(args, func.meta, bundle) + return tfxla.call_module( + tuple(call_args), + version=5, + Tout=Touts, # dtype information + Sout=Souts, # Shape information + function_list=[], + module=func.bytecode, + ) + + return inner + + +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule): + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], + stablehlo_program._bundle) + + +def _make_input_signatures( + meta: stablehlo.StableHLOFunctionMeta) -> List[tf.TensorSpec]: + input_pos_to_spec = { + loc.position: spec + for loc, spec in zip(meta.input_locations, meta.input_signature) + if loc.type_ == stablehlo.VariableType.INPUT_ARG + } + for i in range(len(input_pos_to_spec)): + spec = input_pos_to_spec[i] + yield tf.TensorSpec( + shape=spec.shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}') + + +def save_stablehlo_graph_as_tf( + stablehlo_program: stablehlo.StableHLOGraphModule, + path: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = '') -> None: + """This function will export and save a StableHLOGraphModule to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model server + or further convert to tflite flatbuffer for on-device serving. + + StableHLOGraphModule is produced with the torch_xla.stablehlo package. + + Args: + stablehlo_program - model to export and save + path: os.PathLike - location to an empty directory to store the saved_model + serving_key: str - serving key tag, this is used by tf.serving to know which function to run. + function_alias: str - passed through saved_model.save, used to tag a function for + inference converter or other tools. + """ + + bundle = copy.deepcopy(stablehlo_program._bundle) + tfm = tf.Module() + bundle.state_dict = { + k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items() + } + bundle.additional_constants = [ + tf.Variable(v, trainable=False) for v in bundle.additional_constants + ] + input_signatures = list( + _make_input_signatures(bundle.stablehlo_funcs[0].meta)) + tfm.f = tf.function( + make_tf_function(stablehlo_program), input_signature=input_signatures) + tfm._variables = ( + list(bundle.state_dict.values()) + bundle.additional_constants) + signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)} + save_options = tf.saved_model.SaveOptions(function_aliases={ + function_alias: tfm.f, + }) + tf.saved_model.save( + tfm, + path, + signatures=signatures, + options=save_options, + ) + + +def save_torch_module_as_tf_saved_model( + torch_model: torch.nn.Module, + args: Tuple[Any], + saved_model_dir: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = '', +): + """This function will export and save a pytorch nn.Module to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the saved_model + serving_key: str - serving key tag, this is used by tf.serving to know which function to run. + function_alias: str - passed through saved_model.save, used to tag a function for + inference converter or other tools. + """ + exported = torch.export.export(torch_model, args) + options = stablehlo.StableHLOExportOptions(override_tracing_arguments=args) + stablehlo_model = stablehlo.exported_program_to_stablehlo(exported, options) + save_stablehlo_graph_as_tf(stablehlo_model, saved_model_dir, serving_key, + function_alias)