Skip to content

Commit

Permalink
Revert "Partially replicate lower-rank tensors (#5409)" (#5412)
Browse files Browse the repository at this point in the history
This reverts commit 56a6a02.
  • Loading branch information
yeounoh authored and will-cromar committed Sep 14, 2023
1 parent 88a6865 commit 239119f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
43 changes: 19 additions & 24 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,6 @@ def test_mark_sharding_not_ordered_partial_4d(self):
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
'At least 4 devices required')
def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self):
ct1 = torch.randn(16, 16, device='cpu')
ct2 = torch.randn(16, 16, device='cpu')
Expand All @@ -384,16 +382,15 @@ def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self):
t1 = ct1.to(xm.xla_device())
t2 = ct2.to(xm.xla_device())
mesh = self._get_mesh((1, self.n_devices, 1))
mesh = self._get_mesh((1, self.n_devices // 2, 2))
# sharding spec here is not ordered.
xs.mark_sharding(t1, mesh, partition_spec=(1, 0))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(t1)
devices = f'[{self.n_devices // 2},1,2]' + ','.join(
str(x) for x in range(self.n_devices))
expected_spec = f'{{devices={devices} last_tile_dim_replicate}}'
self.assertEqual(sharding_spec, expected_spec)

actual = (t1 + t2).cpu()
xt1 = xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
if self.n_devices > 1:
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt1.global_tensor])
sharding_annotation = 'sharding={devices=[1,1,%d]%s}' % (
self.n_devices, ','.join(
[str(d) for d in mesh.get_logical_mesh().flatten()]))
self.assertIn(sharding_annotation, hlo)
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

def test_partial_replication_addmm(self):
Expand Down Expand Up @@ -634,24 +631,22 @@ def test_xla_sharded_hlo_dump(self):
# scalar 5 should be replicated
self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo)

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
'At least 4 devices required')
def test_2d_tensor_3d_mesh(self):
ct1 = torch.randn(16, 16)
ct2 = torch.randn(16, 16)
ct1 = torch.randn(16, 16, device='cpu')
ct2 = torch.randn(16, 16, device='cpu')
expected = ct1 + ct2

t1 = ct1.to(xm.xla_device())
t2 = ct2.to(xm.xla_device())
mesh = self._get_mesh((1, self.n_devices // 2, 2))
xs.mark_sharding(t1, mesh, partition_spec=(0, 1))

sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(t1)
devices = f'[1,{self.n_devices // 2},2]' + ','.join(
str(x) for x in range(self.n_devices))
expected_spec = f'{{devices={devices} last_tile_dim_replicate}}'
self.assertEqual(sharding_spec, expected_spec)

mesh = self._get_mesh((1, self.n_devices, 1))
t1 = xs.mark_sharding(t1, mesh, partition_spec=(1, 2))
if self.n_devices > 1:
hlo = torch_xla._XLAC._get_xla_tensors_hlo([t1.global_tensor])
# expected string in hlo %param = f32[1,4,16]{2,1,0:T(4,128)} parameter(0), sharding={devices=[1,4,1]0,2,1,3}
sharding_annotation = 'sharding={devices=[1,%d,1]%s}' % (
self.n_devices, ','.join(
[str(d) for d in mesh.get_logical_mesh().flatten()]))
self.assertIn(sharding_annotation, hlo)
actual = (t1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

Expand Down
6 changes: 1 addition & 5 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1401,17 +1401,13 @@ void InitXlaModuleBindings(py::module m) {
m.def("_xla_mark_sharding",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type, bool tensor_rank_less_than_mesh) {
int sharding_type) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice()) << "Please set `XLA_USE_SPMD=1`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xla::OpSharding sharding = ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
if (tensor_rank_less_than_mesh) {
// Replicate the lower-rank tensor along the last mesh dimension.
sharding.set_replicate_on_last_tile_dim(true);
}
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding,
MakeShapeWithDeviceLayout(
Expand Down
31 changes: 20 additions & 11 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,26 +412,35 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tensor_rank_less_than_mesh = len(t.shape) < len(mesh.get_logical_mesh().shape)
if tensor_rank_less_than_mesh:
assert len(mesh.get_logical_mesh().shape) == len(
t.shape) + 1, 'Tensor rank must be equal to or one less than mesh rank'
tile_assignment = _get_tile_assignment(mesh, partition_spec + (None,))
else:
tile_assignment = _get_tile_assignment(mesh, partition_spec)
# check for sharding 2D tensor on a 3D mesh
original_shape = tuple(t.shape)
# number of dims to expand on tensor
tensor_expand = 0
if tensor_expand < len(mesh.get_logical_mesh().shape) - len(partition_spec):
tensor_expand = len(mesh.get_logical_mesh().shape) - len(partition_spec)
partition_spec = (None,) * tensor_expand + partition_spec
shape = (1,) * tensor_expand + (*original_shape,)
t = t.expand(shape)

tile_assignment = _get_tile_assignment(mesh, partition_spec)
sharding_type = _get_sharding_type(partition_spec, num_devices)
group_assignment, replication_groups = _get_group_assignment(
sharding_type, partition_spec, tile_assignment)

def tensor_squeeze(t, tensor_expand):
if tensor_expand:
t = torch.squeeze(t, dim=tuple(range(tensor_expand)))
return t

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, tile_assignment,
group_assignment, replication_groups,
int(sharding_type),
tensor_rank_less_than_mesh)
int(sharding_type))
t = tensor_squeeze(t, tensor_expand)
return t
torch_xla._XLAC._xla_mark_sharding(t, tile_assignment, group_assignment,
replication_groups, int(sharding_type),
tensor_rank_less_than_mesh)
replication_groups, int(sharding_type))
t = tensor_squeeze(t, tensor_expand)
return XLAShardedTensor(t)


Expand Down

0 comments on commit 239119f

Please sign in to comment.