From 781f6c547e61fce85b8b084317f11f2e26310297 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 5 Jan 2024 12:02:07 -0800 Subject: [PATCH 01/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 02139c55a5a2..6a128d9f79a8 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -36,6 +36,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = self.n_devices + if num_devices != 8: + self.skipTest("skip num_devices!=8 env to avoid circular reasoning") mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) mesh = self._get_mesh(mesh_shape) @@ -116,6 +118,8 @@ def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = self.n_devices + if num_devices != 8: + self.skipTest("skip num_devices!=8 env to avoid circular reasoning") mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) mesh = self._get_mesh(mesh_shape) @@ -167,6 +171,8 @@ def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = self.n_devices + if num_devices != 8: + self.skipTest("skip num_devices!=8 env to avoid circular reasoning") mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) mesh = self._get_mesh(mesh_shape) @@ -212,6 +218,8 @@ def test_debugging_spmd_single_host_tiled_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = self.n_devices + if num_devices != 1: + self.skipTest("skip num_devices!=1 env to avoid circular reasoning") mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) mesh = self._get_mesh(mesh_shape) @@ -255,6 +263,8 @@ def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = self.n_devices + if num_devices != 1: + self.skipTest("skip num_devices!=1 env to avoid circular reasoning") mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) mesh = self._get_mesh(mesh_shape) @@ -299,6 +309,8 @@ def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = self.n_devices + if num_devices != 1: + self.skipTest("skip num_devices!=1 env to avoid circular reasoning") mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) mesh = self._get_mesh(mesh_shape) From 03b5d0d4414184c63ac0d0f79465fdd30f4689ed Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:21:47 -0800 Subject: [PATCH 02/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 6a128d9f79a8..9d6f7874b448 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -368,13 +368,16 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): color = None text_color = None + test_debugging_spmd_multi_host_tiled_tpu + # console = rich.console.Console() # width=max_width) + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( From bfd4d366047fcefdef16d164f0192e9ef6f9e15b Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:26:16 -0800 Subject: [PATCH 03/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 9d6f7874b448..ab64694b7e8a 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -368,7 +368,6 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): color = None text_color = None - test_debugging_spmd_multi_host_tiled_tpu # console = rich.console.Console() # width=max_width) use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( From ecf0ca85dfc6d3186cfd31fb2e8604f55d7bdfcb Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:29:24 -0800 Subject: [PATCH 04/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 78 ++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index ab64694b7e8a..5eea1e136e25 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -53,13 +53,14 @@ def test_debugging_spmd_single_host_tiled_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -136,13 +137,14 @@ def test_single_host_partial_replication_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -189,13 +191,14 @@ def test_single_host_replicated_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -235,13 +238,14 @@ def test_debugging_spmd_single_host_tiled_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -281,13 +285,14 @@ def test_single_host_partial_replication_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -327,13 +332,14 @@ def test_single_host_replicated_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -368,7 +374,6 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): color = None text_color = None - # console = rich.console.Console() # width=max_width) use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, @@ -482,13 +487,14 @@ def test_multi_host_partial_replication_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -566,13 +572,14 @@ def test_multi_host_replicated_tpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -602,13 +609,14 @@ def test_debugging_spmd_multi_host_tiled_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -714,13 +722,14 @@ def test_multi_host_partial_replication_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( @@ -798,13 +807,14 @@ def test_multi_host_replicated_cpu(self): color = None text_color = None + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, - show_lines=True, + show_lines=not use_color, padding=0, - highlight=True, + highlight=not use_color, pad_edge=False, - box=rich.box.SQUARE) + box=rich.box.SQUARE if not use_color else None) col = [] col.append( rich.padding.Padding( From 46f58e79714f53e4ab0a29752a4fa5da24ec24a8 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:35:40 -0800 Subject: [PATCH 05/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 5eea1e136e25..647ac419791d 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -563,6 +563,9 @@ def test_multi_host_partial_replication_tpu(self): f"Requires PJRT_DEVICE set to `TPU`.") def test_multi_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding + num_devices = self.n_devices + if num_devices != 8: + self.skipTest("skip num_devices!=8 env to avoid circular reasoning") sharding = '{replicated}' generated_table = visualize_sharding(sharding) console = rich.console.Console() From 456f5d1ef8f40696aee58e541c5ccf5d5804695d Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:41:09 -0800 Subject: [PATCH 06/24] format --- test/spmd/test_spmd_debugging.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 647ac419791d..d11c4416f844 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -53,7 +53,7 @@ def test_debugging_spmd_single_host_tiled_tpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -137,7 +137,7 @@ def test_single_host_partial_replication_tpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -191,7 +191,7 @@ def test_single_host_replicated_tpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -238,7 +238,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -285,7 +285,7 @@ def test_single_host_partial_replication_cpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -332,7 +332,7 @@ def test_single_host_replicated_cpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -374,7 +374,7 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -487,7 +487,7 @@ def test_multi_host_partial_replication_tpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -575,7 +575,7 @@ def test_multi_host_replicated_tpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -612,7 +612,7 @@ def test_debugging_spmd_multi_host_tiled_cpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -725,7 +725,7 @@ def test_multi_host_partial_replication_cpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, @@ -810,7 +810,7 @@ def test_multi_host_replicated_cpu(self): color = None text_color = None - use_color = True if rich.console.Console().color_system else False + use_color = True if rich.console.Console().color_system else False fake_table = rich.table.Table( show_header=False, show_lines=not use_color, From 7a7989c0ce68b8871a02865fb1b1198404256df8 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:45:33 -0800 Subject: [PATCH 07/24] Update test_spmd_debugging.py with clean sharding string --- test/spmd/test_spmd_debugging.py | 79 +++----------------------------- 1 file changed, 6 insertions(+), 73 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index d11c4416f844..dda4ae3a51ac 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -34,17 +34,7 @@ def setUpClass(cls): f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - if num_devices != 8: - self.skipTest("skip num_devices!=8 env to avoid circular reasoning") - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - t = torch.randn(8, 4, device=device) - partition_spec = (0, 1) - xs.mark_sharding(t, mesh, partition_spec) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + sharding={devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15} generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -117,18 +107,7 @@ def test_debugging_spmd_single_host_tiled_tpu(self): f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - if num_devices != 8: - self.skipTest("skip num_devices!=8 env to avoid circular reasoning") - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - - partition_spec = (0, None) - t = torch.randn(8, 32, device=device) - xs.mark_sharding(t, mesh, (0, None)) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -171,18 +150,7 @@ def test_single_host_partial_replication_tpu(self): f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - if num_devices != 8: - self.skipTest("skip num_devices!=8 env to avoid circular reasoning") - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - - partition_spec_replicated = (None, None) - t = torch.randn(8, 32, device=device) - xs.mark_sharding(t, mesh, partition_spec_replicated) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + sharding = '{replicated}' generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -218,18 +186,7 @@ def test_single_host_replicated_tpu(self): xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - if num_devices != 1: - self.skipTest("skip num_devices!=1 env to avoid circular reasoning") - mesh_shape = (1, num_devices) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - t = torch.randn(8, 4, device=device) - partition_spec = (0, 1) - xs.mark_sharding(t, mesh, partition_spec) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + sharding={devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15} generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -265,18 +222,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - if num_devices != 1: - self.skipTest("skip num_devices!=1 env to avoid circular reasoning") - mesh_shape = (1, num_devices) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - - partition_spec = (0, None) - t = torch.randn(8, 32, device=device) - xs.mark_sharding(t, mesh, (0, None)) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -312,18 +258,7 @@ def test_single_host_partial_replication_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - device = xm.xla_device() - num_devices = self.n_devices - if num_devices != 1: - self.skipTest("skip num_devices!=1 env to avoid circular reasoning") - mesh_shape = (1, num_devices) - device_ids = np.array(range(num_devices)) - mesh = self._get_mesh(mesh_shape) - - partition_spec_replicated = (None, None) - t = torch.randn(8, 32, device=device) - xs.mark_sharding(t, mesh, partition_spec_replicated) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + sharding = '{replicated}' generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -564,8 +499,6 @@ def test_multi_host_partial_replication_tpu(self): def test_multi_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding num_devices = self.n_devices - if num_devices != 8: - self.skipTest("skip num_devices!=8 env to avoid circular reasoning") sharding = '{replicated}' generated_table = visualize_sharding(sharding) console = rich.console.Console() From f0351f6d4ee5cc91d8143cc9f3d4fd8fd8278690 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:47:13 -0800 Subject: [PATCH 08/24] test for GPU too --- test/spmd/test_spmd_debugging.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index dda4ae3a51ac..5bba5f90c3bc 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,10 +28,10 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf( - not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), - f"Requires PJRT_DEVICE set to `TPU`.") +# @unittest.skipIf( +# not xr.using_pjrt() or +# xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), +# f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding sharding={devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15} @@ -101,10 +101,10 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( - not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), - f"Requires PJRT_DEVICE set to `TPU`.") +# @unittest.skipIf( +# not xr.using_pjrt() or +# xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), +# f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} @@ -144,10 +144,10 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( - not xr.using_pjrt() or - xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), - f"Requires PJRT_DEVICE set to `TPU`.") +# @unittest.skipIf( +# not xr.using_pjrt() or +# xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), +# f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding sharding = '{replicated}' From 68cb4208916b22af6f926a5baa3153ff744d8e88 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 8 Jan 2024 20:27:18 -0800 Subject: [PATCH 09/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 5bba5f90c3bc..77863c31c1bd 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -34,7 +34,7 @@ def setUpClass(cls): # f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - sharding={devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15} + sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -107,7 +107,7 @@ def test_debugging_spmd_single_host_tiled_tpu(self): # f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} + sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -186,7 +186,7 @@ def test_single_host_replicated_tpu(self): xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): - sharding={devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15} + sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -222,7 +222,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} + sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' generated_table = visualize_tensor_sharding(t) console = rich.console.Console() with console.capture() as capture: @@ -498,7 +498,6 @@ def test_multi_host_partial_replication_tpu(self): f"Requires PJRT_DEVICE set to `TPU`.") def test_multi_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - num_devices = self.n_devices sharding = '{replicated}' generated_table = visualize_sharding(sharding) console = rich.console.Console() From 0271290e1615821a59975f1c25b4c5ff05336e36 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 10 Jan 2024 11:42:35 -0800 Subject: [PATCH 10/24] format --- test/spmd/test_spmd_debugging.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 9dd71d554f9e..30eca1c245b1 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,7 +28,6 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' @@ -98,7 +97,6 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' @@ -138,7 +136,6 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding sharding = '{replicated}' From 899737005129a5a3c63f4553394bcd8ebf19a43a Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 10 Jan 2024 15:38:23 -0800 Subject: [PATCH 11/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 30eca1c245b1..d13ef88b7c3c 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -29,9 +29,9 @@ def setUpClass(cls): super().setUpClass() def test_debugging_spmd_single_host_tiled_tpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -98,9 +98,9 @@ def test_debugging_spmd_single_host_tiled_tpu(self): assert output == fake_output def test_single_host_partial_replication_tpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -137,9 +137,9 @@ def test_single_host_partial_replication_tpu(self): assert output == fake_output def test_single_host_replicated_tpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -172,8 +172,9 @@ def test_single_host_replicated_tpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -205,9 +206,9 @@ def test_debugging_spmd_single_host_tiled_cpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) @@ -239,9 +240,9 @@ def test_single_host_partial_replication_cpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): - from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' - generated_table = visualize_tensor_sharding(t) + generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: console.print(generated_table) From 90cdfd2251434af9fa24cdb06d95eb2a80ad7ac6 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 11 Jan 2024 17:25:39 -0800 Subject: [PATCH 12/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index d13ef88b7c3c..9da0b13ff600 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -30,7 +30,7 @@ def setUpClass(cls): def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' + sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -99,7 +99,7 @@ def test_debugging_spmd_single_host_tiled_tpu(self): def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' + sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -139,6 +139,9 @@ def test_single_host_partial_replication_tpu(self): def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' + num_devices = self.n_devices + if num_devices != 8: + self.skipTest("limit test num_devices to 8 for function consistency") generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -173,7 +176,7 @@ def test_single_host_replicated_tpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' + sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -207,7 +210,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' + sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -242,6 +245,9 @@ def test_single_host_partial_replication_cpu(self): def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' + num_devices = self.n_devices + if num_devices != 8: + self.skipTest("limit test num_devices to 8 for function consistency") generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -476,6 +482,9 @@ def test_multi_host_partial_replication_tpu(self): def test_multi_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' + num_devices = self.n_devices + if num_devices != 8: + self.skipTest("limit test num_devices to 8 for function consistency") generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -705,6 +714,9 @@ def test_multi_host_partial_replication_cpu(self): def test_multi_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' + num_devices = self.n_devices + if num_devices != 8: + self.skipTest("limit test num_devices to 8 for function consistency") generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: From 94ae823562221032f4ca688918352230ba0eb041 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:30:05 -0800 Subject: [PATCH 13/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 42 +++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 9da0b13ff600..6e33c8e0d451 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,6 +28,8 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() + @unittest.skipIf(xr.device_type() != 'TPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' @@ -97,6 +99,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output + @unittest.skipIf(xr.device_type() != 'TPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' @@ -136,6 +140,8 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output + @unittest.skipIf(xr.device_type() != 'TPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -176,7 +182,16 @@ def test_single_host_replicated_tpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (1, num_devices) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + + partition_spec = (0, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, (0, None)) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -196,7 +211,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('CPU [0]', "center", vertical="middle"), + rich.align.Align('CPU 0', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -210,7 +225,16 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (1, num_devices) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + + partition_spec = (0, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, (0, None)) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -244,10 +268,16 @@ def test_single_host_partial_replication_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - sharding = '{replicated}' + device = xm.xla_device() num_devices = self.n_devices - if num_devices != 8: - self.skipTest("limit test num_devices to 8 for function consistency") + mesh_shape = (1, num_devices) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + + partition_spec_replicated = (None, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, partition_spec_replicated) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: From 099391106f44f2af958ac9f0e809bc19842cd57b Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:49:55 -0800 Subject: [PATCH 14/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 6e33c8e0d451..4d026891c473 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -211,7 +211,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('CPU 0', "center", vertical="middle"), + rich.align.Align('CPU [0]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) From b2444c9da6d591c60572c2b8b0a394b35fbceaa7 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 12 Jan 2024 11:39:38 -0800 Subject: [PATCH 15/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 4d026891c473..c7725ae8dc9f 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -123,14 +123,28 @@ def test_single_host_partial_replication_tpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"), + rich.align.Align('TPU [0, 1]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [2, 3]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"), + rich.align.Align('TPU [4, 5]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [6, 7]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) From e04349851c8a75443bc1aef2436bc0c12962f6ae Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:34:44 -0800 Subject: [PATCH 16/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index c7725ae8dc9f..a868246a4ca2 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -156,12 +156,11 @@ def test_single_host_partial_replication_tpu(self): @unittest.skipIf(xr.device_type() != 'TPU', f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.global_runtime_device_count() != 8, + f"Limit test num_devices to 8 for function consistency") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' - num_devices = self.n_devices - if num_devices != 8: - self.skipTest("limit test num_devices to 8 for function consistency") generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -523,12 +522,11 @@ def test_multi_host_partial_replication_tpu(self): @unittest.skipIf(xr.device_type() != 'TPU', f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.global_runtime_device_count() != 8, + f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' - num_devices = self.n_devices - if num_devices != 8: - self.skipTest("limit test num_devices to 8 for function consistency") generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: @@ -755,12 +753,11 @@ def test_multi_host_partial_replication_cpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") + @unittest.skipIf(xr.global_runtime_device_count() != 8, + f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' - num_devices = self.n_devices - if num_devices != 8: - self.skipTest("limit test num_devices to 8 for function consistency") generated_table = visualize_sharding(sharding) console = rich.console.Console() with console.capture() as capture: From 2215e9d909800c0c346a77f9640d09530ef4f100 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:48:31 -0800 Subject: [PATCH 17/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index a868246a4ca2..bdb701ac1d77 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -156,8 +156,6 @@ def test_single_host_partial_replication_tpu(self): @unittest.skipIf(xr.device_type() != 'TPU', f"Requires PJRT_DEVICE set to `TPU`.") - @unittest.skipIf(xr.global_runtime_device_count() != 8, - f"Limit test num_devices to 8 for function consistency") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -178,10 +176,14 @@ def test_single_host_replicated_tpu(self): pad_edge=False, box=rich.box.SQUARE if not use_color else None) col = [] + alltpus = 'TPU [0' + for i in range(xr.global_runtime_device_count()-1): + alltpus = alltpus + ',' + str(i+1) + alltpus = alltpus + ']' col.append( rich.padding.Padding( rich.align.Align( - 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + alltpus, "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -753,8 +755,6 @@ def test_multi_host_partial_replication_cpu(self): @unittest.skipIf(xr.device_type() != 'CPU', f"Requires PJRT_DEVICE set to `CPU`.") - @unittest.skipIf(xr.global_runtime_device_count() != 8, - f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -774,10 +774,14 @@ def test_multi_host_replicated_cpu(self): highlight=not use_color, pad_edge=False, box=rich.box.SQUARE if not use_color else None) + alltpus = 'CPU [0' + for i in range(xr.global_runtime_device_count()-1): + alltpus = alltpus + ',' + str(i+1) + alltpus = alltpus + ']' col = [] col.append( rich.padding.Padding( - rich.align.Align('CPU [0]', "center", vertical="middle"), + rich.align.Align(alltpus, "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) From 057397e2973a1124faebe5e0e74fb1d3bf636bcd Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:55:46 -0800 Subject: [PATCH 18/24] try --- test/spmd/test_spmd_debugging.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index bdb701ac1d77..a6ad8789a270 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,8 +28,8 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() not in ('TPU', 'GPU'), + f"Requires PJRT_DEVICE set to `TPU` or `GPU`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' @@ -52,44 +52,44 @@ def test_debugging_spmd_single_host_tiled_tpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 0', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 0', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 1', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 1', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 2', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 2', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 3', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 3', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 4', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 4', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 5', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 5', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 6', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 6', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 7', "center", vertical="middle"), + rich.align.Align(xr.device_type() + ' 7', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) From ffd19f6e49b3a7525e057e656d055c62b7f2daf1 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 12 Jan 2024 15:22:46 -0800 Subject: [PATCH 19/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index a6ad8789a270..be5d346048cd 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -774,14 +774,11 @@ def test_multi_host_replicated_cpu(self): highlight=not use_color, pad_edge=False, box=rich.box.SQUARE if not use_color else None) - alltpus = 'CPU [0' - for i in range(xr.global_runtime_device_count()-1): - alltpus = alltpus + ',' + str(i+1) - alltpus = alltpus + ']' col = [] + # PJRT_DEVICE=CPU will only has one CPU, please update once situation change col.append( rich.padding.Padding( - rich.align.Align(alltpus, "center", vertical="middle"), + rich.align.Align('CPU [0]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) From b05a7396ac18cb04311a41b05c3accda58646c6b Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:52:37 -0800 Subject: [PATCH 20/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 35 +++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index be5d346048cd..a64e7f2ed588 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,8 +28,8 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() not in ('TPU', 'GPU'), - f"Requires PJRT_DEVICE set to `TPU` or `GPU`.") + @unittest.skipIf(xr.device_type() == ('CPU'), + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, or `CUDA`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' @@ -52,44 +52,52 @@ def test_debugging_spmd_single_host_tiled_tpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 0', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 0', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 1', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 1', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 2', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 2', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 3', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 3', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 4', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 4', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 5', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 5', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 6', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 6', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align(xr.device_type() + ' 7', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 7', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -177,13 +185,12 @@ def test_single_host_replicated_tpu(self): box=rich.box.SQUARE if not use_color else None) col = [] alltpus = 'TPU [0' - for i in range(xr.global_runtime_device_count()-1): - alltpus = alltpus + ',' + str(i+1) + for i in range(xr.global_runtime_device_count() - 1): + alltpus = alltpus + ',' + str(i + 1) alltpus = alltpus + ']' col.append( rich.padding.Padding( - rich.align.Align( - alltpus, "center", vertical="middle"), + rich.align.Align(alltpus, "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) From 01488e3bf0e5cef4bd4e8c06690f424ea0b2b476 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:53:04 -0800 Subject: [PATCH 21/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index a64e7f2ed588..6a960699628c 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,7 +28,7 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() == ('CPU'), + @unittest.skipIf(xr.device_type() == 'CPU', f"Requires PJRT_DEVICE set to `TPU`, `GPU`, or `CUDA`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding From 7eb80d83567464fc35956fa07a1c2bad464877f8 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 18 Jan 2024 13:36:34 -0800 Subject: [PATCH 22/24] Update test_spmd_debugging.py --- test/spmd/test_spmd_debugging.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 6a960699628c..61af77932afb 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -29,7 +29,7 @@ def setUpClass(cls): super().setUpClass() @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, or `CUDA`.") + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' @@ -107,8 +107,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' @@ -162,8 +162,8 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -335,8 +335,8 @@ def test_single_host_replicated_cpu(self): # e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} # e.g.: sharding={replicated} - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_multi_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' @@ -446,8 +446,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_multi_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' @@ -529,8 +529,8 @@ def test_multi_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() != 'TPU', - f"Requires PJRT_DEVICE set to `TPU`.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") @unittest.skipIf(xr.global_runtime_device_count() != 8, f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_tpu(self): From fc91e8b83d7ed46fbb9a6fa4e254423c77d6c797 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 18 Jan 2024 16:58:28 -0800 Subject: [PATCH 23/24] enable GPU test --- test/spmd/test_spmd_debugging.py | 118 ++++++++++++++++++++----------- 1 file changed, 76 insertions(+), 42 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 61af77932afb..f75ff076e50f 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,8 +28,9 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' @@ -107,8 +108,9 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' @@ -131,28 +133,32 @@ def test_single_host_partial_replication_tpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [0, 1]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [0, 1]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [2, 3]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [2, 3]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [4, 5]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [4, 5]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [6, 7]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [6, 7]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -162,8 +168,9 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -184,7 +191,7 @@ def test_single_host_replicated_tpu(self): pad_edge=False, box=rich.box.SQUARE if not use_color else None) col = [] - alltpus = 'TPU [0' + alltpus = xr.device_type() + ' [0' for i in range(xr.global_runtime_device_count() - 1): alltpus = alltpus + ',' + str(i + 1) alltpus = alltpus + ']' @@ -335,8 +342,9 @@ def test_single_host_replicated_cpu(self): # e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} # e.g.: sharding={replicated} - @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_debugging_spmd_multi_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' @@ -359,84 +367,100 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 0', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 0', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 4', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 4', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 8', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 8', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 12', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 12', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 2', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 2', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 6', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 6', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 10', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 10', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 14', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 14', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU 1', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 1', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 5', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 5', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 9', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 9', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 13', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 13', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 3', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 3', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 7', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 7', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 11', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 11', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( - rich.align.Align('TPU 15', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' 15', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -446,8 +470,9 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") def test_multi_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' @@ -470,56 +495,64 @@ def test_multi_host_partial_replication_tpu(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [0, 1]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [0, 1]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [4, 5]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [4, 5]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [8, 9]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [8, 9]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [12, 13]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [12, 13]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [2, 3]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [2, 3]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [6, 7]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [6, 7]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [10, 11]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [10, 11]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [14, 15]', "center", vertical="middle"), + rich.align.Align( + xr.device_type() + ' [14, 15]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) @@ -529,8 +562,9 @@ def test_multi_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf(xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf( + xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") @unittest.skipIf(xr.global_runtime_device_count() != 8, f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_tpu(self): @@ -556,7 +590,7 @@ def test_multi_host_replicated_tpu(self): col.append( rich.padding.Padding( rich.align.Align( - 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + xr.device_type() + ' [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) From 8367f997bd25bef5254d6899a300d52ef6db1a77 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:00:45 -0800 Subject: [PATCH 24/24] format --- test/spmd/test_spmd_debugging.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index f75ff076e50f..20ae3a3f71fd 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -590,8 +590,9 @@ def test_multi_host_replicated_tpu(self): col.append( rich.padding.Padding( rich.align.Align( - xr.device_type() + ' [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), - (1, 1, 1, 1), + xr.device_type() + ' [0, 1, 2, 3, 4, 5, 6, 7]', + "center", + vertical="middle"), (1, 1, 1, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fake_table.add_row(*col) fake_console = rich.console.Console()