-
Notifications
You must be signed in to change notification settings - Fork 505
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add unbounded dynamism test for some aten ops
format fix comment for skipped tests cover mul (cherry picked from commit f55abc88ae361e89da675a1aa1e4a19e7a5c762a) cover mul (cherry picked from commit 30abe2be43defc25db8954c525d34f7f3de35292) add missing tests to ci scripts yapf fix scalar type (cherry picked from commit 8526b2091ffafccf6972ecba3c111d1b0869621e) disable addmm test disable mark pattern api in gh ci, due to tf dep enable conv dynamism support addmm enable softmax dynamism update comment for slice add slice support, need converter change update test script take dynamic shape in save model export api verify lowering by adding tfl inference in tests remove debug pritn add assertion of sliced dim in select lowering remove log in conv, remove assertion in select re-enable test add select fx pass add no op slice removal pass add fx passes add tests' support layernorm add vit export scripot fix ep callable enable gelu test add export script support dynamic view with sym dim on dims other than BS add tests for gemma export support unsqueeze support softmax reduction on dynamic dim support unbounded index (unfinished) support dynamic expand add groupnorm add conv1d support, add dynamism (partially) to view add wav2vec2 export script add cumsum test, ne test remove existing tests change from crlf to lf add checks on view move stablehlo test util script remove debugging print add more assertions to fx passes remove test print add docstr to dynamic op make export script more concise remove debug print add comments to shape inference fix linter fix test util path yapf remove stack yapf update export script fix meta val not available in some nodes
- Loading branch information
Siyuan Liu
committed
Mar 20, 2024
1 parent
7e0d3a5
commit eb7420f
Showing
9 changed files
with
186 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import os | ||
from typing import Callable, List, Tuple, Type, Union | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import torch | ||
import torch.nn as nn | ||
import torch_xla | ||
from torch.export import Dim, export | ||
from torch.utils import _pytree as pytree | ||
from torch_xla.stablehlo import exported_program_to_stablehlo | ||
from torch_xla.tf_saved_model_integration import \ | ||
save_torch_module_as_tf_saved_model | ||
from transformers import ViTForImageClassification | ||
|
||
os.environ['EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM'] = '1' | ||
|
||
|
||
class ViTForImageClassificationModelWrapper(nn.Module): | ||
|
||
def __init__(self, model_name): | ||
super().__init__() | ||
self.m = ViTForImageClassification.from_pretrained(model_name) | ||
|
||
def forward(self, img): | ||
return self.m(pixel_values=img).logits | ||
|
||
|
||
model = ViTForImageClassificationModelWrapper( | ||
'google/vit-base-patch16-224').eval() | ||
args = (torch.rand(10, 3, 224, 224),) | ||
dynamic_shapes = ({0: Dim("dim")},) | ||
|
||
# Export to saved_model | ||
tmp_dir = "/tmp/vit-export/vit-1" | ||
save_torch_module_as_tf_saved_model( | ||
model, args, tmp_dir, dynamic_shapes=dynamic_shapes) | ||
|
||
# Verify numeric accuracy with an input with a different BS. | ||
args = (torch.rand(2, 3, 224, 224),) | ||
loaded_m = tf.saved_model.load(tmp_dir) | ||
tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()), | ||
args) | ||
tf_output = loaded_m.f(*tf_input) | ||
with torch.no_grad(): | ||
torch_output = model(*args) | ||
print(np.max(torch_output.numpy() - tf_output[0].numpy())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import torch | ||
import torch_xla | ||
from torch.export import Dim, export | ||
from torch.utils import _pytree as pytree | ||
from torch_xla.stablehlo import exported_program_to_stablehlo | ||
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model | ||
from transformers import Wav2Vec2ForCTC | ||
|
||
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1" | ||
|
||
|
||
class ModelWrapper(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self._model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | ||
|
||
def forward(self, input): | ||
r = self._model(input) | ||
return r.logits | ||
|
||
|
||
model = ModelWrapper().eval() | ||
args = (torch.rand(3, 800),) | ||
dynamic_shapes = ({0: Dim("bs")},) | ||
ep = export(model, args=args, dynamic_shapes=dynamic_shapes) | ||
|
||
tmp_dir = "/tmp/wav2vec2-export/tmp" | ||
save_torch_module_as_tf_saved_model( | ||
model, args, tmp_dir, dynamic_shapes=dynamic_shapes) | ||
|
||
# Verify numeric accuracy with an input with a different BS. | ||
args = (torch.rand(2, 800),) | ||
loaded_m = tf.saved_model.load(tmp_dir) | ||
tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()), | ||
args) | ||
tf_output = loaded_m.f(*tf_input) | ||
with torch.no_grad(): | ||
torch_output = model(*args) | ||
print(np.max(torch_output.numpy() - tf_output[0].numpy())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters