Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test_spmd_debugging.py to avoid code test code self #6263

Merged
merged 25 commits into from
Jan 19, 2024
Merged
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 63 additions & 36 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
if num_devices != 8:
self.skipTest("skip num_devices!=8 env to avoid circular reasoning")
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = self._get_mesh(mesh_shape)
Expand All @@ -51,13 +53,14 @@ def test_debugging_spmd_single_host_tiled_tpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -116,6 +119,8 @@ def test_single_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
if num_devices != 8:
self.skipTest("skip num_devices!=8 env to avoid circular reasoning")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make these checks into test decorators? e.g. unittest.skipIf(xr.global_runtime_device_count() != 8)

Also for my understanding, what's the circular reasoning here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, unittest.skipIf is the way to go. What's the reason behind this check? !=8 seems too restrictive -- so if we can generalize to n devices, it would be better?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for both suggestion, has updated to more generalized way to generate expected table without restrictive limitation

and for circular reasoning, I want to describe a test situation that we need to provide test example by our function, and provide expected example by our function too,

for this test/spmd/test_spmd_debugging.py, if we want to test on given device kind, we need to generate table with our function, and generate expected table with our function too, so to avoid this(circular reasoning), we want to limit the test device kind to 8-devices

let me change to code test code self for better description

mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = self._get_mesh(mesh_shape)
Expand All @@ -132,13 +137,14 @@ def test_single_host_partial_replication_tpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -167,6 +173,8 @@ def test_single_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
if num_devices != 8:
self.skipTest("skip num_devices!=8 env to avoid circular reasoning")
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = self._get_mesh(mesh_shape)
Expand All @@ -183,13 +191,14 @@ def test_single_host_replicated_tpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand All @@ -212,6 +221,8 @@ def test_debugging_spmd_single_host_tiled_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
if num_devices != 1:
self.skipTest("skip num_devices!=1 env to avoid circular reasoning")
mesh_shape = (1, num_devices)
device_ids = np.array(range(num_devices))
mesh = self._get_mesh(mesh_shape)
Expand All @@ -227,13 +238,14 @@ def test_debugging_spmd_single_host_tiled_cpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand All @@ -255,6 +267,8 @@ def test_single_host_partial_replication_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
if num_devices != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naive question, but is there a case where PJRT_DEVICE=CPU has n_devices > 1?

Copy link
Collaborator Author

@ManfeiBai ManfeiBai Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, Jon, very good question, checked with TPU Pod-16:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} --zone  ${ZONE} --worker all --command='PJRT_DEVICE=CPU python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count());"'
Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
1
1
1
1

looks like PJRT_DEVICE=CPU only has n_devices = 1 compared with TPU device number:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} --zone  ${ZONE} --worker all --command='PJRT_DEVICE=TPU python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count());"'
Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
16
16
16

self.skipTest("skip num_devices!=1 env to avoid circular reasoning")
mesh_shape = (1, num_devices)
device_ids = np.array(range(num_devices))
mesh = self._get_mesh(mesh_shape)
Expand All @@ -271,13 +285,14 @@ def test_single_host_partial_replication_cpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand All @@ -299,6 +314,8 @@ def test_single_host_replicated_cpu(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
if num_devices != 1:
self.skipTest("skip num_devices!=1 env to avoid circular reasoning")
mesh_shape = (1, num_devices)
device_ids = np.array(range(num_devices))
mesh = self._get_mesh(mesh_shape)
Expand All @@ -315,13 +332,14 @@ def test_single_host_replicated_cpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -356,13 +374,14 @@ def test_debugging_spmd_multi_host_tiled_tpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -468,13 +487,14 @@ def test_multi_host_partial_replication_tpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -543,6 +563,9 @@ def test_multi_host_partial_replication_tpu(self):
f"Requires PJRT_DEVICE set to `TPU`.")
def test_multi_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
num_devices = self.n_devices
if num_devices != 8:
self.skipTest("skip num_devices!=8 env to avoid circular reasoning")
sharding = '{replicated}'
generated_table = visualize_sharding(sharding)
console = rich.console.Console()
Expand All @@ -552,13 +575,14 @@ def test_multi_host_replicated_tpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -588,13 +612,14 @@ def test_debugging_spmd_multi_host_tiled_cpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -700,13 +725,14 @@ def test_multi_host_partial_replication_cpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down Expand Up @@ -784,13 +810,14 @@ def test_multi_host_replicated_cpu(self):

color = None
text_color = None
use_color = True if rich.console.Console().color_system else False
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
show_lines=not use_color,
padding=0,
highlight=True,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE)
box=rich.box.SQUARE if not use_color else None)
col = []
col.append(
rich.padding.Padding(
Expand Down