forked from pytorch/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add dynamo config skip_input_data_check (pytorch#7913)
- Loading branch information
Showing
6 changed files
with
56 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch | ||
import torch_xla | ||
import unittest | ||
from torch_xla._dynamo import config | ||
|
||
|
||
class DynamoconfigTest(unittest.TestCase): | ||
|
||
def dummy_test(self, a): | ||
return a.cos().sin() | ||
|
||
def test_config_skip_input_data_check(self): | ||
device = torch_xla.device() | ||
print(config.skip_input_data_check) | ||
config.skip_input_data_check = True | ||
compiled_dummy = torch.compile(self.dummy_test, backend="openxla") | ||
t1 = torch.randn(3, 4, device=device) | ||
compiled_dummy(t1) | ||
t2 = torch.randn(3, 4, device=device) | ||
t2 += 5 | ||
with self.assertRaisesRegex( | ||
RuntimeError, r'input data to dynamo graph can not be a pending ir'): | ||
compiled_dummy(t2) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import torch_xla._dynamo.config as config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import torch_xla | ||
|
||
# Whether to skip checking input is a device data or not in the optim_mod. | ||
# Enabling it will reduce the overhead a bit but will throw a runtime error | ||
# if input is a pending IR. | ||
skip_input_data_check = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters