Skip to content

Commit

Permalink
Refactor stablehlo API and put them in official location.
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Aug 24, 2023
1 parent 8a98a88 commit 85b97dc
Show file tree
Hide file tree
Showing 7 changed files with 470 additions and 163 deletions.
13 changes: 7 additions & 6 deletions test/stablehlo/test_export_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.stablehlo_saved_model import save_as_stablehlo
from torch_xla.stablehlo import save_as_stablehlo, StableHLOExportOptions
import torch
import torch._export
import torchvision
Expand All @@ -26,22 +26,23 @@ def test_llama_export(self):

gen = llama_model.GenLoop(model).eval()
arg = (torch.randint(0, 1000, (1, 10)),)
print(gen(*arg).shape)
arg[0].requires_grad = False
options = StableHLOExportOptions()
options.override_tracing_arguments = arg
with torch.no_grad():
exported2 = torch._export.export(gen, arg)
print(exported2.graph_module.code)
with tempfile.TemporaryDirectory() as tempdir:
save_as_stablehlo(exported2, arg, tempdir)
save_as_stablehlo(exported2, tempdir, options)

def test_llama_export(self):
options = llama_model.ModelArgs()
model = llama_model2.Transformer(options)
arg = (torch.randint(0, 1000, (8, 100)), torch.arange(0, 100), None)
print('val', model(*arg))
options = StableHLOExportOptions()
options.override_tracing_arguments = arg
exported = torch._export.export(model, arg)
with tempfile.TemporaryDirectory() as tempdir:
save_as_stablehlo(exported, arg, tempdir)
save_as_stablehlo(exported, tempdir, options)


if __name__ == '__main__':
Expand Down
45 changes: 45 additions & 0 deletions test/stablehlo/test_saved_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import StableHLOExportOptions, exported_program_to_stablehlo
from torch_xla.tf_saved_model_integration import make_tf_function, save_torch_module_as_tf_saved_model
from torch.utils import _pytree as pytree
import torch
import torchvision

import tempfile
import unittest
import tensorflow as tf


class StableHLOInferenceTest(unittest.TestCase):

def test_resnet18_inference(self):
resnet18 = torchvision.models.resnet18().eval()
data = torch.randn(4, 3, 224, 224)
output = resnet18(data)

exported = torch.export.export(resnet18, (data,))
options = StableHLOExportOptions(override_tracing_arguments=(data,))
stablehlo_program = exported_program_to_stablehlo(exported, options)
tf_func = make_tf_function(stablehlo_program)

output_tf = tf_func(*options.override_tracing_arguments)
output2 = torch.tensor(output_tf[0].numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))

def test_resnet18_save_load(self):
resnet18 = torchvision.models.resnet18().eval()
data = torch.randn(4, 3, 224, 224)
output = resnet18(data)

with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(resnet18, (data,), tempdir)
loaded_m = tf.saved_model.load(tempdir)
res = loaded_m.f(data.detach().numpy())[0]
output2 = torch.tensor(res.numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
51 changes: 44 additions & 7 deletions test/stablehlo/test_stablehlo_inference.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.stablehlo_saved_model import export_torch_model
from torch_xla.stablehlo import exported_program_to_stablehlo, StableHLOExportOptions
from torch.utils import _pytree as pytree
import torch
import torchvision

import tempfile
import unittest


def export_torch_model(model, args):
exported = torch._export.export(model, args)
options = StableHLOExportOptions()
options.override_tracing_arguments = args
return exported_program_to_stablehlo(exported, options)


class StableHLOInferenceTest(unittest.TestCase):

def test_resnet18_inference(self):
Expand All @@ -19,18 +27,18 @@ def test_resnet18_inference(self):
resnet18,
(data,),
)
output2 = exported(data)[0].cpu()

self.assertTrue(torch.allclose(output, output2, atol=1e-5))
output2 = exported(data).cpu()
self.assertTrue(torch.allclose(output, output2, atol=1e-3))

def test_resnet18_save_load(self):
return
import tensorflow as tf
resnet18 = torchvision.models.resnet18()
data = torch.randn(4, 3, 224, 224)
output = resnet18(data)
output_np = output.detach().numpy()

exported = export_torch_model(resnet18, (data,), to_tf=True)
exported = export_torch_model(resnet18, (data,))
tf_m = tf.Module()
tf_m.f = tf.function(
exported,
Expand Down Expand Up @@ -58,9 +66,38 @@ def forward(self, x, y):
data = (torch.randn(100, 100), torch.tensor(4.4))
output = m(*data)
exported = export_torch_model(m, data)
output2 = exported(*data)[0].cpu()

self.assertTrue(torch.allclose(output, output2, atol=1e-5))
device = xm.xla_device()
data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data)
output2 = exported(*data).cpu()

self.assertTrue(torch.allclose(output, output2, atol=1e-3))

def test_model_with_dict(self):

class DictInput(torch.nn.Module):

def __init__(self) -> None:
super().__init__()

def forward(self, inputs):
return (inputs['x'] + inputs['y'], inputs['x'] * inputs['y'])

m = DictInput()
data = ({
'x': torch.randn((10, 10)),
'y': torch.randn((10, 10)),
},)

output = m(*data)
exported = export_torch_model(m, data)
device = xm.xla_device()
data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data)
output2 = exported(*data)
self.assertEqual(len(output2), 2)

self.assertTrue(torch.allclose(output[0], output2[0].cpu()))
self.assertTrue(torch.allclose(output[1], output2[1].cpu()))


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import tempfile
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental import stablehlo_saved_model
from torch_xla import save_torch_model_as_stablehlo, save_as_stablehlo
from torch_xla.stablehlo import StableHLOExportOptions, StableHLOGraphModule
import torch
import torch._export
import torchvision
Expand Down Expand Up @@ -91,13 +92,22 @@ def test_save_load(self):
model = ElementwiseAdd()
inputs = model.get_random_inputs()
exported = torch._export.export(model, inputs)
bundle = stablehlo_saved_model._exported_program_to_stablehlo_bundle(
exported, inputs)
options = StableHLOExportOptions()
options.override_tracing_arguments = inputs
with tempfile.TemporaryDirectory() as tempdir:
stablehlo_saved_model._save_program_bundle(bundle, tempdir)
bundle2 = stablehlo_saved_model._load_program_bundle(tempdir)
save_as_stablehlo(exported, tempdir, options)
program2 = StableHLOGraphModule.load(tempdir)
result = program2(*inputs).detach().cpu()
self.assertTrue(torch.allclose(model(*inputs), result))

self.assertEqual(bundle.stablehlo_funcs, bundle2.stablehlo_funcs)
def test_save_load2(self):
model = ElementwiseAdd()
inputs = model.get_random_inputs()
with tempfile.TemporaryDirectory() as tempdir:
save_torch_model_as_stablehlo(model, inputs, tempdir)
program2 = StableHLOGraphModule.load(tempdir)
result = program2(*inputs).detach().cpu()
self.assertTrue(torch.allclose(model(*inputs), result))


if __name__ == '__main__':
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,5 @@ def _init_xla_lazy_backend():
# keep PyTorch/XLA CI healthy.
# TODO @wonjoo come up with a long term fix in Dynamo.
torch._dynamo.config.automatic_dynamic_shapes = False

from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo
Loading

0 comments on commit 85b97dc

Please sign in to comment.