-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update test_spmd_debugging.py to avoid code test code self #6263
Changes from 6 commits
781f6c5
03b5d0d
bfd4d36
ecf0ca8
46f58e7
456f5d1
7a7989c
f0351f6
68cb420
1fedc5f
0271290
8997370
90cdfd2
94ae823
0993911
b2444c9
e043498
2215e9d
057397e
ffd19f6
b05a739
01488e3
7eb80d8
fc91e8b
8367f99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self): | |
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding | ||
device = xm.xla_device() | ||
num_devices = self.n_devices | ||
if num_devices != 8: | ||
self.skipTest("skip num_devices!=8 env to avoid circular reasoning") | ||
mesh_shape = (2, num_devices // 2) | ||
device_ids = np.array(range(num_devices)) | ||
mesh = self._get_mesh(mesh_shape) | ||
|
@@ -51,13 +53,14 @@ def test_debugging_spmd_single_host_tiled_tpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -116,6 +119,8 @@ def test_single_host_partial_replication_tpu(self): | |
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding | ||
device = xm.xla_device() | ||
num_devices = self.n_devices | ||
if num_devices != 8: | ||
self.skipTest("skip num_devices!=8 env to avoid circular reasoning") | ||
mesh_shape = (2, num_devices // 2) | ||
device_ids = np.array(range(num_devices)) | ||
mesh = self._get_mesh(mesh_shape) | ||
|
@@ -132,13 +137,14 @@ def test_single_host_partial_replication_tpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -167,6 +173,8 @@ def test_single_host_replicated_tpu(self): | |
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding | ||
device = xm.xla_device() | ||
num_devices = self.n_devices | ||
if num_devices != 8: | ||
self.skipTest("skip num_devices!=8 env to avoid circular reasoning") | ||
mesh_shape = (2, num_devices // 2) | ||
device_ids = np.array(range(num_devices)) | ||
mesh = self._get_mesh(mesh_shape) | ||
|
@@ -183,13 +191,14 @@ def test_single_host_replicated_tpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -212,6 +221,8 @@ def test_debugging_spmd_single_host_tiled_cpu(self): | |
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding | ||
device = xm.xla_device() | ||
num_devices = self.n_devices | ||
if num_devices != 1: | ||
self.skipTest("skip num_devices!=1 env to avoid circular reasoning") | ||
mesh_shape = (1, num_devices) | ||
device_ids = np.array(range(num_devices)) | ||
mesh = self._get_mesh(mesh_shape) | ||
|
@@ -227,13 +238,14 @@ def test_debugging_spmd_single_host_tiled_cpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -255,6 +267,8 @@ def test_single_host_partial_replication_cpu(self): | |
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding | ||
device = xm.xla_device() | ||
num_devices = self.n_devices | ||
if num_devices != 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naive question, but is there a case where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, Jon, very good question, checked with TPU Pod-16:
looks like
|
||
self.skipTest("skip num_devices!=1 env to avoid circular reasoning") | ||
mesh_shape = (1, num_devices) | ||
device_ids = np.array(range(num_devices)) | ||
mesh = self._get_mesh(mesh_shape) | ||
|
@@ -271,13 +285,14 @@ def test_single_host_partial_replication_cpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -299,6 +314,8 @@ def test_single_host_replicated_cpu(self): | |
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding | ||
device = xm.xla_device() | ||
num_devices = self.n_devices | ||
if num_devices != 1: | ||
self.skipTest("skip num_devices!=1 env to avoid circular reasoning") | ||
mesh_shape = (1, num_devices) | ||
device_ids = np.array(range(num_devices)) | ||
mesh = self._get_mesh(mesh_shape) | ||
|
@@ -315,13 +332,14 @@ def test_single_host_replicated_cpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -356,13 +374,14 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -468,13 +487,14 @@ def test_multi_host_partial_replication_tpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -543,6 +563,9 @@ def test_multi_host_partial_replication_tpu(self): | |
f"Requires PJRT_DEVICE set to `TPU`.") | ||
def test_multi_host_replicated_tpu(self): | ||
from torch_xla.distributed.spmd.debugging import visualize_sharding | ||
num_devices = self.n_devices | ||
if num_devices != 8: | ||
self.skipTest("skip num_devices!=8 env to avoid circular reasoning") | ||
sharding = '{replicated}' | ||
generated_table = visualize_sharding(sharding) | ||
console = rich.console.Console() | ||
|
@@ -552,13 +575,14 @@ def test_multi_host_replicated_tpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -588,13 +612,14 @@ def test_debugging_spmd_multi_host_tiled_cpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -700,13 +725,14 @@ def test_multi_host_partial_replication_cpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
@@ -784,13 +810,14 @@ def test_multi_host_replicated_cpu(self): | |
|
||
color = None | ||
text_color = None | ||
use_color = True if rich.console.Console().color_system else False | ||
fake_table = rich.table.Table( | ||
show_header=False, | ||
show_lines=True, | ||
show_lines=not use_color, | ||
padding=0, | ||
highlight=True, | ||
highlight=not use_color, | ||
pad_edge=False, | ||
box=rich.box.SQUARE) | ||
box=rich.box.SQUARE if not use_color else None) | ||
col = [] | ||
col.append( | ||
rich.padding.Padding( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make these checks into test decorators? e.g.
unittest.skipIf(xr.global_runtime_device_count() != 8)
Also for my understanding, what's the circular reasoning here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1,
unittest.skipIf
is the way to go. What's the reason behind this check?!=8
seems too restrictive -- so if we can generalize ton
devices, it would be better?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for both suggestion, has updated to more generalized way to generate expected table without restrictive limitation
and for
circular reasoning
, I want to describe a test situation that we need to provide test example by our function, and provide expected example by our function too,for this
test/spmd/test_spmd_debugging.py
, if we want to test on given device kind, we need to generate table with our function, and generate expected table with our function too, so to avoid this(circular reasoning), we want to limit the test device kind to 8-deviceslet me change to
code test code self
for better description