diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 1cd940c15531..612e65db9369 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -374,8 +374,6 @@ def test_mark_sharding_not_ordered_partial_4d(self): actual = (xt1 + t2).cpu() self.assertTrue(torch.allclose(expected, actual)) - @unittest.skipUnless(xr.global_runtime_device_count() >= 4, - 'At least 4 devices required') def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self): ct1 = torch.randn(16, 16, device='cpu') ct2 = torch.randn(16, 16, device='cpu') @@ -384,16 +382,15 @@ def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self): t1 = ct1.to(xm.xla_device()) t2 = ct2.to(xm.xla_device()) mesh = self._get_mesh((1, self.n_devices, 1)) - mesh = self._get_mesh((1, self.n_devices // 2, 2)) # sharding spec here is not ordered. - xs.mark_sharding(t1, mesh, partition_spec=(1, 0)) - sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(t1) - devices = f'[{self.n_devices // 2},1,2]' + ','.join( - str(x) for x in range(self.n_devices)) - expected_spec = f'{{devices={devices} last_tile_dim_replicate}}' - self.assertEqual(sharding_spec, expected_spec) - - actual = (t1 + t2).cpu() + xt1 = xs.mark_sharding(t1, mesh, partition_spec=(2, 1)) + if self.n_devices > 1: + hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt1.global_tensor]) + sharding_annotation = 'sharding={devices=[1,1,%d]%s}' % ( + self.n_devices, ','.join( + [str(d) for d in mesh.get_logical_mesh().flatten()])) + self.assertIn(sharding_annotation, hlo) + actual = (xt1 + t2).cpu() self.assertTrue(torch.allclose(expected, actual)) def test_partial_replication_addmm(self): @@ -634,24 +631,22 @@ def test_xla_sharded_hlo_dump(self): # scalar 5 should be replicated self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo) - @unittest.skipUnless(xr.global_runtime_device_count() >= 4, - 'At least 4 devices required') def test_2d_tensor_3d_mesh(self): - ct1 = torch.randn(16, 16) - ct2 = torch.randn(16, 16) + ct1 = torch.randn(16, 16, device='cpu') + ct2 = torch.randn(16, 16, device='cpu') expected = ct1 + ct2 t1 = ct1.to(xm.xla_device()) t2 = ct2.to(xm.xla_device()) - mesh = self._get_mesh((1, self.n_devices // 2, 2)) - xs.mark_sharding(t1, mesh, partition_spec=(0, 1)) - - sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(t1) - devices = f'[1,{self.n_devices // 2},2]' + ','.join( - str(x) for x in range(self.n_devices)) - expected_spec = f'{{devices={devices} last_tile_dim_replicate}}' - self.assertEqual(sharding_spec, expected_spec) - + mesh = self._get_mesh((1, self.n_devices, 1)) + t1 = xs.mark_sharding(t1, mesh, partition_spec=(1, 2)) + if self.n_devices > 1: + hlo = torch_xla._XLAC._get_xla_tensors_hlo([t1.global_tensor]) + # expected string in hlo %param = f32[1,4,16]{2,1,0:T(4,128)} parameter(0), sharding={devices=[1,4,1]0,2,1,3} + sharding_annotation = 'sharding={devices=[1,%d,1]%s}' % ( + self.n_devices, ','.join( + [str(d) for d in mesh.get_logical_mesh().flatten()])) + self.assertIn(sharding_annotation, hlo) actual = (t1 + t2).cpu() self.assertTrue(torch.allclose(expected, actual)) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8289d962c18f..419a281691c2 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1375,17 +1375,13 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_mark_sharding", [](const at::Tensor& input, const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, - int sharding_type, bool tensor_rank_less_than_mesh) { + int sharding_type) { TORCH_LAZY_COUNTER("XlaMarkSharding", 1); XLA_CHECK(UseVirtualDevice()) << "Please set `XLA_USE_SPMD=1`"; XLATensorPtr xtensor = bridge::GetXlaTensor(input); xla::OpSharding sharding = ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); - if (tensor_rank_less_than_mesh) { - // Replicate the lower-rank tensor along the last mesh dimension. - sharding.set_replicate_on_last_tile_dim(true); - } auto new_sharding_spec = std::make_shared( sharding, MakeShapeWithDeviceLayout( diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 5dded1854cb9..0ac313c34ad6 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -412,26 +412,35 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, assert len(specs) == len(np.unique(specs)), \ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - tensor_rank_less_than_mesh = len(t.shape) < len(mesh.get_logical_mesh().shape) - if tensor_rank_less_than_mesh: - assert len(mesh.get_logical_mesh().shape) == len( - t.shape) + 1, 'Tensor rank must be equal to or one less than mesh rank' - tile_assignment = _get_tile_assignment(mesh, partition_spec + (None,)) - else: - tile_assignment = _get_tile_assignment(mesh, partition_spec) + # check for sharding 2D tensor on a 3D mesh + original_shape = tuple(t.shape) + # number of dims to expand on tensor + tensor_expand = 0 + if tensor_expand < len(mesh.get_logical_mesh().shape) - len(partition_spec): + tensor_expand = len(mesh.get_logical_mesh().shape) - len(partition_spec) + partition_spec = (None,) * tensor_expand + partition_spec + shape = (1,) * tensor_expand + (*original_shape,) + t = t.expand(shape) + + tile_assignment = _get_tile_assignment(mesh, partition_spec) sharding_type = _get_sharding_type(partition_spec, num_devices) group_assignment, replication_groups = _get_group_assignment( sharding_type, partition_spec, tile_assignment) + def tensor_squeeze(t, tensor_expand): + if tensor_expand: + t = torch.squeeze(t, dim=tuple(range(tensor_expand))) + return t + if isinstance(t, XLAShardedTensor): torch_xla._XLAC._xla_mark_sharding(t.global_tensor, tile_assignment, group_assignment, replication_groups, - int(sharding_type), - tensor_rank_less_than_mesh) + int(sharding_type)) + t = tensor_squeeze(t, tensor_expand) return t torch_xla._XLAC._xla_mark_sharding(t, tile_assignment, group_assignment, - replication_groups, int(sharding_type), - tensor_rank_less_than_mesh) + replication_groups, int(sharding_type)) + t = tensor_squeeze(t, tensor_expand) return XLAShardedTensor(t)